diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b682eb96cf..c8af5f00b5c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -883,7 +883,7 @@ cmake_dependent_option( # USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake # if(USE_ROCM) - if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) + if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)) include(cmake/External/aotriton.cmake) endif() endif() diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 61f571e6bdc..1831f84edb4 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1073,7 +1073,9 @@ if(USE_ROCM) set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS hip::amdhip64 MIOpen hiprtc::hiprtc) # libroctx will be linked in with MIOpen - list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipblaslt) + if(UNIX) + list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipblaslt) + endif(UNIX) list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver) diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 1499977f8e4..3eb34b0b833 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -13,7 +13,11 @@ if(DEFINED ENV{ROCM_PATH}) "Set a valid ROCM_PATH or unset ROCM_PATH environment variable to fix.") endif() else() - set(ROCM_PATH /opt/rocm) + if(UNIX) + set(ROCM_PATH /opt/rocm) + else() # Win32 + set(ROCM_PATH C:/opt/rocm) + endif() if(NOT EXISTS ${ROCM_PATH}) message(STATUS "ROCM_PATH environment variable is not set and ${ROCM_PATH} does not exist.\n" @@ -28,7 +32,6 @@ else() set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS}) endif() - # MAGMA_HOME if(NOT DEFINED ENV{MAGMA_HOME}) set(MAGMA_HOME ${ROCM_PATH}/magma) @@ -37,6 +40,16 @@ else() set(MAGMA_HOME $ENV{MAGMA_HOME}) endif() +# MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different +# installation directory. +if(WIN32) + if(NOT DEFINED ENV{MIOPEN_PATH}) + set(miopen_DIR C:/opt/miopen/lib/cmake/miopen) + else() + set(miopen_DIR $ENV{MIOPEN_PATH}/lib/cmake/miopen) + endif() +endif() + torch_hip_get_arch_list(PYTORCH_ROCM_ARCH) if(PYTORCH_ROCM_ARCH STREQUAL "") message(FATAL_ERROR "No GPU arch specified for ROCm build. Please use PYTORCH_ROCM_ARCH environment variable to specify GPU archs to build for.") @@ -46,7 +59,11 @@ message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}") # Add HIP to the CMAKE Module Path # needed because the find_package call to this module uses the Module mode search # https://cmake.org/cmake/help/latest/command/find_package.html#search-modes -set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH}) +if(UNIX) + set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH}) +else() # Win32 + set(CMAKE_MODULE_PATH ${ROCM_PATH}/cmake/ ${CMAKE_MODULE_PATH}) +endif() # Add ROCM_PATH to CMAKE_PREFIX_PATH, needed because the find_package # call to individual ROCM components uses the Config mode search @@ -66,15 +83,29 @@ if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) # Find ROCM version for checks - if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) - set(ROCM_HEADER_FILE ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) + if(UNIX) + set(ROCM_VERSION_HEADER_PATH ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) else() - message(FATAL_ERROR "********************* rocm_version.h could not be found ******************\n") + set(ROCM_VERSION_HEADER_PATH ${ROCM_INCLUDE_DIRS}/hip/hip_version.h) + endif() + get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME) + + if(EXISTS ${ROCM_VERSION_HEADER_PATH}) + set(ROCM_HEADER_FILE ${ROCM_VERSION_HEADER_PATH}) + else() + message(FATAL_ERROR "********************* ${ROCM_HEADER_NAME} could not be found ******************\n") endif() # Read the ROCM headerfile into a variable file(READ ${ROCM_HEADER_FILE} ROCM_HEADER_CONTENT) + # Since Windows currently supports only a part of ROCm and names it HIP-SDK, + # we need to refer to the HIP-SDK equivalents of entities existing in ROCm lib. + if(UNIX) + set(ROCM_LIB_NAME "ROCM") + else() # Win32 + set(ROCM_LIB_NAME "HIP") + endif() # Below we use a RegEx to find ROCM version numbers. # Note that CMake does not support \s for blank space. That is # why in the regular expressions below we have a blank space in @@ -83,21 +114,22 @@ if(HIP_FOUND) # 1. Match regular expression # 2. Strip the non-numerical part of the string # 3. Strip leading and trailing spaces - string(REGEX MATCH "ROCM_VERSION_MAJOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) - string(REPLACE "ROCM_VERSION_MAJOR" "" TEMP2 ${TEMP1}) + + string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MAJOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "${ROCM_LIB_NAME}_VERSION_MAJOR" "" TEMP2 ${TEMP1}) string(STRIP ${TEMP2} ROCM_VERSION_DEV_MAJOR) - string(REGEX MATCH "ROCM_VERSION_MINOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) - string(REPLACE "ROCM_VERSION_MINOR" "" TEMP2 ${TEMP1}) + string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MINOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "${ROCM_LIB_NAME}_VERSION_MINOR" "" TEMP2 ${TEMP1}) string(STRIP ${TEMP2} ROCM_VERSION_DEV_MINOR) - string(REGEX MATCH "ROCM_VERSION_PATCH[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) - string(REPLACE "ROCM_VERSION_PATCH" "" TEMP2 ${TEMP1}) + string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_PATCH[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "${ROCM_LIB_NAME}_VERSION_PATCH" "" TEMP2 ${TEMP1}) string(STRIP ${TEMP2} ROCM_VERSION_DEV_PATCH) # Create ROCM_VERSION_DEV_INT which is later used as a preprocessor macros set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") - message("\n***** ROCm version from rocm_version.h ****\n") + message("\n***** ROCm version from ${ROCM_HEADER_NAME} ****\n") message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") @@ -112,51 +144,56 @@ if(HIP_FOUND) # Find ROCM components using Config mode # These components will be searced for recursively in ${ROCM_PATH} message("\n***** Library versions from cmake find_package *****\n") - find_package_and_print_version(hip REQUIRED) - find_package_and_print_version(hsa-runtime64 REQUIRED) + find_package_and_print_version(hip REQUIRED CONFIG) find_package_and_print_version(amd_comgr REQUIRED) find_package_and_print_version(rocrand REQUIRED) find_package_and_print_version(hiprand REQUIRED) find_package_and_print_version(rocblas REQUIRED) find_package_and_print_version(hipblas REQUIRED) - find_package_and_print_version(hipblaslt REQUIRED) find_package_and_print_version(miopen REQUIRED) find_package_and_print_version(hipfft REQUIRED) find_package_and_print_version(hipsparse REQUIRED) - find_package_and_print_version(rccl) find_package_and_print_version(rocprim REQUIRED) find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(rocthrust REQUIRED) find_package_and_print_version(hipsolver REQUIRED) find_package_and_print_version(hiprtc REQUIRED) - # roctx is part of roctracer - find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib) + if(UNIX) + find_package_and_print_version(rccl) + find_package_and_print_version(hsa-runtime64 REQUIRED) + find_package_and_print_version(hipblaslt REQUIRED) - # check whether HIP declares new types - set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") - set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc") - file(WRITE ${file} "" - "#include \n" - "int main() {\n" - " hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n" - " return 0;\n" - "}\n" - ) + # roctx is part of roctracer + find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib) - try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} - CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" - COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ - OUTPUT_VARIABLE hip_compile_output) + # check whether HIP declares new types + set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") + set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc") + file(WRITE ${file} "" + "#include \n" + "int main() {\n" + " hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n" + " return 0;\n" + "}\n" + ) - if(hip_compile_result) + try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hip_compile_output) + + if(hip_compile_result) + set(HIP_NEW_TYPE_ENUMS ON) + #message("HIP is using new type enums: ${hip_compile_output}") + message("HIP is using new type enums") + else() + set(HIP_NEW_TYPE_ENUMS OFF) + #message("HIP is NOT using new type enums: ${hip_compile_output}") + message("HIP is NOT using new type enums") + endif() + else() # Win32 + # With HIP-SDK 6.2, HIP declares new enum types on Windows set(HIP_NEW_TYPE_ENUMS ON) - #message("HIP is using new type enums: ${hip_compile_output}") - message("HIP is using new type enums") - else() - set(HIP_NEW_TYPE_ENUMS OFF) - #message("HIP is NOT using new type enums: ${hip_compile_output}") - message("HIP is NOT using new type enums") endif() - endif()