diff --git a/.gitmodules b/.gitmodules index 32c0c205948..282746ed0b5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -151,3 +151,6 @@ [submodule "third_party/VulkanMemoryAllocator"] path = third_party/VulkanMemoryAllocator url = https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git +[submodule "third_party/cutlass"] + path = third_party/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/CMakeLists.txt b/CMakeLists.txt index f701923d2e8..3b4d71d2c1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -721,6 +721,13 @@ set(BUILD_ONEDNN_GRAPH OFF) include(cmake/Dependencies.cmake) +# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now +option(USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention" OFF) +if(USE_FLASH_ATTENTION) + ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION) +ENDIF() + + if(USE_CUDA AND (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 10.2) AND (CMAKE_HOST_SYSTEM_NAME MATCHES "Windows")) # CUDA < 10.2 doesn't support compiling and extracting header dependencies in # one call, so instead CMake calls nvcc twice with && in between. diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 2380f66ce2b..3055e290094 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -130,15 +130,13 @@ file(GLOB native_cuda_h "native/cuda/*.h" "native/cuda/*.cuh") file(GLOB native_cuda_linalg_cpp "native/cuda/linalg/*.cpp") file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh") file(GLOB native_cudnn_cpp "native/cudnn/*.cpp") -file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu") -file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp") file(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu") file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp") file(GLOB native_quantized_cuda_cu "native/quantized/cuda/*.cu") file(GLOB native_quantized_cuda_cpp "native/quantized/cuda/*.cpp") file(GLOB native_quantized_cudnn_cpp "native/quantized/cudnn/*.cpp") -file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu") -file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp") +file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu") +file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp") file(GLOB native_hip_hip "native/hip/*.hip") file(GLOB native_hip_cpp "native/hip/*.cpp") @@ -151,11 +149,22 @@ file(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip") file(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp") file(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip") file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp") +file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu") +file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp") file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip") file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp") file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp") file(GLOB native_utils_cpp "native/utils/*.cpp") +# flash_attention sources +file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu") +file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") + +if(USE_FLASH_ATTENTION) + list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu}) + list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp}) +endif() + # XNNPACK file(GLOB native_xnnpack "native/xnnpack/*.cpp") @@ -415,6 +424,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) endif() if(USE_CUDA AND NOT USE_ROCM) + if(USE_FLASH_ATTENTION) + list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) + endif() if($ENV{ATEN_STATIC_CUDA}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDA_LIBRARIES} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 73634685a43..d3aeeb00f84 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13136,6 +13136,11 @@ structured: True variants: function +- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool causal) -> Tensor + variants: function + dispatch: + CUDA: flash_scaled_dot_product_attention + - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp index 231eca94f07..10ffc21b263 100644 --- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp @@ -7,6 +7,7 @@ #include #include +#include namespace at { namespace native { @@ -243,5 +244,196 @@ Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional mask_dim, c } return result; } +std::tuple cumulative_and_max_seq_len(Tensor qkv) { + TORCH_CHECK( + qkv.is_nested(), + "QKV must be nested for flash cumulative_seq_len calculation.") + auto* nt_impl = get_nested_tensor_impl(qkv); + const auto& sizes = nt_impl->get_nested_size_tensor(); + auto size_tensor_stride = sizes.stride(0); + + const int64_t batch_size = qkv.size(0); + auto cumulative_seqlen = at::zeros( + {batch_size + 1}, TensorOptions().device(at::kCPU).dtype(at::kInt)); + + auto* sizes_ptr = sizes.data_ptr(); + auto* cumulative_seqlen_ptr = cumulative_seqlen.data_ptr(); + + int32_t sum = 0; + int64_t max_seqlen = -1; + cumulative_seqlen_ptr[0] = sum; + for (const auto i : c10::irange(batch_size)) { + // Calculate the cumulative sum of the sequence lengths + auto current_seq_len = sizes_ptr[i * size_tensor_stride]; + sum += current_seq_len; + cumulative_seqlen_ptr[i + 1] = sum; + + // Find the max element while we traverse + max_seqlen = std::max(max_seqlen, current_seq_len); + } + // Send to GPU, this is pretty light weight calc for normal batch size + // but maybe this needs to be on gpu + cumulative_seqlen = cumulative_seqlen.to(TensorOptions().device(at::kCUDA)); + return std::tuple{cumulative_seqlen, max_seqlen}; +} + +Tensor flash_attention_helper( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool causal) { + // Query is of size (batch_size x ragged_seq_len x (3 or 1) x n_heads x + // head_did + int64_t head_dim{query.size(-1)}; + int64_t num_heads{query.size(-2)}; + + auto cumulative_and_max_q = cumulative_and_max_seq_len(query); + Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q); + int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q); + + if (key.is_same(value) || query.is_same(key) || query.is_same(value)) { + int64_t Nnz_q{cumulative_sequence_length_q[-1].item()}; + + // For the packed case we need to set the output size for dim 2 to 1 + auto atten_size = get_nested_size_tensor(query); + atten_size.index({at::indexing::Slice(), 1}) = 1; + + auto qkv_buffer_reshaped = + get_buffer(query).view({Nnz_q, 3, num_heads, head_dim}); + + // If we are passing in query, key, value all the same tensors than we have + // packed them into one tensor and need to slice for flash attention + Tensor atten_buffer = at::_flash_scaled_dot_product_attention( + qkv_buffer_reshaped.index({at::indexing::Slice(), 0}), + qkv_buffer_reshaped.index({at::indexing::Slice(), 1}), + qkv_buffer_reshaped.index({at::indexing::Slice(), 2}), + cumulative_sequence_length_q, + cumulative_sequence_length_q, + max_seqlen_batch_q, + max_seqlen_batch_q, + dropout_p, + causal); + // Output of flash_attention is a regular tensor lets wrap it back up to + // form a nested tensor + return wrap_buffer(atten_buffer.view(-1), atten_size); + } + + // Query, Key, and Value are not all the same tensor and therefore need to + // calculate K meta data + + // The nested tensors will be of shape {Batch_size x ragged_seq_len x + // num_heads * head_dim } + auto cumulative_and_max_k = cumulative_and_max_seq_len(key); + Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k); + int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k); + + // K and V have to have the same Nnz, should probably torch_check before now + // assume in order to not iterate over v + int64_t Nnz_q{cumulative_sequence_length_q[-1].item()}; + int64_t Nnz_kv{cumulative_sequence_length_k[-1].item()}; + + auto query_buffer_reshaped = + get_buffer(query).view({Nnz_q, num_heads, head_dim}); + auto key_buffer_reshaped = + get_buffer(key).view({Nnz_kv, num_heads, head_dim}); + auto value_buffer_reshaped = + get_buffer(value).view({Nnz_kv, num_heads, head_dim}); + + Tensor atten_buffer = at::_flash_scaled_dot_product_attention( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout_p, + causal); + // Output of flash_attention is a regular tensor lets wrap it back up to + // form a nested tensor, the size of which should match the query tensor + return wrap_buffer(atten_buffer.view(-1), get_nested_size_tensor(query)); +} + +Tensor flash_attention_helper_dense( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool causal) { + TORCH_INTERNAL_ASSERT( + !query.is_nested() && !key.is_nested() && !value.is_nested()); + // Query is of size (batch_size x dense_seq_len x 3 x n_heads + // head_dim) + const auto batch_size = query.size(0); + auto max_seqlen_batch_q = query.size(1); + int64_t head_dim{query.size(-1)}; + int64_t num_heads{query.size(-2)}; + + auto cumulative_sequence_length_q = at::arange( + 0, + (batch_size + 1) * max_seqlen_batch_q, + max_seqlen_batch_q, + TensorOptions().device(at::kCUDA).dtype(at::kInt)); + int64_t Nnz_q{batch_size * max_seqlen_batch_q}; + + if (key.is_same(value) || query.is_same(key) || query.is_same(value)) { + // In the dense case flash attention expects an input that is + // (b*s) x num_heads x head_dim + auto query_reshaped = query.reshape({Nnz_q, 3, num_heads, head_dim}); + // If we are passing in query, key, value all the same tensors than we have + // packed them into one tensor and need to slice for flash attention + + Tensor atten_buffer = at::_flash_scaled_dot_product_attention( + query_reshaped.index({at::indexing::Slice(), 0}), + query_reshaped.index({at::indexing::Slice(), 1}), + query_reshaped.index({at::indexing::Slice(), 2}), + cumulative_sequence_length_q, + cumulative_sequence_length_q, + max_seqlen_batch_q, + max_seqlen_batch_q, + dropout_p, + causal); + // Reshape output to convert nnz to batch_size and seq_len + return atten_buffer.reshape( + {batch_size, max_seqlen_batch_q, num_heads, head_dim}); + } + + // Query, Key, and Value are not all the same tensor and therefore need to + // calculate K meta data + auto max_seqlen_batch_k = key.size(1); + auto cumulative_sequence_length_k = at::arange( + 0, + (batch_size + 1) * max_seqlen_batch_k, + max_seqlen_batch_k, + TensorOptions().device(at::kCUDA).dtype(at::kInt)); + + // K and V have to have the same Nnz, should probably torch_check before + // assume for now in order to not iterate over v + int64_t Nnz_kv{batch_size * max_seqlen_batch_k}; + + // Calculate head dim + TORCH_INTERNAL_ASSERT(query.size(-1) == key.size(-1)); + TORCH_INTERNAL_ASSERT(query.size(-1) == value.size(-1)); + + auto query_reshaped = query.reshape({Nnz_q, num_heads, head_dim}); + auto key_reshaped = key.reshape({Nnz_kv, num_heads, head_dim}); + auto value_reshaped = value.reshape({Nnz_kv, num_heads, head_dim}); + + Tensor atten_buffer = at::_flash_scaled_dot_product_attention( + query_reshaped, + key_reshaped, + value_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout_p, + causal); + // Reshape output to convert nnz to batch_size and seq_len + return atten_buffer.reshape( + {batch_size, max_seqlen_batch_q, num_heads, head_dim}); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h index 77eb0145d68..09b35d9c39e 100644 --- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h +++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h @@ -83,5 +83,19 @@ void add_padding_kernelLauncher( const std::vector& output_sizes, const int batch_size, const int output_batch_size); + +Tensor flash_attention_helper_dense( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool causal); + +Tensor flash_attention_helper( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool causal); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 54d2b7ffd0c..a2625ced1e1 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -9,10 +10,18 @@ #include #endif +// TODO Consider moving all flash_attention code, nested tensor included to +// Transformer library +#ifdef USE_FLASH_ATTENTION +#include +#endif + #include #include #include +#include + namespace at { namespace native { namespace { @@ -207,5 +216,37 @@ Tensor NestedTensor_to_padded_tensor_cuda( return NestedTensor_to_padded_tensor_generic(t, padding, output_size); } +Tensor flash_scaled_dot_product_attention( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& cumulative_sequence_length_q, + const Tensor& cumulative_sequence_length_k, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_k, + double dropout_p, + bool causal) { +#if defined(USE_FLASH_ATTENTION) + auto softmax_scale = std::pow(query.size(-1), -0.5); + std::vector output = fmha::mha_fwd( + query, + key, + value, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout_p, + softmax_scale, + false, + causal, + false, + c10::nullopt); + return output[0]; +#endif + TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") + return Tensor{}; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h b/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h new file mode 100644 index 00000000000..c95d73164c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h @@ -0,0 +1,149 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FMHAEpilogue { + + using ThreadblockShape = typename MmaCore::Shape; + using WarpMma = typename MmaCore::MmaTensorOp; + using LayoutC = typename MmaCore::LayoutC; + using Element = typename MmaCore::ElementA; + using ElementC = typename MmaCore::ElementC; + + static constexpr int kPartitionsK = ThreadblockShape::kK / MmaCore::WarpShape::kK; + + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + typename WarpMma::Shape, + typename WarpMma::Policy::Operator::Shape, + typename WarpMma::Policy::Operator::ElementC, + typename WarpMma::Policy::Operator::FragmentC, + LayoutC>; + using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + static constexpr int kIterationsStore = AccumulatorFragmentIterator::kIterations; + + // Maybe elementsPerAccess should vary: 4 for d=64, 2 for d=32? + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + ThreadblockShape, typename WarpMma::Shape, kPartitionsK, Element, /*ElementsPerAccess=*/4>::Type; + using OutputTileThreadMapAccum = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + ThreadblockShape, typename WarpMma::Shape, kPartitionsK, ElementC, /*ElementsPerAccess=*/4>::Type; + + using GmemIterator = fmha::EpiloguePredicatedTileIterator< + OutputTileThreadMap, + Element + >; + // which ThreadMap should we use? + using GmemIteratorAccum = fmha::EpiloguePredicatedTileIterator< + // OutputTileThreadMapAccum, + OutputTileThreadMap, + ElementC + >; + + + using DefaultIterators = cutlass::epilogue::threadblock::detail::DefaultIteratorsTensorOp< + Element, ElementC, /*ElementsPerAccess=*/4, ThreadblockShape, typename WarpMma::Shape, + typename WarpMma::Policy::Operator::Shape, typename OutputTileThreadMap::CompactedThreadMap>; + using WarpTileIterator = typename DefaultIterators::WarpTileIterator; + static_assert(WarpTileIterator::kIterations == kIterationsStore); + using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; + using OutputFragment = typename SharedLoadIterator::Fragment; + + // using Padding = cutlass::MatrixShape<0, 0>; + using Padding = cutlass::MatrixShape<0, 64 / cutlass::sizeof_bits::value * 4>; + static constexpr int kFragmentsPerIteration = kIterationsStore; // TODO: could be 1 for Volta? + /*Using kIterationsStore here so that we get the right storage size*/ + using EpilogueBase = typename cutlass::epilogue::threadblock::EpilogueBase< + ThreadblockShape, typename WarpMma::Shape, kPartitionsK, AccumulatorFragmentIterator, WarpTileIterator, + Padding, kIterationsStore>; + + using SharedStorage = typename EpilogueBase::SharedStorage; + static constexpr int kSmemTiles = EpilogueBase::kFragmentsPerIteration; + static constexpr int kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles; + static constexpr int kSmemPointerOffsetPerWarp = SharedStorage::StorageShape::kCount / (kSmemTiles * kPartitionsK); + + SharedStorage *shared_storage; + WarpTileIterator warp_tile_iterator; + + inline __device__ FMHAEpilogue(void *smem, const int tidx) + : shared_storage(reinterpret_cast(smem)) + , warp_tile_iterator(shared_storage->reference(), threadIdx.x % 32) { + + // const int warp_idx = tidx / 32; + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + // https://github.com/NVIDIA/cutlass/blob/e66bfcb1f880792caa46b1e983c4114e23afa5f3/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h#L520 + const int warp_idx = __shfl_sync(0xffffffff, tidx / 32, 0); + + cutlass::MatrixCoord warp_offset{kIterationsStore * warp_idx, 0}; + + warp_tile_iterator.add_tile_offset(warp_offset); + } + + // Store the accumulators. + inline __device__ void store(const AccumulatorTile &acc) { + AccumulatorFragmentIterator accum_fragment_iterator(acc); + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kIterationsStore; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < kIterationsStore - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp); + } + } + if (kIterationsStore > 1) { + warp_tile_iterator.add_pointer_offset((1 - kIterationsStore) * kSmemPointerOffsetPerWarp); + } + } + + // Load the accumulators + template + inline __device__ void load(OutputFragment (&out)[kFragmentsPerIteration], + const int tidx) { + SharedLoadIterator shared_load_iterator(shared_storage->reference(), tidx); + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < EpilogueBase::kFragmentsPerIteration; ++p) { + OutputFragment aligned_accum_fragment[kPartitionsK]; + shared_load_iterator.load(aligned_accum_fragment[0]); + cutlass::plus add_fragments; + if (kPartitionsK > 1) { + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp * kIterationsStore); + shared_load_iterator.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + shared_load_iterator.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffsetPerWarp * kIterationsStore); + } + if (p < EpilogueBase::kFragmentsPerIteration - 1) { + shared_load_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp); + } + + out[p] = zero_init ? aligned_accum_fragment[0] : add_fragments(out[p], aligned_accum_fragment[0]); + } + } + +}; + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue_predicated_tile_iterator.h b/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue_predicated_tile_iterator.h new file mode 100644 index 00000000000..1c04fdfd705 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue_predicated_tile_iterator.h @@ -0,0 +1,493 @@ +// Adapted from cutlass/epilogue/threadblock/predicated_tile_iterator.h +// We just want to add the move() function, but idk how to do it without +// copying the code here. + +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////// + +using namespace cutlass; +using namespace cutlass::epilogue::threadblock; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false +> +class EpiloguePredicatedTileIterator { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout): + PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) + { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const *indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + EpiloguePredicatedTileIterator( + PredicatedTileIteratorParams const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const *indices = nullptr + ): + params_(params), indices_(indices) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast(byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast(byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) const { + + store_with_byte_offset(frag, 0); + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + void move(const int step=1) { + + if (!ScatterD) { + byte_pointer_ += step * params_.advance_row; + } + + thread_start_row_ += step * ThreadMap::Shape::kRow; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h new file mode 100644 index 00000000000..d259280fac5 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h @@ -0,0 +1,154 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +#include +#include +#include + + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + // size_t qkv_stride_in_elts; + // size_t qkv_stride_in_bytes; + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + uint32_t q_row_stride_in_elts; + uint32_t k_row_stride_in_elts; + uint32_t v_row_stride_in_elts; + uint32_t q_head_stride_in_elts; + uint32_t k_head_stride_in_elts; + uint32_t v_head_stride_in_elts; + + // The number of heads. + int h; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct FMHA_fprop_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + + // The stride between rows of O. + // size_t o_stride_in_elts; + // size_t o_stride_in_bytes; + uint32_t o_row_stride_in_elts; + uint32_t o_head_stride_in_elts; + + // The pointer to the O_tmp matrix, which holds O intermediate value during + // the loop; + void *__restrict__ o_tmp_ptr; + + // The pointer to the S matrix. + void * __restrict__ s_ptr; + // The stride between rows of the S matrix. + // int64_t s_stride_in_bytes; + uint32_t s_stride_in_bytes; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, d; + + // The scaling factors for the kernel. + float scale_bmm1; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + int *__restrict__ blockmask; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + uint32_t p_dropout_in_uint; + uint16_t p_dropout_in_uint16_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_bmm1_rp_dropout; + + // Random state. + at::PhiloxCudaState philox_args; + + bool is_bf16; + bool is_causal; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Launch_params{ + Launch_params(cudaDeviceProp * props_, + cudaStream_t stream_, + bool is_dropout_, + bool return_softmax_) + : elts_per_thread(0) + , props(props_) + , stream(stream_) + , is_dropout(is_dropout_) + , return_softmax(return_softmax_) { + } + + size_t elts_per_thread; + + cudaDeviceProp * props; + + cudaStream_t stream; + + bool is_dropout; + bool return_softmax; + + Kernel_params params; + int num_full_heads; + int num_main_groups; + int heads_last_wave; + int main_steps; + int rest_steps; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_fmha_fprop(Launch_params &launch_params, const bool configure); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp new file mode 100644 index 00000000000..a3970f25028 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -0,0 +1,244 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include +#include +#include + +#include +#include + +#include + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +namespace fmha { + +void set_params_fprop(FMHA_fprop_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t h, + const size_t d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *o_packed_d, + void *o_tmp_d, + void *s_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + bool is_causal) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.is_bf16 = q.dtype() == at::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + params.q_row_stride_in_elts = q.stride(0); + params.k_row_stride_in_elts = k.stride(0); + params.v_row_stride_in_elts = v.stride(0); + params.q_head_stride_in_elts = q.stride(1); + params.k_head_stride_in_elts = k.stride(1); + params.v_head_stride_in_elts = v.stride(1); + params.o_ptr = o_packed_d; + params.o_row_stride_in_elts = h * d; + params.o_head_stride_in_elts = d; + params.o_tmp_ptr = o_tmp_d; + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // S = softmax(P) + params.s_ptr = s_d; + params.s_stride_in_bytes = b * h * seqlen_k * 2; // 2 = sizeof(Element) + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.d = d; + + // Set the different scale values. + params.scale_bmm1 = softmax_scale; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1; + TORCH_CHECK(p_dropout < 1.f); + + params.is_causal = is_causal; +} + +std::vector +mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + TORCH_CHECK(is_sm8x || is_sm75); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + bool is_dropout = p_dropout > 0.0; + Launch_params launch_params(dprops, stream, is_dropout, return_softmax); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || (is_sm8x && q_dtype == at::kBFloat16)); + TORCH_CHECK(k.dtype() == q_dtype); + TORCH_CHECK(v.dtype() == q_dtype); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt); + + TORCH_CHECK(q.is_cuda()); + TORCH_CHECK(k.is_cuda()); + TORCH_CHECK(v.is_cuda()); + TORCH_CHECK(cu_seqlens_q.is_cuda()); + TORCH_CHECK(cu_seqlens_k.is_cuda()); + + TORCH_CHECK(q.stride(-1) == 1); + TORCH_CHECK(k.stride(-1) == 1); + TORCH_CHECK(v.stride(-1) == 1); + TORCH_CHECK(cu_seqlens_k.is_contiguous()); + TORCH_CHECK(cu_seqlens_k.is_contiguous()); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + const int total_q = sizes[TOTAL_DIM]; + const int num_heads = sizes[H_DIM]; + const int head_size = sizes[D_DIM]; + const int total_k = k.size(TOTAL_DIM); + TORCH_CHECK(batch_size > 0); + TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128)); + const int head_size_rounded = head_size <= 64 ? 64 : 128; + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads, head_size); + CHECK_SHAPE(v, total_k, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + int blocksize_c = ((head_size_rounded == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size_rounded == 64 && is_dropout)) ? 128 : 256; + // Need to round max_seqlen_k to multiples of blocksize_c + int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; + if( max_seqlen_k_ <= 128 ) { + max_seqlen_k = 128; + } else if( max_seqlen_k_ <= 256 ) { + max_seqlen_k = 256; + } + int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; + bool loop = max_seqlen_k > blocksize_c; + + auto opts = q.options(); + + auto o = at::empty({ total_q, num_heads, head_size }, opts); + + at::Tensor o_tmp; + if (loop) { o_tmp = at::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); } + + auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + // auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits::infinity(), opts.dtype(at::kFloat)); + + at::Tensor s; + if (return_softmax) { s = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); } + + if( zero_tensors ) { + o.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) {s.zero_();} + } + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + set_params_fprop(launch_params.params, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + head_size, + q, k, v, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + o.data_ptr(), + loop ? o_tmp.data_ptr() : nullptr, + return_softmax ? s.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + is_causal); + + run_fmha_fprop(launch_params, /*configure=*/ true); + // number of times random will be generated per thread, to offset philox counter in thc random + // state + int64_t counter_offset = launch_params.elts_per_thread; + at::PhiloxCudaState rng_engine_inputs; + + if( is_dropout ) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); + } + + run_fmha_fprop(launch_params, /*configure=*/false); + + std::vector result = {o, softmax_lse}; + if (return_softmax) {result.push_back(s);} + return result; +} +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h new file mode 100644 index 00000000000..3dca7e2ac89 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h @@ -0,0 +1,24 @@ +#pragma once +#include + +#include +#include + +namespace fmha { + +std::vector +mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool return_softmax, + c10::optional gen_); + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h new file mode 100644 index 00000000000..5b3cd8fb68d --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h @@ -0,0 +1,722 @@ +/*************************************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Gemm_Q_K_base { + using Smem_O = fmha::FMHAEpilogue; + using WarpMma = typename Kernel_traits::MmaCoreQK::MmaTensorOp; + + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + + static constexpr size_t SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; + + __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k) + : smem_q_ptr(smem_ptr_q) + , smem_k_ptr(smem_ptr_k) { + + } + + __device__ inline void load_q(int byte_offset=0) { + typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Cta_tile_p::M, Cta_tile_p::K}); + typename WarpMma::IteratorA iter_A({reinterpret_cast(smem_q_ptr + byte_offset), layout_A}, threadIdx.x % 32); + iter_A.load(frag_q[0]); + } + + + __device__ inline void reload_q(int byte_offset=0) { + typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Cta_tile_p::M, Cta_tile_p::K}); + typename WarpMma::IteratorA iter_A({reinterpret_cast(smem_q_ptr + byte_offset), layout_A}, threadIdx.x % 32); + iter_A.load(frag_q[0]); + } + + typename WarpMma::FragmentA frag_q[2]; + char *smem_q_ptr; + char *smem_k_ptr; +}; + +template +struct Gemm_Q_K : public Gemm_Q_K_base { + + using Base = Gemm_Q_K_base; + using Cta_tile_p = typename Base::Cta_tile_p; + using Smem_O = typename Base::Smem_O; + using WarpMma = typename Base::WarpMma; + + static constexpr int kIterations = WarpMma::Shape::kK / WarpMma::InstructionShape::kK; + + static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; + // If V is stored in shared memory, we can't load K using the same shared memory. + static_assert(Kernel_traits::V_IN_REGS); + + static constexpr size_t SMEM_OFFSET_O = Kernel_traits::BYTES_PER_SMEM_Q; + static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage); + static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K); + + // Q | K / V + // | O | SOFTMAX + static constexpr size_t SMEM_BYTES = Kernel_traits::BYTES_PER_SMEM_Q + + std::max((size_t)(SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Kernel_traits::BYTES_PER_SMEM_K, + sizeof(typename Smem_O::SharedStorage) + Base::SMEM_BYTES_SOFTMAX); + + __device__ inline Gemm_Q_K(char * smem_) + : Base(smem_, smem_ + Kernel_traits::BYTES_PER_SMEM_Q) { + } + + __device__ inline void load_k(){ + typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N}); + typename WarpMma::IteratorB iter_B({reinterpret_cast(Base::smem_k_ptr), layout_B}, threadIdx.x % 32); + const int warp_idx = threadIdx.x / 32; + iter_B.add_tile_offset({0, warp_idx}); + #pragma unroll + for( int ki = 0; ki < kIterations; ++ki ) { + iter_B.load(frag_k[ki]); + ++iter_B; + } + } + + __device__ inline void operator()(WarpMma warp_mma, typename WarpMma::FragmentC &acc_p, int byte_offset_q=0){ + typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Base::Cta_tile_p::M, Base::Cta_tile_p::K}); + typename WarpMma::IteratorA iter_A({reinterpret_cast(Base::smem_q_ptr + byte_offset_q), layout_A}, threadIdx.x % 32); + ++iter_A; + // Do this part of P^T = (Q * K^T)^T. + #pragma unroll + for( int ki = 0; ki < kIterations; ++ki ) { + // Trigger the load from shared memory for the next series of Q values. + if (ki + 1 < kIterations) { iter_A.load(Base::frag_q[(ki + 1) % 2]); ++iter_A; } + // Do the math for the values already in registers. + warp_mma(acc_p, Base::frag_q[ki % 2], frag_k[ki], acc_p); + } + } + + __device__ inline void reload_k(){ + // Noop. + } + + typename WarpMma::FragmentB frag_k[kIterations]; +}; + + +template +struct Gemm_Q_K : public Gemm_Q_K_base { + using Base = Gemm_Q_K_base; + using Cta_tile_p = typename Base::Cta_tile_p; + using Smem_O = typename Base::Smem_O; + using WarpMma = typename Base::WarpMma; + + static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; + static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS; + static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V); + + static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K); + static constexpr size_t SMEM_OFFSET_O = SMEM_OFFSET_V + Kernel_traits::BYTES_PER_SMEM_V; + static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage); + + // If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX + // If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX + static constexpr size_t SMEM_BYTES = Kernel_traits::BYTES_PER_SMEM_Q + + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Kernel_traits::BYTES_PER_SMEM_K + + sizeof(typename Smem_O::SharedStorage) + Base::SMEM_BYTES_SOFTMAX; + + __device__ inline Gemm_Q_K(char * smem_) + : Base(smem_, smem_ + Kernel_traits::BYTES_PER_SMEM_Q) { + } + + __device__ inline void load_k(){ + typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N}); + typename WarpMma::IteratorB iter_B({reinterpret_cast(Base::smem_k_ptr), layout_B}, threadIdx.x % 32); + const int warp_idx = threadIdx.x / 32; + iter_B.add_tile_offset({0, warp_idx}); + iter_B.load(frag_k[0]); + } + + __device__ inline void operator()(WarpMma warp_mma, typename WarpMma::FragmentC &acc_p, int byte_offset_q=0){ + typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Base::Cta_tile_p::M, Base::Cta_tile_p::K}); + typename WarpMma::IteratorA iter_A({reinterpret_cast(Base::smem_q_ptr + byte_offset_q), layout_A}, threadIdx.x % 32); + ++iter_A; + typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N}); + typename WarpMma::IteratorB iter_B({reinterpret_cast(Base::smem_k_ptr), layout_B}, threadIdx.x % 32); + const int warp_idx = threadIdx.x / 32; + iter_B.add_tile_offset({0, warp_idx}); + ++iter_B; + + // Do this part of P^T = (Q * K^T)^T. + constexpr int kIterations = WarpMma::Shape::kK / WarpMma::InstructionShape::kK; + #pragma unroll + for( int ki = 0; ki < kIterations; ++ki ) { + // Trigger the load from shared memory for the next series of Q values. + if (ki + 1 < kIterations) { + iter_A.load(Base::frag_q[(ki + 1) % 2]); ++iter_A; + iter_B.load(frag_k[(ki + 1) % 2]); ++iter_B; + } + // Do the math for the values already in registers. + warp_mma(acc_p, Base::frag_q[ki % 2], frag_k[ki % 2], acc_p); + } + } + __device__ inline void reload_k(){ + typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N}); + typename WarpMma::IteratorB iter_B({reinterpret_cast(Base::smem_k_ptr), layout_B}, threadIdx.x % 32); + const int warp_idx = threadIdx.x / 32; + iter_B.add_tile_offset({0, warp_idx}); + iter_B.load(frag_k[0]); + } + + typename WarpMma::FragmentB frag_k[2]; +}; + +template +constexpr size_t get_dynamic_smem_size(){ + return Gemm_Q_K::SMEM_BYTES; +} + +template +inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { + + // The description of the CTA tile for the 1st batched GEMM. + using Cta_tile_p = typename Kernel_traits::Cta_tile_p; + // The description of the CTA tile for the 2nd batched GEMM. + using Cta_tile_o = typename Kernel_traits::Cta_tile_o; + + // The MMA tile for the 1st GEMM. + using Mma_tile_p = fmha::Hmma_tile; + // The MMA tile for the 2nd GEMM. + using Mma_tile_o = fmha::Hmma_tile; + + using InstructionShape = typename Kernel_traits::MmaInstructionShape; + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + + using ThreadblockShapeQK = typename Kernel_traits::ThreadblockShapeQK; + using LayoutQ = typename Kernel_traits::LayoutQ; + using LayoutK = typename Kernel_traits::LayoutK; + using LayoutP = typename Kernel_traits::LayoutP; + using MmaCoreQK = typename Kernel_traits::MmaCoreQK; + using WarpMmaQK = typename MmaCoreQK::MmaTensorOp; + using SmemLayoutQ = typename MmaCoreQK::SmemLayoutA; + using SmemLayoutK = typename MmaCoreQK::SmemLayoutB; + using SmemIteratorQ = typename MmaCoreQK::SmemIteratorA; + using SmemIteratorK = typename MmaCoreQK::SmemIteratorB; + + using ThreadblockShapePV = typename Kernel_traits::ThreadblockShapePV; + using LayoutV = typename Kernel_traits::LayoutV; + using LayoutO = typename Kernel_traits::LayoutO; + using MmaCorePV = typename Kernel_traits::MmaCorePV; + using WarpMmaPV = typename MmaCorePV::MmaTensorOp; + using WarpIteratorV = typename WarpMmaPV::IteratorB; + using SmemLayoutV = typename MmaCorePV::SmemLayoutB; + using SmemIteratorV = typename MmaCorePV::SmemIteratorB; + constexpr int kIterationsPV = WarpMmaPV::Shape::kK / WarpMmaPV::InstructionShape::kK; + + // The global memory tile to load Q. + // Copy from mma_piplined_testbed.h + using GmemIteratorQ = typename Kernel_traits::GmemIteratorQ; + // The global memory tile to load K. + using GmemIteratorK = typename Kernel_traits::GmemIteratorK; + // The global memory tile to load V. + using GmemIteratorV = typename Kernel_traits::GmemIteratorV; + // The global memory tile to store O. + using GmemIteratorO = typename fmha::FMHAEpilogue::GmemIterator; + using GmemIteratorOAccum = typename fmha::FMHAEpilogue::GmemIteratorAccum; + + using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; + + using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; + + using Smem_softmax_lse = typename Kernel_traits::Smem_softmax_lse; + + using Gemm1 = Gemm_Q_K; + + using Softmax = fmha::Softmax; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + const BlockInfoPadded binfo(params, bidb, bidh, tidx); + if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; + + Gemm1 gemm_q_k(smem_); + // Allocate the global memory tile loader for S. + Gmem_tile_s gmem_s(params, binfo, tidx); + Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); + + // Wind gmem tiles to the correct position. + static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); + const int begin_og = begin; + begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin; + const int steps_og = steps; + steps -= begin - begin_og; + if (Return_softmax) { gmem_s.move(begin); } + gmem_softmax_lse.move(begin); + + fmha::Mask mask(binfo, tidx, loop_step_idx); + + // The base pointer of smem_v; + char *smem_v_addr = &smem_[Gemm1::SMEM_OFFSET_V]; + + // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! + + SmemLayoutQ layout_Q = SmemLayoutQ::packed({ThreadblockShapeQK::kM, ThreadblockShapeQK::kK}); + SmemIteratorQ smem_q({reinterpret_cast(smem_), layout_Q}, tidx); + SmemLayoutK layout_K = SmemLayoutK::packed({ThreadblockShapeQK::kK, ThreadblockShapeQK::kN}); + SmemIteratorK smem_k({reinterpret_cast(smem_ + Kernel_traits::BYTES_PER_SMEM_Q), layout_K}, tidx); + SmemLayoutV layout_V = SmemLayoutV::packed({ThreadblockShapePV::kK, ThreadblockShapePV::kN}); + // SmemIterator stores to smem and WarpIterator loads from smem + SmemIteratorV smem_v({reinterpret_cast(smem_v_addr), layout_V}, tidx); + WarpIteratorV iter_V({reinterpret_cast(smem_v_addr), layout_V}, threadIdx.x % 32); + + // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! + using Smem_O = fmha::FMHAEpilogue; + Smem_O smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); + + // Allocate the global memory tile loader for Q. + // cutlass::transform::threadblock::PredicatedTileIterator deals with seqlen not divisible + // by 16 in a different way than we want. If the seqlen_q is 36, the first iteration would + // load 4 rows and the next two iterations would load 16 rows each. Instead we round the + // actual_seqlen_q to be multiple of 16, then change the mask in the last iteration, so + // that in this case we would load 16, 16, 4. + LayoutQ gmem_layout_Q(params.q_row_stride_in_elts); + typename GmemIteratorQ::Params gmem_Q_params(gmem_layout_Q); + const uint32_t row_offset_q = (binfo.sum_s_q + begin * ThreadblockShapeQK::kM) * params.q_row_stride_in_elts + binfo.bidh * params.q_head_stride_in_elts; + const int actual_seqlen_q = binfo.actual_seqlen_q - begin * ThreadblockShapeQK::kM; + const int seqlen_q_remainder = actual_seqlen_q % ThreadblockShapeQK::kM; + const int extent_q = ((actual_seqlen_q <= ThreadblockShapeQK::kM) || (seqlen_q_remainder == 0)) ? actual_seqlen_q : actual_seqlen_q + ThreadblockShapeQK::kM - seqlen_q_remainder; + GmemIteratorQ gmem_q(gmem_Q_params, + reinterpret_cast(params.q_ptr) + row_offset_q, + {extent_q, params.d}, + tidx); + + // Allocate the global memory tile loader for K. + LayoutK gmem_layout_K(params.k_row_stride_in_elts); + typename GmemIteratorK::Params gmem_K_params(gmem_layout_K); + const uint32_t row_offset_k = (binfo.sum_s_k + loop_step_idx * ThreadblockShapeQK::kN) * params.k_row_stride_in_elts + binfo.bidh * params.k_head_stride_in_elts; + const int extent_k = min(binfo.actual_seqlen_k - loop_step_idx * ThreadblockShapeQK::kN, ThreadblockShapeQK::kN); + GmemIteratorK gmem_k(gmem_K_params, + reinterpret_cast(params.k_ptr) + row_offset_k, + {params.d, extent_k}, + tidx); + + // Allocate the global memory tile loader for V. + LayoutV gmem_layout_V(params.v_row_stride_in_elts); + typename GmemIteratorV::Params gmem_V_params(gmem_layout_V); + const uint32_t row_offset_v = (binfo.sum_s_k + loop_step_idx * ThreadblockShapePV::kK) * params.v_row_stride_in_elts + binfo.bidh * params.v_head_stride_in_elts; + // extent_v is the same as extent_k + GmemIteratorV gmem_v(gmem_V_params, + reinterpret_cast(params.v_ptr) + row_offset_v, + {extent_k, params.d}, + tidx); + + // Allocate the global memory tile loader for O. + LayoutO gmem_layout_O(params.o_row_stride_in_elts); + typename GmemIteratorO::Params gmem_O_params(gmem_layout_O); + const uint32_t row_offset_o = (binfo.sum_s_q + begin * ThreadblockShapeQK::kM) * params.o_row_stride_in_elts + binfo.bidh * params.o_head_stride_in_elts; + GmemIteratorO gmem_o(gmem_O_params, + reinterpret_cast(params.o_ptr) + row_offset_o, + {actual_seqlen_q, params.d}, + tidx); + + typename GmemIteratorOAccum::Params gmem_Oaccum_params(gmem_layout_O); + GmemIteratorOAccum gmem_o_accum(gmem_Oaccum_params, + reinterpret_cast(params.o_tmp_ptr) + row_offset_o, + {actual_seqlen_q, params.d}, + tidx); + + // Create the object to do the softmax. + Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx); + + Smem_softmax_lse smem_softmax_lse(reinterpret_cast(&smem_[Gemm1::SMEM_BYTES])); + + if (!Is_first) { + if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } + } + + if (!Is_first) { __syncthreads(); } + + // Trigger the loads for V. + typename GmemIteratorV::Fragment gmem_frag_v; + gmem_frag_v.clear(); + gmem_v.load(gmem_frag_v); + + // Trigger the loads for Q. + typename GmemIteratorQ::Fragment gmem_frag_q; + gmem_frag_q.clear(); + gmem_q.load(gmem_frag_q); + + // Trigger the loads for K. + typename GmemIteratorK::Fragment gmem_frag_k; + gmem_frag_k.clear(); + gmem_k.load(gmem_frag_k); + + float p_prev_lse[Mma_tile_p::MMAS_M * 2]; + if (!Is_first) { + gmem_softmax_lse.load(reinterpret_cast(p_prev_lse)); + } + + // Commit the data for Q and V to shared memory. + smem_v.store(gmem_frag_v); + smem_q.store(gmem_frag_q); + + // Commit the data for K to shared memory. + if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { + smem_k.store(gmem_frag_k); + } + + __syncthreads(); + + // Load the fragments for Q. + gemm_q_k.load_q(); + + // Load the fragments for V. We keep the data in registers during the entire + // kernel. copied from mma_pipelined.h + const int warp_idx = threadIdx.x / 32; + iter_V.add_tile_offset({kIterationsPV * warp_idx, 0}); + typename WarpIteratorV::Fragment frag_v[kIterationsPV]; + static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N ); + #pragma unroll + for( int ki = 0; ki < kIterationsPV; ++ki ) { + iter_V.load(frag_v[ki]); + ++iter_V; + } + + // Commit the data for K to shared memory if it has not been done already. + if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { + // Make sure we are done loading the fragments for K. + __syncthreads(); + + // Commit the data to shared memory for K. + smem_k.store(gmem_frag_k); + + // Make sure the data is in shared memory. + __syncthreads(); + } + + // Load the fragments for K. + gemm_q_k.load_k(); + + // Load over the entire sequence length. + for( int l = 0; l < steps; l++ ) { + if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break; + + // Declare the accumulators for the 1st gemm. + WarpMmaQK mma_qk; + typename WarpMmaQK::FragmentC acc_p; + acc_p.clear(); + + // Do this part of P = Q * K^T. + gemm_q_k(mma_qk, acc_p); + + typename Smem_O::OutputFragment out[Smem_O::kIterationsStore]; + static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore); + static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore); + if (!Is_first) { + #pragma unroll + for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) { + gmem_o_accum.load(out[iter]); + gmem_o_accum.move(); + } + } + + // Trigger the load for the next Q values. + if( l < steps - 1) { + ++gmem_q; + // If actual_seqlen_q is not a multiple of 16, we change the mask in the last iteration + // to load the "residue" tile. + if ((l + 1 == steps - 1) && (actual_seqlen_q % ThreadblockShapeQK::kM != 0)) { + // TODO: this probably only works for head_dim = 64 and head_dim = 128, which is + // what we have right now. Maybe for head_dim = 32 or 96, this could be different. + const int row_idx = tidx / (GmemIteratorQ::Shape::kColumn / GmemIteratorQ::Fragment::kElements); + if (row_idx >= actual_seqlen_q - (l + 1) * ThreadblockShapeQK::kM) { + gmem_q.clear_mask(); + } + } + gmem_q.load(gmem_frag_q); + } + + // Load the mask for that iteration. + mask.load(begin + l); + + // Convert from the accumulator type to FP32 for Softmax. + softmax.unpack_noscale(acc_p); + + // Apply the mask. + softmax.apply_mask(mask); + + if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { + // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction + __syncthreads(); + } + + // Compute the max. + float p_max[Mma_tile_p::MMAS_M * 2]; + if (!Is_first) { + smem_softmax_lse.store_pair(p_prev_lse); + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1; } + } + + // Trigger the load for the next LSE values. + if( l < steps - 1) { + if (!Is_first) { + gmem_softmax_lse.load_next(reinterpret_cast(p_prev_lse)); + } + } + + softmax.template reduce_max(p_max); + + // Compute the exponential value. + softmax.scale_apply_exp(p_max, params.scale_bmm1); + + // We don't finalize the sum reduction here, as that would incur an extra sync_threads(). + // Instead, we reduce the sum from each warp, write to smem, then wait until the sync_threads() + // from storing acc_o. Then we read the sum of each warp from smem and finalize the reduction. + // As a consequence, we don't scale acc_p by the inverse sum, we scale the output by the inverse sum. + // Compute the sum. + float p_sum[Mma_tile_p::MMAS_M * 2]; + // softmax.reduce_sum(p_sum); + softmax.reduce_sum_before_sync_(p_sum); + + constexpr bool encode_dropout_in_sign_bit = Return_softmax; + if (Is_dropout) { + softmax.template apply_dropout_16bits(ph0, ph1, params.p_dropout_in_uint16_t); + } + + static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); + static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); + softmax.pack_noconvert(acc_p); + cutlass::NumericArrayConverter convert_p; + auto frag_p = convert_p(acc_p); + + if (Return_softmax) { + gmem_s.store(reinterpret_cast(&)[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]>(frag_p), mask); + gmem_s.move(); + } + + // Commit the values for Q into shared memory. + if (l < steps - 1) { smem_q.store(gmem_frag_q); } + + if (Is_dropout && encode_dropout_in_sign_bit) { + cutlass::epilogue::thread::ReLu relu; + frag_p = relu(frag_p); + } + + // Declare the accumulators for the 2nd gemm. + WarpMmaPV mma_pv; + typename WarpMmaPV::FragmentC acc_o; + static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8); + acc_o.clear(); + + // For some reason, WarpMmaPV::FragmentA has length K * N * (8|4) instead of just N * (8|4). + // We have to first cast frag_p to be array of k x (N * (8|4)), then cast each row to be + // an array of WarpMmaPV::FragmentA (which is what mma_pv expects). + static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements); + const auto frag_p_reshaped = reinterpret_cast (&)[kIterationsPV]>(frag_p); + #pragma unroll + for( int ki = 0; ki < kIterationsPV; ++ki ) { + mma_pv(acc_o, reinterpret_cast(frag_p_reshaped[ki]), frag_v[ki], acc_o); + } + // Swizzle the elements and do the final reduction. + smem_o.store(acc_o); + + // The mapping from tidx to rows changes between the softmax and the + // O-reduction. So we recalculate the max. + using OutputTileThreadMap = typename Smem_O::OutputTileThreadMap; + constexpr int kOutputRowsPerThread = OutputTileThreadMap::Iterations::kRow * Smem_O::kIterationsStore; + float p_max_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M]; + int rows[kOutputRowsPerThread]; + cutlass::MatrixCoord output_thread_offset = OutputTileThreadMap::initial_offset(tidx); + const int output_thread_start_row = output_thread_offset.row(); + const int output_thread_start_column = output_thread_offset.column(); + for (int iter = 0; iter < Smem_O::kIterationsStore; ++iter) { + for (int row = 0; row < OutputTileThreadMap::Iterations::kRow; ++row) { + rows[iter * OutputTileThreadMap::Iterations::kRow + row] = output_thread_start_row + iter * OutputTileThreadMap::Shape::kRow + row; + } + } + + softmax.reduce_max_after_sync_(p_max_o, rows); + static_assert(Mma_tile_o::MMAS_M == 1); + for (int jj = 0; jj < kOutputRowsPerThread; jj++) { + p_max_o[jj][0] *= params.scale_bmm1; + } + float p_prev_scale_o[kOutputRowsPerThread]; + if (!Is_first) { + smem_softmax_lse.load(p_prev_scale_o, rows); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + static_assert(Mma_tile_o::MMAS_M == 1); + float p_sum_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M]; + softmax.reduce_sum_after_sync_(p_sum_o, rows); + if (!Is_first) { + for (int jj = 0; jj < kOutputRowsPerThread; jj++) { + p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]); + p_sum_o[jj][0] += p_prev_scale_o[jj]; + } + } + + float p_sum_log[kOutputRowsPerThread][Mma_tile_o::MMAS_M]; + #pragma unroll + for (int jj = 0; jj < kOutputRowsPerThread; jj++) { + float sum = p_sum_o[jj][0]; + p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum); + if (output_thread_start_column == 0) { + gmem_softmax_lse.store_row( + reinterpret_cast(p_sum_log[jj]), rows[jj]); + } + } + gmem_softmax_lse.move(); + + // Load from shared memory. + using ArrayTypeO = cutlass::Array; + static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements); + cutlass::multiplies multiply_fragments; + if (!Is_first) { + auto out_reshaped = reinterpret_cast(out); + for (int jj = 0; jj < kOutputRowsPerThread; jj++) { + out_reshaped[jj] = multiply_fragments(out_reshaped[jj], p_prev_scale_o[jj]); + } + } + smem_o.template load(out, tidx); + + const bool is_final_write = + Is_last + || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) + || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); + auto out_reshaped = reinterpret_cast(out); + #pragma unroll + for (int jj = 0; jj < kOutputRowsPerThread; jj++) { + float sum = p_sum_o[jj][0]; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + if (Is_dropout && is_final_write) { + inv_sum *= params.rp_dropout; + } + out_reshaped[jj] = multiply_fragments(out_reshaped[jj], inv_sum); + } + + // Output the values. + if (is_final_write) { + typename GmemIteratorO::Fragment out_converted; + cutlass::NumericArrayConverter convert_o; + #pragma unroll + for (int iter = 0; iter < GmemIteratorO::kIterations; ++iter) { + out_converted = convert_o(out[iter]); + gmem_o.store(out_converted); + gmem_o.move(); + } + // We also need to move gmem_o_accum. For example, if Is_causal=true and seqlen=512, + // in the first loop, we write the first 256 rows to gmem_o and the last 256 rows to gmem_o_accum. + if (Is_first && !Is_last) { gmem_o_accum.move(GmemIteratorOAccum::kIterations); } + } else { + if (!Is_first) { gmem_o_accum.move(-GmemIteratorOAccum::kIterations); } + #pragma unroll + for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) { + gmem_o_accum.store(out[iter]); + gmem_o_accum.move(); + } + } + + gemm_q_k.reload_k(); + + // Trigger the load from shared memory for the next series of Q values. + if(l < steps - 1) { + gemm_q_k.reload_q(); + } + + } // Outer loop over the sequence length. +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void device_1xN_loop(const Params ¶ms) { + + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The thread index. + const int tidx = threadIdx.x; + + const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; + auto seeds = at::cuda::philox::unpack(params.philox_args); + // We use 2 Philox generators to match the dropout pattern in the backward pass. + // Forward pass uses 128 threads while backward pass uses 256 threads, so each thread + // in the forward pass is simulating the droout pattern of 2 threads in the backward pass. + Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); + constexpr int M = Kernel_traits::Cta_tile_p::M; + const int STEPS = (params.seqlen_q + M - 1) / M; + + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + if (params.seqlen_k == blocksize_c) { + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + } else { + const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); + } + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_dispatch.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_dispatch.cu new file mode 100644 index 00000000000..344aa07dd3b --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_dispatch.cu @@ -0,0 +1,134 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ +#include +#include +#include +#include +#include + +template +__global__ void fmha_fprop_loop_kernel(FMHA_fprop_params params) { + fmha::device_1xN_loop(params); +} + +template +void run_fmha_loop_(Launch_params &launch_params, + const bool configure) { + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; + + if (configure) { + using Mma_tile_p = fmha::Hmma_tile; + constexpr int M = Kernel_traits::Cta_tile_p::M; + size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M; + constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; + constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; + size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; + launch_params.elts_per_thread = elts_per_head; + return; + } + + constexpr size_t smem_size_softmax_lse = Kernel_traits::Smem_softmax_lse::BYTES_PER_TILE; + // Don't need smem_size_softmax_lse if we're not looping + const size_t smem_size = fmha::get_dynamic_smem_size() + + (loop_steps > 1 ? smem_size_softmax_lse : 0); + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_loop_kernel + : &fmha_fprop_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_loop_kernel + : &fmha_fprop_loop_kernel); + // constexpr bool IsDropoutConstTmp = false; + // auto kernel = launch_params.params.is_causal + // ? (launch_params.return_softmax + // ? &fmha_fprop_loop_kernel + // : &fmha_fprop_loop_kernel) + // : (launch_params.return_softmax + // ? &fmha_fprop_loop_kernel + // : &fmha_fprop_loop_kernel); + if( smem_size >= 48L * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); +} + +void run_fmha_fprop(Launch_params &launch_params, + const bool configure) { + BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] { + using elem_type = std::conditional::type; + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (launch_params.params.d <= 64) { + if( launch_params.params.seqlen_k == 128 ) { + // TD [2022-08-20]: One might expect that not sharing the smem between K & V + // could be faster, but seems like it's the same speed. + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_loop_(launch_params, configure); + } else if( launch_params.params.seqlen_k >= 256 ) { + if (dprops->major == 8 && dprops->minor >= 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_loop_(launch_params, configure); + } else if (dprops->major == 7 && dprops->minor == 5) { + if (launch_params.is_dropout) { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_loop_(launch_params, configure); + } else { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_loop_(launch_params, configure); + } + } + } + } else if (launch_params.params.d <= 128) { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_loop_(launch_params, configure); + } else { + if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) { + // TD [2022-06-05] Keep K in smem to reduce register spilling + // Gives about 6% speedup compared to using block size 128. + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_loop_(launch_params, configure); + } else { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_loop_(launch_params, configure); + } + } + } + }); +} diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_kernel.h new file mode 100644 index 00000000000..a321e839b3b --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_kernel.h @@ -0,0 +1,71 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfoPadded { + + template + __device__ BlockInfoPadded(const Params ¶ms, + const int bidb, + const int bidh, + const int tidx) + : bidb(bidb), bidh(bidh), h(params.h) { + + // The block index. + sum_s_k = params.cu_seqlens_k[bidb]; + actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k; + sum_s_q = params.cu_seqlens_q[bidb]; + actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q; + + tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; + } + + __device__ bool stop_early(const int start_col = 0) const { + return actual_seqlen_k <= start_col; + } + + uint32_t actual_seqlen_q; + uint32_t actual_seqlen_k; + uint32_t sum_s_q; + uint32_t sum_s_k; + uint32_t bidh; + uint32_t bidb; + uint32_t tidx_global; + uint32_t h; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_utils.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_utils.h new file mode 100644 index 00000000000..9a40ecb59f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_utils.h @@ -0,0 +1,52 @@ + + +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define FMHA_CHECK_CUDA( call ) \ + do { \ + cudaError_t status_ = call; \ + if( status_ != cudaSuccess ) { \ + fprintf( stderr, \ + "CUDA error (%s:%d): %s\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString( status_ ) ); \ + exit( 1 ); \ + } \ + } while( 0 ) + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/gemm.h b/aten/src/ATen/native/transformers/cuda/flash_attn/gemm.h new file mode 100644 index 00000000000..703e1a2629b --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/gemm.h @@ -0,0 +1,95 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The number of rows in the CTA tile. + int M_, + // The number of cols in the CTA tile. + int N_, + // The number of elements in the the K dimension of the GEMM loop. + int K_, + // The number of rows of warps. + int WARPS_M_, + // The number of cols of warps. + int WARPS_N_, + // The number of warps in the K dimension of the GEMM loop. + int WARPS_K_> +struct Cta_tile_ { + + static constexpr int M = M_, N = N_, K = K_; + // The number of warps. + static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_; + // The number of warps per CTA. + static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K; + // The number of threads per warp. + static constexpr int THREADS_PER_WARP = 32; + // The number of threads per CTA. + static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Hmma_tile { + // The number of elements computed with a single warp-MMA. + static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16; + + // The number of elements computed with a single CTA-MMA. + static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, + N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, + K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K; + + // The number of MMAs needed to compute the GEMM. + static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA), + MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA), + MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA); + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Cta_tile_extd = Cta_tile_; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h b/aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h new file mode 100644 index 00000000000..0102c0611be --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h @@ -0,0 +1,272 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Cta_tile, int BYTES_PER_ELEMENT > +struct Gmem_tile_mma_sd { + + // The mma tile. + using Mma_tile = fmha::Hmma_tile; + + // Each STG stores 8 elements. + static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8; + // The number of MMAs in the M dimension. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int MMAS_N = Mma_tile::MMAS_N; + // The number of rows computed per MMA per thread block. + static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA; + // The number of cols computed per MMA per thread block. + static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA; + // The number of threads per block. + static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA; + // The size of each row in bytes. I.e. how many bytes are stored per STG. + static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG; + // The distance between elements stored per loop (in bytes). + static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW; + + // The type of elements stored per STG. + using Type = typename fmha::Uint_from_size_in_bytes::Type; + + // Ctor. + template + inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx) + : ptr_(static_cast(ptr)) { + + // The block index. + // size_t bidx = bidb * params.h + bidh; + uint32_t bidx = bidb * params.h + bidh; + + // The distance between two blocks (in bytes). + // const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT; + const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT; + // Set store location for each thread at the beginning of the loop + ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG; + } + + // Store to global memory. + inline __device__ void store(const Type &data, const int mi, const int ni) { + // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + fmha::stg(ptr_ + offset, data); + } + + // Load from global memory. + inline __device__ void load(Type &data, const int mi, const int ni) { + // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + fmha::ldg(data, ptr_ + offset); + } + + // Move to the next tile. + inline __device__ void move(const int steps = 1) { + ptr_ += LOOP_STRIDE_BYTES * steps; + } + + // The pointer in global memory. + char *ptr_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > +struct Gmem_tile_mma_s : public Base { + + // The number of mmas in the vertical dimension. + static constexpr int M = Base::MMAS_M; + // The number of mmas in the horizontal dimension. + static constexpr int N = Base::MMAS_N; + // The type of the vectors stored by each STG. + using Type = typename Base::Type; + + // Ctor. + template< typename Params, typename Block_info > + inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx) + : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) { + } + + // Store to global memory. + template + inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){ + static_assert(Fragment::kStorageElements == 4); + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + uint4 dst; + dst.x = frag[ni][mi].raw_data()[0]; + dst.y = frag[ni][mi].raw_data()[2]; + dst.z = frag[ni][mi].raw_data()[1]; + dst.w = frag[ni][mi].raw_data()[3]; + if( mask.any_valid(mi, ni) ) { + Base::store(dst, mi, ni); + } + } + } + } + + // Load from global memory. + template + inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) { + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + regs[mi][ni] = make_uint4(0, 0, 0, 0); + if( mask.any_valid(mi, ni) ) { + Base::load(regs[mi][ni], mi, ni); + } + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The dimensions of the tile computed by the CTA. + typename Cta_tile +> +struct Gmem_summary_stats { + + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + + // The size of each element. + static constexpr int BYTES_PER_ELEMENT = 4; + static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT; + static constexpr int ROWS = Cta_tile::M; + + // Ctor. + template + inline __device__ Gmem_summary_stats(void *ptr, const Params ¶ms, const int tidx) + : ptr_(reinterpret_cast(ptr)), tidx_(tidx) { + + // The block index for the batch. + const int bidb = blockIdx.x; + // The block index for the head. + const int bidh = blockIdx.y; + // The block index. + // size_t bidx = bidb * params.h + bidh; + uint32_t bidx = bidb * params.h + bidh; + + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // The distance between two blocks (in bytes). + // size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT; + uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT; + + // Set store location for each thread at the beginning of the loop + ptr_row_ = ptr_ + bidx * block_stride_bytes; + ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT; + } + + // Store data to global memory. + inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) { + int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + if ((warp == 0) && (lane % 4 == 0)) { + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]); + fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]); + } + } + } + + // Store data to global memory. + inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) { + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]); + } + } + + // Load from global memory. + inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) { + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT); + fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT); + } + } + + // Load from global memory. + inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) { + char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT; + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + // TODO: Not sure if it's right for MMAS_M > 1 + fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT); + fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT); + } + } + + // Store data to global memory. + template + inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) { + #pragma unroll + for (int ni = 0; ni < N; ++ni) { + fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT); + } + } + + // Move the pointer to the next location. + inline __device__ void move() { + ptr_ += ROWS * BYTES_PER_ELEMENT; + ptr_row_ += ROWS * BYTES_PER_ELEMENT; + } + + // Move the pointer to the next location. + inline __device__ void move(const int steps) { + ptr_ += ROWS * BYTES_PER_ELEMENT * steps; + ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps; + } + + // The pointer. + char *ptr_; + char *ptr_row_; + const int tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h new file mode 100644 index 00000000000..26fff55c867 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h @@ -0,0 +1,154 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FMHA_kernel_traits { + + // The CTA description for the 1st GEMM. + using Cta_tile_p = fmha::Cta_tile_extd; + // The CTA description for the 2nd GEMM. + using Cta_tile_o = fmha::Cta_tile_extd; + + // Do we use one buffer for K and V. + static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u; + // Do we keep K in registers. + static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u; + // Do we keep V in registers. + static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u; + + // The global memory tile to load/store S. + using Gmem_tile_s = fmha::Gmem_tile_mma_s; + + // The global memory tile to store the softmax sum. + using Gmem_softmax_sum = fmha::Gmem_summary_stats; + + // The number of threads. + static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA; + // Make sure the number of threads matches both CTAs. + static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, ""); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; +#else + // using MmaInstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + // TD [2022-06-02] We don't support Volta (SM70) yet. +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; +#else + using Element = cutlass::half_t; +#endif + using ElementAccum = float; + + static_assert(WARPS_M == 1); + using ThreadblockShapeQK = cutlass::gemm::GemmShape; + using WarpCountQK = cutlass::gemm::GemmShape; + using WarpShapeQK = cutlass::gemm::GemmShape< + ThreadblockShapeQK::kM, + ThreadblockShapeQK::kN / WarpCountQK::kN, ThreadblockShapeQK::kK>; + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutP = cutlass::layout::RowMajor; + using MmaCoreQK = typename fmha::FMHAMmaCore< + ThreadblockShapeQK, WarpShapeQK, MmaInstructionShape, Element, LayoutQ, + Element, LayoutK, ElementAccum, LayoutP, + cutlass::arch::OpClassTensorOp>; + + using ThreadblockShapePV = cutlass::gemm::GemmShape; + using WarpCountPV = cutlass::gemm::GemmShape; + using WarpShapePV = cutlass::gemm::GemmShape; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + using MmaCorePV = typename fmha::FMHAMmaCore< + ThreadblockShapePV, WarpShapePV, MmaInstructionShape, Element, LayoutP, + Element, LayoutV, ElementAccum, LayoutO, + cutlass::arch::OpClassTensorOp>; + + // The global memory tile to load Q. + // Copy from mma_piplined_testbed.h + using GmemIteratorQ = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + Element, + LayoutQ, + 0, + typename MmaCoreQK::IteratorThreadMapA + >; + + // The global memory tile to load K. + using GmemIteratorK = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + Element, + LayoutK, + 1, + typename MmaCoreQK::IteratorThreadMapB + >; + + // The global memory tile to load V. + using GmemIteratorV = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + Element, + LayoutV, + 0, + typename MmaCorePV::IteratorThreadMapB + >; + + // The shared memory tile to store softmax lse. + using Smem_softmax_lse = fmha::Smem_tile_softmax_lse; + + // The amount of shared memory needed to load Q and K. + static constexpr size_t BYTES_PER_SMEM_Q = ThreadblockShapeQK::kM * ThreadblockShapeQK::kK * sizeof(Element); + static constexpr size_t BYTES_PER_SMEM_K = ThreadblockShapeQK::kN * ThreadblockShapeQK::kK * sizeof(Element); + static constexpr size_t BYTES_PER_SMEM_V = ThreadblockShapePV::kN * ThreadblockShapePV::kK * sizeof(Element); + static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V); + static constexpr size_t BYTES_PER_SMEM_QK = BYTES_PER_SMEM_Q + BYTES_PER_SMEM_K; + // The extra amount of shared memory needed to load V. + static constexpr size_t BYTES_PER_SMEM_V_EXTRA = SHARE_SMEM_FOR_K_AND_V ? 0u : BYTES_PER_SMEM_V; + // The amount of shared memory needed for Q, K and V.. + static constexpr size_t BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V_EXTRA; + +}; diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h b/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h new file mode 100644 index 00000000000..6169c89550b --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h @@ -0,0 +1,92 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +namespace fmha { + + +template +struct Mask { + using Mma_tile = fmha::Hmma_tile; + + template + __device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0) + : actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N) + , loop_step_idx(loop_step_idx_) { + + const int warp = tidx / Cta_tile::THREADS_PER_WARP; + const int lane = tidx % Cta_tile::THREADS_PER_WARP; + + static_assert(Cta_tile::WARPS_K == 1, ""); + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + row = warp_m * 16 + quad; + // col = warp_n * 16 + tid; + col = warp_n * Mma_tile::N_PER_MMA * Mma_tile::MMAS_N + tid; + } + + inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { + + // ii and jj iterate over the 2x4 fragment + // const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); + // const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); + const int current_col = ni * Mma_tile::N_PER_MMA + col + (jj & 2) * 4 + (jj & 1); + const int current_row = row_offset + ii * 8; + const bool col_valid = current_col < actual_seqlen_k; + // const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k; + //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid); + // } + return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; + // return row_valid && col_valid; + } + + //BERT Mask: if upper left is invalid, none are valid + inline __device__ bool any_valid(const int mi, const int ni) const { + return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0); + } + + inline __device__ void load(const int it) { + row_offset = it * Cta_tile::M + row; + } + int row_offset; + + int row; + int col; + const int loop_step_idx; + const int actual_seqlen_k; +}; + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/mma_core_sm75.h b/aten/src/ATen/native/transformers/cuda/flash_attn/mma_core_sm75.h new file mode 100644 index 00000000000..96d9c9461f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/mma_core_sm75.h @@ -0,0 +1,382 @@ +// Adapted from cutlass/gemm/threadblock/default_mma_core_sm75.h +// This is very similar, except we make it work for head_dim=128. +// The original cutlass version only allows kK of the thread block to be +// at most 64. Here we set kCrosswise = max(64, ThreadblockShape::kK) instead. + +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Operation performed by MMA + typename Operator = cutlass::arch::OpMultiplyAdd +> +struct FMHAMmaCore; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Operation performed by MMA + typename Operator_> +struct FMHAMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + Shape::kK / WarpShape::kK + >; + + // Divisibility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Cutlass only supports Crosswise at most 64 + static int const kCrosswise = std::min(Shape::kK, 64); + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, kCrosswise>; + + // Shared memory layout + using SmemLayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, kCrosswise>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, kThreads, + cutlass::layout::PitchLinearShape, + kAccessSizeInBits / cutlass::sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = cutlass::transform::threadblock::RegularTileIterator< + cutlass::MatrixShape, + ElementA, + SmemLayoutA, + 0, + IteratorThreadMapA + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, kThreads, + cutlass::layout::PitchLinearShape, + kAccessSizeInBits / cutlass::sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = cutlass::transform::threadblock::RegularTileIterator< + cutlass::MatrixShape, + ElementB, + SmemLayoutB, + 1, + IteratorThreadMapB + >; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< + MmaTensorOp, + cutlass::MatrixShape<0, 0>, + cutlass::MatrixShape<0, 0>, + WarpCount::kK + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Operation performed by MMA + typename Operator_> +struct FMHAMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + Shape::kK / WarpShape::kK + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Cutlass only supports Crosswise at most 64 + static int const kCrosswise = std::min(Shape::kK, 64); + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + // + // Shared memory layouts + // + + using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, kCrosswise>; + + // Shared memory layout + using SmemLayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, int(128 / sizeof(ElementB))>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, kThreads, + cutlass::layout::PitchLinearShape, + kAccessSizeInBits / cutlass::sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = cutlass::transform::threadblock::RegularTileIterator< + cutlass::MatrixShape, + ElementA, + SmemLayoutA, + 0, + IteratorThreadMapA + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, + kThreads, + cutlass::layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / cutlass::sizeof_bits::value + >; + + /// Shared memory iterator to B operand + using SmemIteratorB = cutlass::transform::threadblock::RegularTileIterator< + cutlass::MatrixShape, + ElementB, + SmemLayoutB, + 0, + IteratorThreadMapB + >; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< + MmaTensorOp, + cutlass::MatrixShape<0, 0>, + cutlass::MatrixShape<0, 0>, + WarpCount::kK + >; +}; + + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh b/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh new file mode 100644 index 00000000000..456b320b64e --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh @@ -0,0 +1,146 @@ +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +#pragma once +// Philox CUDA. + +#include + +namespace { + +class Philox { +public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) + : STATE(0) + , key(reinterpret_cast(seed)) { + //key.x = (unsigned int)seed; + //key.y = (unsigned int)(seed >> 32); + //counter = make_uint4(0, 0, 0, 0); + //counter.z = (unsigned int)(subsequence); + //counter.w = (unsigned int)(subsequence >> 32); + //STATE = 0; + //incr_n(offset / 4); + + // key = reinterpret_cast(seed); + ull2 * tmp = reinterpret_cast(&counter); + tmp->x = offset / 4; + tmp->y = subsequence; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); + // } + } + __device__ inline uint4 operator()() { + // if (STATE == 0) { + uint4 counter_ = counter; + uint2 key_ = key; + // 7-round philox + #pragma unroll + for (int i = 0; i < 6; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); + key_.y += (kPhilox10B); + } + // output = single_round(counter_, key_); + uint4 output = single_round(counter_, key_); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // } + incr(); + // } + // return a float4 directly + // unsigned long ret; + // switch(STATE) { + // case 0: ret = output.x; break; + // case 1: ret = output.y; break; + // case 2: ret = output.z; break; + // case 3: ret = output.w; break; + //} + // STATE = (STATE + 1) % 4; + return output; + } + +private: + struct ull2 { + uint64_t x; + uint64_t y; + }; + uint4 counter; + // uint4 output; + const uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + + __device__ uint4 incr128 (uint4 ctr) + { + uint4 res; + asm ("add.cc.u32 %0, %4, %8;\n\t" + "addc.cc.u32 %1, %5, %9;\n\t" + "addc.cc.u32 %2, %6, %10;\n\t" + "addc.u32 %3, %7, %11;\n\t" + : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) + : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), + "n"(1), "n"(0), "n"(0), "n"(0)); + return res; + } + + __device__ inline void incr() { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + counter = incr128(counter); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + } + __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + unsigned int *result_high) { + *result_high = __umulhi(a, b); + return a * b; + } + __device__ uint2 mulhilo32_v2 (const unsigned int a, const unsigned int b) + { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; + } + __device__ inline uint4 single_round(const uint4 ctr, const uint2 key) { + //unsigned int hi0; + //unsigned int hi1; + //unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + //unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + //uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; + } + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; +// Inverse of 2^32. +constexpr float M_RAN_INVM32 = 2.3283064e-10f; +__device__ __inline__ float4 uniform4(const uint4 x) { + return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, + x.w * M_RAN_INVM32); +} + +} // namespace diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h new file mode 100644 index 00000000000..4f4c93da5ea --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h @@ -0,0 +1,446 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include + +#include +#include + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float apply_exp_(float x, float max) { + return __expf(x - max); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float apply_exp2_(float x, float max) { + return exp2f(x - max); + // With fast-math, this produces the same PTX instruction as the assembly below + // float diff = x - max; + // float res; + // asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff)); + // return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ReadType {}; +template<> struct ReadType<4> { using T = float;}; +template<> struct ReadType<8> { using T = float2;}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_reduce { + // Helper class to distribute MMA tiles reduced over rows per warp over quads. + + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + static constexpr int MMAS_N = Mma_tile::MMAS_N; + + static constexpr int WARPS_M = Cta_tile::WARPS_M; + static constexpr int WARPS_N = Cta_tile::WARPS_N; + + + static constexpr int ROWS = WARPS_M * MMAS_M * 16; + static constexpr int COLS = WARPS_N; + static_assert(COLS == 4 || COLS == 8); + static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; + static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); + static constexpr int ELTS_PER_TILE = ROWS * COLS; + + using read_t = typename ReadType::T; + + __device__ inline Smem_tile_reduce(float *smem_, const int tidx) { + + int lane = tidx % 32; + int warp = tidx / 32; + + int warp_m = warp % WARPS_M; + int warp_n = warp / WARPS_M; + + qid_ = lane % 4; + int qp = lane / 4; + + // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. + // This won't affect reading as we assume commutative reduction ops. + const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); + smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; + smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; + smem_read_row_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qid_]; + + } + + __device__ inline void store(float (&frag)[2 * MMAS_M]) { + if( qid_ == 0 ) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + int offset = mi * 16 * WARPS_N; + smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; + smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; + } + } + } + + __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + int offset = mi * 16 * 4; + frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; + frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; + } + } + + __device__ inline void load_row(read_t (&frag)[MMAS_M], int row) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + int offset = mi * 16 * 4; + frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4]; + } + } + + int qid_; + float *smem_write_; + read_t *smem_read_; + read_t *smem_read_row_; + +}; + + +template +struct Softmax_base { + + // The Mma tile. + using Mma_tile = fmha::Hmma_tile; + + // The number of MMAs in M/N dimensions. + static constexpr int MMAS_M = Mma_tile::MMAS_M; + static constexpr int MMAS_N = Mma_tile::MMAS_N; + + // The number of groups of warp such that we have at most 4 warps writing consecutive elements. + static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4); + // The number of elements that we are going to store per row. + static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS; + // The number of rows. + static constexpr int ROWS = Cta_tile::M * GROUPS; + // The total number of elements. + static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW; + + // Ctor. + template + inline __device__ Softmax_base(const Params ¶ms, void *smem, int tidx) + : // packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), + smem_(reinterpret_cast(smem)), tidx_(tidx) { + + // Extract the position in the warp. + int warp = tidx / Cta_tile::THREADS_PER_WARP; + int lane = tidx % Cta_tile::THREADS_PER_WARP; + + // Decompose the warp index into M and N. + int warp_m = warp % Cta_tile::WARPS_M; + int warp_n = warp / Cta_tile::WARPS_M; + + // Decompose the warp-n index into group/position-inside-the-group. + int warp_g = warp_n / ELEMENTS_PER_ROW; + int warp_i = warp_n % ELEMENTS_PER_ROW; + + // The location written by the threads. + int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4; + int write_col = warp_i; + + // Assemble the write pointer. + smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; + + // Assemble the read pointer. + smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; + } + + template + inline __device__ void apply_mask(const Mask &mask) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + #pragma unroll + for( int jj = 0; jj < 4; ++jj ) { + if( !mask.is_valid(mi, ni, ii, jj) ) { + elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; + } + } + } + } + } + } + + // Apply the exp to all the elements. + template + inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) { + const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E; + const float scale = scale_ * M_LOG2E; + #pragma unroll + for( int mi = 0; mi < MMAS_M * 2; ++mi ) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + const float max_scaled = max[mi] * max_scale; + #pragma unroll + for( int ni = 0; ni < MMAS_N * 4; ++ni ) { + elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled); + } + } + } + + template + inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ni++ ) { + uint4 random_uint4 = ph(); + uint16_t (&rnd)[8] = reinterpret_cast(random_uint4); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd.x, rnd.y, rnd.z, rnd.w); + // } + #pragma unroll + for (int ii = 0; ii < 2; ++ii) { + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * ni + jj] = + encode_dropout(rnd[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); + } + } + } + } + } + + template + inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + static_assert(MMAS_N % 2 == 0); + #pragma unroll + for( int ni = 0; ni < MMAS_N; ni += 2 ) { + uint4 random_uint4 = ph0(); + uint16_t (&rnd0)[8] = reinterpret_cast(random_uint4); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd0.x, rnd0.y, rnd0.z, rnd0.w); + // } + #pragma unroll + for (int ii = 0; ii < 2; ++ii) { + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * ni + jj] = + encode_dropout(rnd0[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); + } + } + random_uint4 = ph1(); + uint16_t (&rnd1)[8] = reinterpret_cast(random_uint4); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd1.x, rnd1.y, rnd1.z, rnd1.w); + // } + #pragma unroll + for (int ii = 0; ii < 2; ++ii) { + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * (ni + 1) + jj] = + encode_dropout(rnd1[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]); + } + } + } + } + } + + // Shared memory for the CTA-wide reduction. + float *smem_, *smem_write_, *smem_read_; + // The current thread index. + int tidx_; + // The elements. + float elt_[MMAS_M * 2][MMAS_N * 4]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax : public Softmax_base { + + // The base class. + using Base = Softmax_base; + + static constexpr int WARPS_M = Cta_tile::WARPS_M; + static constexpr int WARPS_N = Cta_tile::WARPS_N; + // The MMAs. + static constexpr int MMAS_M = Base::MMAS_M; + static constexpr int MMAS_N = Base::MMAS_N; + + using Smem_tile_red = Smem_tile_reduce; + static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); + // Ctor. + template + inline __device__ Softmax(const Params ¶ms, void *smem, int tidx) + : Base(params, smem, tidx) + , smem_sum_(static_cast(smem), tidx) + , smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) { + } + + // Pack the data to a fragment for the next GEMM. + inline __device__ void pack_noconvert(cutlass::Array &frag) const { + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ki = 0; ki < MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + frag[ki * MMAS_M * 8 + mi * 8 + 0] = this->elt_[2 * mi + 0][4 * ki + 0]; + frag[ki * MMAS_M * 8 + mi * 8 + 1] = this->elt_[2 * mi + 0][4 * ki + 1]; + frag[ki * MMAS_M * 8 + mi * 8 + 4] = this->elt_[2 * mi + 0][4 * ki + 2]; + frag[ki * MMAS_M * 8 + mi * 8 + 5] = this->elt_[2 * mi + 0][4 * ki + 3]; + // 2nd row - 4 elements per row. + frag[ki * MMAS_M * 8 + mi * 8 + 2] = this->elt_[2 * mi + 1][4 * ki + 0]; + frag[ki * MMAS_M * 8 + mi * 8 + 3] = this->elt_[2 * mi + 1][4 * ki + 1]; + frag[ki * MMAS_M * 8 + mi * 8 + 6] = this->elt_[2 * mi + 1][4 * ki + 2]; + frag[ki * MMAS_M * 8 + mi * 8 + 7] = this->elt_[2 * mi + 1][4 * ki + 3]; + } + } + } + + template + inline __device__ void unpack_noscale(const FragmentC (&acc)) { + static_assert(FragmentC::kElements == MMAS_M * MMAS_N * 8, ""); + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + // 1st row - 4 elements per row. + this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi * MMAS_N * 8 + ni * 8 + 0]; + this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi * MMAS_N * 8 + ni * 8 + 1]; + this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi * MMAS_N * 8 + ni * 8 + 4]; + this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi * MMAS_N * 8 + ni * 8 + 5]; + // 2nd row - 4 elements per row. + this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi * MMAS_N * 8 + ni * 8 + 2]; + this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi * MMAS_N * 8 + ni * 8 + 3]; + this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi * MMAS_N * 8 + ni * 8 + 6]; + this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi * MMAS_N * 8 + ni * 8 + 7]; + } + } + } + + template + __device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) { + #pragma unroll + for( int mi = 0; mi < 2 * MMAS_M; mi++ ) { + frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]); + #pragma unroll + for( int ni = 1; ni < 4 * MMAS_N; ni++ ) { + frag[mi] = op(frag[mi], this->elt_[mi][ni]); + } + } + } + + template + __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) { + thread_reduce_(frag, op); + quad_reduce(frag, frag, op); + smem_red.store(frag); + __syncthreads(); + typename Smem_tile_red::read_t tmp[2 * MMAS_M]; + smem_red.load(tmp); + quad_allreduce(frag, tmp, op); + } + + template + __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ + MaxOp max; + reduce_(frag, max, smem_max_); + } + + __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ + SumOp sum; + reduce_(frag, sum, smem_sum_); + } + + template + __device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){ + SumOp sum; + thread_reduce_(frag, sum); + quad_reduce(frag, frag, sum); + smem_sum_.store(frag); + } + + template + __device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS], + Operator &op, Smem_tile_red & smem_red) { + #pragma unroll + for (int ii = 0; ii < NROWS; ii++) { + typename Smem_tile_red::read_t tmp[MMAS_M]; + smem_red.load_row(tmp, rows[ii]); + quad_allreduce(frag[ii], tmp, op); + } + } + + template + __device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS]){ + SumOp sum; + reduce_after_sync_(frag, rows, sum, smem_sum_); + } + + template + __device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M], + const int (&rows)[NROWS]){ + MaxOp max; + reduce_after_sync_(frag, rows, max, smem_max_); + } + + Smem_tile_red smem_max_; + Smem_tile_red smem_sum_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h b/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h new file mode 100644 index 00000000000..7920ac045d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/summary_stats.h b/aten/src/ATen/native/transformers/cuda/flash_attn/summary_stats.h new file mode 100644 index 00000000000..812aaea7977 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/summary_stats.h @@ -0,0 +1,55 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Smem_tile_softmax_lse { + + static constexpr int kMmaM = (kRows / kWarpCountM) / kRowsPerMma; + static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows); + // static_assert(kWarpCountM == 1); + // Otherwise we might need to check warp_idx / kWarpCountM == 0 instead of just warp_idx == 0 + + // The size of one buffer in bytes in shared memory. + static constexpr size_t BYTES_PER_TILE = kRows * sizeof(float); + + inline __device__ Smem_tile_softmax_lse(float *smem) : smem_(smem) { + } + + inline __device__ void store_pair(const float (&sum)[kMmaM * 2]) { + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + // This makes a difference of 50us for BERT. + // const int warp_idx = threadIdx.x / 32; + const int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const int lane_idx = threadIdx.x % 32; + const int warp_n = warp_idx / kWarpCountM; + // Extract the position in the warp. + const int row = lane_idx / 4; + if ((lane_idx % 4 == 0) && (warp_n == 0)) { + #pragma unroll + for (int mi = 0; mi < kMmaM; ++mi) { + smem_[mi * kRowsPerMma + row + 0] = sum[mi * 2 + 0]; + smem_[mi * kRowsPerMma + row + 8] = sum[mi * 2 + 1]; + } + } + } + + template + inline __device__ void load(float (&sum)[N], const int (&row)[N]) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + sum[ni] = smem_[row[ni]]; + } + } + + float * const smem_; +}; + +} // namespace fmha diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h new file mode 100644 index 00000000000..e70f634c26d --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h @@ -0,0 +1,404 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +// #include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Row {}; +struct Col {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int M, int N > +struct Div_up { + enum { VALUE = (M + N-1) / N }; +}; + +constexpr int DivUpConstexpr(int M, int N) { return (M + N - 1) / N; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int A, int B > +struct Max { + enum { VALUE = A >= B ? A : B }; +}; + +constexpr int MaxConstexpr(int A, int B) { return A >= B ? A : B; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int A, int B, int C > +struct Max_3 { + enum { VALUE = Max::VALUE, C>::VALUE }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int A, int B > +struct Min { + enum { VALUE = A <= B ? A : B }; +}; + +constexpr int MinConstexpr(int A, int B) { return A <= B ? A : B; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int SIZE_IN_BYTES > +struct Uint_from_size_in_bytes { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<1> { + using Type = uint8_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<2> { + using Type = uint16_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<4> { + using Type = uint32_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<8> { + using Type = uint2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Uint_from_size_in_bytes<16> { + using Type = uint4; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline __device__ __host__ T div_up(T m, T n) { + return (m + n-1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t hrelu2(uint32_t x); + +template<> +inline __device__ uint32_t hrelu2<__half>(uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint16_t float_to_half(float f) { + uint16_t h; + asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); + return h; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t float2_pack(float a, float b); + +template <> +inline __device__ uint32_t float2_pack<__half>(float a, float b) { + __half2 result = __floats2half2_rn(a, b); + return reinterpret_cast(result); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) { + __nv_bfloat162 result = __floats2bfloat162_rn(a, b); + return reinterpret_cast(result); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint2 float4_pack(float x, float y, float z, float w) { + uint2 d; + d.x = float2_pack(x, y); + d.y = float2_pack(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float2 half2_unpack(uint32_t a); + +template <> +inline __device__ float2 half2_unpack<__half>(uint32_t a) { + return __half22float2(reinterpret_cast<__half2 (&)>(a)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { + return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert two half2's or bf162's into float, then take their dot product. +template +inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { + float2 af = fmha::half2_unpack(a); + float2 bf = fmha::half2_unpack(b); + return af.x * bf.x + af.y * bf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float hmulsum8(const uint4 a, const uint4 b) { + float sum; + sum = fmha::hfma2_to_float(a.x, b.x); + sum += fmha::hfma2_to_float(a.y, b.y); + sum += fmha::hfma2_to_float(a.z, b.z); + sum += fmha::hfma2_to_float(a.w, b.w); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// L D G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint8_t &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint16_t &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint32_t &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint2 &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void ldg(uint4 &dst, const void *ptr) { + dst = *reinterpret_cast(ptr); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// S T G +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint8_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint16_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint32_t val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint2 val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void stg(void *ptr, uint4 val) { + *reinterpret_cast(ptr) = val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) { + #pragma unroll + for(int mi=0; mi < M; mi++){ + dst[mi] = src[mi]; + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); + dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) { + float tmp[M]; + #pragma unroll + for(int mi=0; mi < M; mi++){ + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_reduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) { + #pragma unroll + for(int mi=0; mi < M; mi++){ + dst[mi] = src[mi]; + dst[mi] = Allreduce<4>::run(dst[mi], op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) { + float tmp[M]; + #pragma unroll + for(int mi=0; mi < M; mi++){ + tmp[mi] = op(src[mi].x, src[mi].y); + } + quad_allreduce(dst, tmp, op); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace fmha diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 27f8381209e..a7da0abff9a 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -77,6 +77,7 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_CUDNN : ${USE_CUDNN}") message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}") message(STATUS " CUDA version : ${CUDA_VERSION}") + message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") if(${USE_CUDNN}) message(STATUS " cuDNN version : ${CUDNN_VERSION}") endif() diff --git a/setup.py b/setup.py index 5d91ab13c1c..c1cf72b42b3 100644 --- a/setup.py +++ b/setup.py @@ -322,7 +322,7 @@ def get_submodule_folders(): git_modules_path = os.path.join(cwd, ".gitmodules") default_modules_path = [os.path.join(third_party_path, name) for name in [ "gloo", "cpuinfo", "tbb", "onnx", - "foxi", "QNNPACK", "fbgemm" + "foxi", "QNNPACK", "fbgemm", "cutlass" ]] if not os.path.exists(git_modules_path): return default_modules_path diff --git a/third_party/cutlass b/third_party/cutlass new file mode 160000 index 00000000000..b72cbf957df --- /dev/null +++ b/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit b72cbf957df8cf84a6d0ff91c190ad51a9c1d24a