diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a43e0da8f2..efad5419aaf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -874,7 +874,7 @@ cmake_dependent_option( "Whether to build the flash_attention kernel for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA OR USE_ROCM;NOT MSVC" + "USE_CUDA OR USE_ROCM" OFF) cmake_dependent_option( @@ -909,7 +909,7 @@ cmake_dependent_option( # USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake # if(USE_ROCM) - if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) include(cmake/External/aotriton.cmake) endif() endif() diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index b8b43e0086c..c2193f2378d 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -95,6 +95,72 @@ #endif #endif +#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)) +namespace pytorch_flash +{ +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_ROCM_CK_SDPA) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + std::optional dummy_attn_bias = std::nullopt; + return mha_fwd_ck( + q, + k, + v, + out_, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} +} +#endif + namespace at { namespace cuda::philox { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index f6f2240d4f0..71a19590659 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -270,7 +270,7 @@ std::tuple mha_varle #endif TORCH_API -inline std::tuple< +std::tuple< at::Tensor, at::Tensor, at::Tensor, @@ -294,42 +294,7 @@ mha_fwd( std::optional window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { -#if defined(USE_ROCM_CK_SDPA) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); - std::optional dummy_attn_bias = std::nullopt; - return mha_fwd_ck( - q, - k, - v, - out_, - p_dropout, - softmax_scale, - is_causal, - non_null_window_left, - non_null_window_right, - return_softmax, - gen_, - dummy_attn_bias); // Not used in flash attention - } -#endif - return mha_fwd_aot( - q, - k, - v, - out_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); -} + std::optional gen_); inline std::tuple< at::Tensor, diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 5d915877465..4f7a79a78bf 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -45,13 +45,88 @@ if(NOT __AOTRITON_INCLUDED) ) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") + # Set the default __AOTRITON_LIB path + set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so") + if(WIN32) + set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib") + endif() + + function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + # Windows-specific dependencies - build these first + if(NOT noimage) + message(FATAL_ERROR "noimage must be ON for Windows builds") + endif() + # Build dlfcn-win32 + set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32") + set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install") + + ExternalProject_Add(${dlfcn-win32_external} + GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git + GIT_TAG v1.4.2 + PREFIX ${__DLFCN_WIN32_PREFIX} + INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_C_COMPILER=cl + -DCMAKE_CXX_COMPILER=cl + -DBUILD_SHARED_LIBS=ON + -DBUILD_TESTS=OFF + BUILD_BYPRODUCTS + "${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib" + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + ) + ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE) + + # Build xz/liblzma + set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz") + set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install") + + ExternalProject_Add(${xz_external} + GIT_REPOSITORY https://github.com/tukaani-project/xz.git + GIT_TAG v5.8.1 + PREFIX ${__XZ_PREFIX} + INSTALL_DIR ${__XZ_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DBUILD_SHARED_LIBS=ON + -DENABLE_NLS=OFF + -DXZ_TOOL_LZMAINFO=OFF + -DXZ_TOOL_XZ=OFF + -DXZ_TOOL_XZDEC=OFF + -DXZ_TOOL_LZMADEC=OFF + BUILD_BYPRODUCTS + "${__XZ_INSTALL_DIR}/lib/lzma.lib" + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + ) + ExternalProject_Add_Step(${xz_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE) + endfunction() + function(aotriton_build_from_source noimage project) if(noimage) SET(RECURSIVE "OFF") else() SET(RECURSIVE "ON") endif() + if(WIN32) + message(STATUS "Building AOTriton Windows dependencies") + aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + endif() message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}") + ExternalProject_Add(${project} GIT_REPOSITORY https://github.com/ROCm/aotriton.git GIT_SUBMODULES_RECURSE ${RECURSIVE} @@ -65,12 +140,19 @@ if(NOT __AOTRITON_INCLUDED) -DAOTRITON_GPU_BUILD_TIMEOUT=0 -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NOIMAGE_MODE=${noimage} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + -DHIP_PLATFORM=amd + $<$:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}> + $<$:-Dliblzma_DIR=${liblzma_DIR}> + BUILD_BYPRODUCTS + "${__AOTRITON_LIB}" USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_BUILD TRUE USES_TERMINAL_INSTALL TRUE ) + if(WIN32) + add_dependencies(${project} dlfcn-win32_external xz_external) + endif() endfunction() set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) @@ -95,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED) INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${__AOTRITON_INSTALL_DIR}" - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + BUILD_BYPRODUCTS "${__AOTRITON_LIB}" ) message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") @@ -111,14 +193,35 @@ if(NOT __AOTRITON_INCLUDED) string(CONCAT __AOTRITON_URL "${__AOTRITON_BASE_URL}" "${__AOTRITON_VER}/${__AOTRITON_FILE}") + + # Set up directories + set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image}) + set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}) + set(__DOWNLOAD_NO_EXTRACT "") + set(__BUILD_COMMANDS "") + + # On Windows, we need custom tar extraction with UTF-8 support + if(WIN32) + set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE") + set(__BUILD_COMMANDS + COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}" + COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}" + ) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton) + endif() + ExternalProject_Add(${project} URL "${__AOTRITON_URL}" URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image} + DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR} + ${__DOWNLOAD_NO_EXTRACT} + SOURCE_DIR ${__AOTRITON_EXTRACT_DIR} CONFIGURE_COMMAND "" BUILD_COMMAND "" + ${__BUILD_COMMANDS} INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory - "${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}" + "${__AOTRITON_INSTALL_SOURCE_DIR}" "${__AOTRITON_INSTALL_DIR}" BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" @@ -164,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED) endforeach() endforeach() endif() - target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so) + target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB}) target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index 706881a8f10..c4a250db048 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -12,6 +12,7 @@ BU contiguities contiguity coo +DEPENDEES deser din dout