mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ROCm] Enable finding HIP and ROCm libraries on Windows (#137279)
This PR introduces support for finding HIP-SDK Libraries on Windows. Since reading the code changes using the diff view is a bit cumbersome due to introduced if branch, let me explain what was changed: - The linux-specific steps to find HIP packages have been dragged into `if(UNIX) block` - Windows steps follow in the `else()` clause The separation was needed, because of several factors: - HIP SDK for Windows typically names its components using `hip` in their names (for exmaple: `hip_version.h` instead of `rocm_version.h`, `HIP_VERSION_DEV_MAJOR` instead of `ROCM_VERSION_DEV_MAJOR`, etc.), - The libraries included in HIP SDK are only a subset of what is available in Linux ROCm (missing hsa-rt, rccl, roctx) - MIOpen isn't a part of HIP SDK, but can be built separately and as of now requires additional path to be defined using and env var. - Windows can only find hip package in version greater than 1.0 and its libraries if the lowercase `find_package(hip ...)` is invoked first. This is because the lowercase `hip` name will cause the mechanism to find hip's packages using [config mode](https://cmake.org/cmake/help/latest/command/find_package.html#search-modes) which is the only one supported on Windows, assuming we also want to [include its libraries](https://rocm.docs.amd.com/en/latest/conceptual/cmake-packages.html#consuming-the-hip-api-in-c-code). The upper-case module-mode-seearched `find_package(HIP)` is used later for inclusion of macros such as `hip_add_library` and related macros. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137279 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
33573488d0
commit
4cbb3b4bd2
|
|
@ -883,7 +883,7 @@ cmake_dependent_option(
|
|||
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
|
||||
#
|
||||
if(USE_ROCM)
|
||||
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
|
||||
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
|
||||
include(cmake/External/aotriton.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -1073,7 +1073,9 @@ if(USE_ROCM)
|
|||
|
||||
set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
hip::amdhip64 MIOpen hiprtc::hiprtc) # libroctx will be linked in with MIOpen
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipblaslt)
|
||||
if(UNIX)
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipblaslt)
|
||||
endif(UNIX)
|
||||
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,11 @@ if(DEFINED ENV{ROCM_PATH})
|
|||
"Set a valid ROCM_PATH or unset ROCM_PATH environment variable to fix.")
|
||||
endif()
|
||||
else()
|
||||
set(ROCM_PATH /opt/rocm)
|
||||
if(UNIX)
|
||||
set(ROCM_PATH /opt/rocm)
|
||||
else() # Win32
|
||||
set(ROCM_PATH C:/opt/rocm)
|
||||
endif()
|
||||
if(NOT EXISTS ${ROCM_PATH})
|
||||
message(STATUS
|
||||
"ROCM_PATH environment variable is not set and ${ROCM_PATH} does not exist.\n"
|
||||
|
|
@ -28,7 +32,6 @@ else()
|
|||
set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
|
||||
# MAGMA_HOME
|
||||
if(NOT DEFINED ENV{MAGMA_HOME})
|
||||
set(MAGMA_HOME ${ROCM_PATH}/magma)
|
||||
|
|
@ -37,6 +40,16 @@ else()
|
|||
set(MAGMA_HOME $ENV{MAGMA_HOME})
|
||||
endif()
|
||||
|
||||
# MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different
|
||||
# installation directory.
|
||||
if(WIN32)
|
||||
if(NOT DEFINED ENV{MIOPEN_PATH})
|
||||
set(miopen_DIR C:/opt/miopen/lib/cmake/miopen)
|
||||
else()
|
||||
set(miopen_DIR $ENV{MIOPEN_PATH}/lib/cmake/miopen)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
torch_hip_get_arch_list(PYTORCH_ROCM_ARCH)
|
||||
if(PYTORCH_ROCM_ARCH STREQUAL "")
|
||||
message(FATAL_ERROR "No GPU arch specified for ROCm build. Please use PYTORCH_ROCM_ARCH environment variable to specify GPU archs to build for.")
|
||||
|
|
@ -46,7 +59,11 @@ message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}")
|
|||
# Add HIP to the CMAKE Module Path
|
||||
# needed because the find_package call to this module uses the Module mode search
|
||||
# https://cmake.org/cmake/help/latest/command/find_package.html#search-modes
|
||||
set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})
|
||||
if(UNIX)
|
||||
set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})
|
||||
else() # Win32
|
||||
set(CMAKE_MODULE_PATH ${ROCM_PATH}/cmake/ ${CMAKE_MODULE_PATH})
|
||||
endif()
|
||||
|
||||
# Add ROCM_PATH to CMAKE_PREFIX_PATH, needed because the find_package
|
||||
# call to individual ROCM components uses the Config mode search
|
||||
|
|
@ -66,15 +83,29 @@ if(HIP_FOUND)
|
|||
set(PYTORCH_FOUND_HIP TRUE)
|
||||
|
||||
# Find ROCM version for checks
|
||||
if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h)
|
||||
set(ROCM_HEADER_FILE ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h)
|
||||
if(UNIX)
|
||||
set(ROCM_VERSION_HEADER_PATH ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h)
|
||||
else()
|
||||
message(FATAL_ERROR "********************* rocm_version.h could not be found ******************\n")
|
||||
set(ROCM_VERSION_HEADER_PATH ${ROCM_INCLUDE_DIRS}/hip/hip_version.h)
|
||||
endif()
|
||||
get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME)
|
||||
|
||||
if(EXISTS ${ROCM_VERSION_HEADER_PATH})
|
||||
set(ROCM_HEADER_FILE ${ROCM_VERSION_HEADER_PATH})
|
||||
else()
|
||||
message(FATAL_ERROR "********************* ${ROCM_HEADER_NAME} could not be found ******************\n")
|
||||
endif()
|
||||
|
||||
# Read the ROCM headerfile into a variable
|
||||
file(READ ${ROCM_HEADER_FILE} ROCM_HEADER_CONTENT)
|
||||
|
||||
# Since Windows currently supports only a part of ROCm and names it HIP-SDK,
|
||||
# we need to refer to the HIP-SDK equivalents of entities existing in ROCm lib.
|
||||
if(UNIX)
|
||||
set(ROCM_LIB_NAME "ROCM")
|
||||
else() # Win32
|
||||
set(ROCM_LIB_NAME "HIP")
|
||||
endif()
|
||||
# Below we use a RegEx to find ROCM version numbers.
|
||||
# Note that CMake does not support \s for blank space. That is
|
||||
# why in the regular expressions below we have a blank space in
|
||||
|
|
@ -83,21 +114,22 @@ if(HIP_FOUND)
|
|||
# 1. Match regular expression
|
||||
# 2. Strip the non-numerical part of the string
|
||||
# 3. Strip leading and trailing spaces
|
||||
string(REGEX MATCH "ROCM_VERSION_MAJOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
|
||||
string(REPLACE "ROCM_VERSION_MAJOR" "" TEMP2 ${TEMP1})
|
||||
|
||||
string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MAJOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
|
||||
string(REPLACE "${ROCM_LIB_NAME}_VERSION_MAJOR" "" TEMP2 ${TEMP1})
|
||||
string(STRIP ${TEMP2} ROCM_VERSION_DEV_MAJOR)
|
||||
string(REGEX MATCH "ROCM_VERSION_MINOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
|
||||
string(REPLACE "ROCM_VERSION_MINOR" "" TEMP2 ${TEMP1})
|
||||
string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_MINOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
|
||||
string(REPLACE "${ROCM_LIB_NAME}_VERSION_MINOR" "" TEMP2 ${TEMP1})
|
||||
string(STRIP ${TEMP2} ROCM_VERSION_DEV_MINOR)
|
||||
string(REGEX MATCH "ROCM_VERSION_PATCH[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
|
||||
string(REPLACE "ROCM_VERSION_PATCH" "" TEMP2 ${TEMP1})
|
||||
string(REGEX MATCH "${ROCM_LIB_NAME}_VERSION_PATCH[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT})
|
||||
string(REPLACE "${ROCM_LIB_NAME}_VERSION_PATCH" "" TEMP2 ${TEMP1})
|
||||
string(STRIP ${TEMP2} ROCM_VERSION_DEV_PATCH)
|
||||
|
||||
# Create ROCM_VERSION_DEV_INT which is later used as a preprocessor macros
|
||||
set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
|
||||
math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
|
||||
|
||||
message("\n***** ROCm version from rocm_version.h ****\n")
|
||||
message("\n***** ROCm version from ${ROCM_HEADER_NAME} ****\n")
|
||||
message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
|
||||
message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
|
||||
message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
|
||||
|
|
@ -112,51 +144,56 @@ if(HIP_FOUND)
|
|||
# Find ROCM components using Config mode
|
||||
# These components will be searced for recursively in ${ROCM_PATH}
|
||||
message("\n***** Library versions from cmake find_package *****\n")
|
||||
find_package_and_print_version(hip REQUIRED)
|
||||
find_package_and_print_version(hsa-runtime64 REQUIRED)
|
||||
find_package_and_print_version(hip REQUIRED CONFIG)
|
||||
find_package_and_print_version(amd_comgr REQUIRED)
|
||||
find_package_and_print_version(rocrand REQUIRED)
|
||||
find_package_and_print_version(hiprand REQUIRED)
|
||||
find_package_and_print_version(rocblas REQUIRED)
|
||||
find_package_and_print_version(hipblas REQUIRED)
|
||||
find_package_and_print_version(hipblaslt REQUIRED)
|
||||
find_package_and_print_version(miopen REQUIRED)
|
||||
find_package_and_print_version(hipfft REQUIRED)
|
||||
find_package_and_print_version(hipsparse REQUIRED)
|
||||
find_package_and_print_version(rccl)
|
||||
find_package_and_print_version(rocprim REQUIRED)
|
||||
find_package_and_print_version(hipcub REQUIRED)
|
||||
find_package_and_print_version(rocthrust REQUIRED)
|
||||
find_package_and_print_version(hipsolver REQUIRED)
|
||||
find_package_and_print_version(hiprtc REQUIRED)
|
||||
|
||||
# roctx is part of roctracer
|
||||
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
|
||||
if(UNIX)
|
||||
find_package_and_print_version(rccl)
|
||||
find_package_and_print_version(hsa-runtime64 REQUIRED)
|
||||
find_package_and_print_version(hipblaslt REQUIRED)
|
||||
|
||||
# check whether HIP declares new types
|
||||
set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
|
||||
set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc")
|
||||
file(WRITE ${file} ""
|
||||
"#include <hip/library_types.h>\n"
|
||||
"int main() {\n"
|
||||
" hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n"
|
||||
" return 0;\n"
|
||||
"}\n"
|
||||
)
|
||||
# roctx is part of roctracer
|
||||
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
|
||||
|
||||
try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hip_compile_output)
|
||||
# check whether HIP declares new types
|
||||
set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
|
||||
set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc")
|
||||
file(WRITE ${file} ""
|
||||
"#include <hip/library_types.h>\n"
|
||||
"int main() {\n"
|
||||
" hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n"
|
||||
" return 0;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
if(hip_compile_result)
|
||||
try_compile(hip_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hip_compile_output)
|
||||
|
||||
if(hip_compile_result)
|
||||
set(HIP_NEW_TYPE_ENUMS ON)
|
||||
#message("HIP is using new type enums: ${hip_compile_output}")
|
||||
message("HIP is using new type enums")
|
||||
else()
|
||||
set(HIP_NEW_TYPE_ENUMS OFF)
|
||||
#message("HIP is NOT using new type enums: ${hip_compile_output}")
|
||||
message("HIP is NOT using new type enums")
|
||||
endif()
|
||||
else() # Win32
|
||||
# With HIP-SDK 6.2, HIP declares new enum types on Windows
|
||||
set(HIP_NEW_TYPE_ENUMS ON)
|
||||
#message("HIP is using new type enums: ${hip_compile_output}")
|
||||
message("HIP is using new type enums")
|
||||
else()
|
||||
set(HIP_NEW_TYPE_ENUMS OFF)
|
||||
#message("HIP is NOT using new type enums: ${hip_compile_output}")
|
||||
message("HIP is NOT using new type enums")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user