[ROCm] Make optional features in LoadHIP better conditioned. (#155305)

* The `rocm-core` CMake package only started appearing in ROCm 6.4, so rework the version probing to work if it is not present. Also collapses the unneeded operating system conditioning in favor of feature probing.
* Make `hipsparselt` optional: it only started appearing in ROCm 6.4 and it is not in all recent distribution channels yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155305
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Stella Laurenzo 2025-06-07 02:20:55 +00:00 committed by PyTorch MergeBot
parent 5596cefba6
commit 10cd1de518
3 changed files with 48 additions and 23 deletions

View File

@ -101,8 +101,8 @@ else()
set(AT_CUSPARSELT_ENABLED 1)
endif()
# Add hipSPARSELt support flag
if(USE_ROCM AND ROCM_VERSION VERSION_GREATER_EQUAL "6.4.0")
# Add hipSPARSELt support flag if the package is available.
if(USE_ROCM AND hipsparselt_FOUND)
set(AT_HIPSPARSELT_ENABLED 1)
else()
set(AT_HIPSPARSELT_ENABLED 0)

View File

@ -1063,7 +1063,13 @@ if(USE_ROCM)
# Math libraries
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsparselt roc::hipsolver roc::hipblaslt)
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt)
# hipsparselt is an optional component that will eventually be enabled by default.
if(hipsparselt_FOUND)
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
roc::hipsparselt
)
endif()
# ---[ Kernel asserts
# Kernel asserts is disabled for ROCm by default.

View File

@ -65,8 +65,14 @@ list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
macro(find_package_and_print_version PACKAGE_NAME)
find_package("${PACKAGE_NAME}" ${ARGN})
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR})
if(NOT ${PACKAGE_NAME}_FOUND)
message("Optional package ${PACKAGE_NAME} not found")
else()
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
if(${PACKAGE_NAME}_INCLUDE_DIR)
list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR})
endif()
endif()
endmacro()
# Find the HIP Package
@ -76,16 +82,32 @@ find_package_and_print_version(HIP 1.0 MODULE)
if(HIP_FOUND)
set(PYTORCH_FOUND_HIP TRUE)
find_package_and_print_version(hip REQUIRED CONFIG)
# Find ROCM version for checks. UNIX filename is rocm_version.h, Windows is hip_version.h
if(UNIX)
find_package_and_print_version(rocm-core REQUIRED CONFIG)
find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h
HINTS ${rocm_core_INCLUDE_DIR}/rocm-core /usr/include)
else() # Win32
find_file(ROCM_VERSION_HEADER_PATH NAMES hip_version.h
HINTS ${hip_INCLUDE_DIR}/hip)
# The rocm-core package was only introduced in ROCm 6.4, so we make it optional.
find_package(rocm-core CONFIG)
# Some old consumer HIP SDKs do not distribute rocm_version.h, so we allow
# falling back to the hip version, which everyone should have.
# rocm_version.h lives in the rocm-core package and hip_version.h lives in the
# hip (lower-case) package. Both are probed above and will be in
# ROCM_INCLUDE_DIRS if available.
find_file(ROCM_VERSION_HEADER_PATH
NAMES rocm-core/rocm_version.h
NO_DEFAULT_PATH
PATHS ${ROCM_INCLUDE_DIRS}
)
set(ROCM_LIB_NAME "ROCM")
if(NOT ROCM_VERSION_HEADER_PATH)
find_file(ROCM_VERSION_HEADER_PATH
NAMES hip/hip_version.h
NO_DEFAULT_PATH
PATHS ${ROCM_INCLUDE_DIRS}
)
set(ROCM_LIB_NAME "HIP")
endif()
if(NOT ROCM_VERSION_HEADER_PATH)
message(FATAL_ERROR "Could not find hip/hip_version.h or rocm-core/rocm_version.h in ${ROCM_INCLUDE_DIRS}")
endif()
get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME)
@ -96,15 +118,10 @@ if(HIP_FOUND)
endif()
# Read the ROCM headerfile into a variable
file(READ ${ROCM_HEADER_FILE} ROCM_HEADER_CONTENT)
message(STATUS "Reading ROCM version from: ${ROCM_HEADER_FILE}")
message(STATUS "Content: ${ROCM_HEADER_CONTENT}")
file(READ "${ROCM_HEADER_FILE}" ROCM_HEADER_CONTENT)
# Since Windows currently supports only a part of ROCm and names it HIP-SDK,
# we need to refer to the HIP-SDK equivalents of entities existing in ROCm lib.
if(UNIX)
set(ROCM_LIB_NAME "ROCM")
else() # Win32
set(ROCM_LIB_NAME "HIP")
endif()
# Below we use a RegEx to find ROCM version numbers.
# Note that CMake does not support \s for blank space. That is
# why in the regular expressions below we have a blank space in
@ -151,7 +168,6 @@ if(HIP_FOUND)
find_package_and_print_version(miopen REQUIRED)
find_package_and_print_version(hipfft REQUIRED)
find_package_and_print_version(hipsparse REQUIRED)
find_package_and_print_version(hipsparselt REQUIRED)
find_package_and_print_version(rocprim REQUIRED)
find_package_and_print_version(hipcub REQUIRED)
find_package_and_print_version(rocthrust REQUIRED)
@ -172,6 +188,9 @@ if(HIP_FOUND)
find_package_and_print_version(hsa-runtime64 REQUIRED)
endif()
# Optional components.
find_package_and_print_version(hipsparselt) # Will be required when ready.
list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)
if(UNIX)