[ROCm] backward compatible type enums (#118137)

Fixes builds of pytorch using unreleased ROCm packages that are missing type enums introduced in ROCm 6.0 release.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118137
Approved by: https://github.com/xw285cornell, https://github.com/anupambhatnagar
This commit is contained in:
Jeff Daily 2024-01-26 08:40:13 +00:00 committed by PyTorch MergeBot
parent f8e14f3b46
commit 2c9a90cde6
3 changed files with 34 additions and 1 deletions

View File

@ -93,11 +93,16 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
return CUDA_R_8F_E5M2;
#endif
#else // USE_ROCM
#if ROCM_VERSION >= 60000
#if defined(HIP_NEW_TYPE_ENUMS)
case c10::ScalarType::Float8_e4m3fnuz:
return HIP_R_8F_E4M3_FNUZ;
case c10::ScalarType::Float8_e5m2fnuz:
return HIP_R_8F_E5M2_FNUZ;
#else
case c10::ScalarType::Float8_e4m3fnuz:
return static_cast<hipDataType>(1000);
case c10::ScalarType::Float8_e5m2fnuz:
return static_cast<hipDataType>(1001);
#endif
#endif
default:

View File

@ -1279,6 +1279,9 @@ if(USE_ROCM)
if(HIPBLASLT_CUSTOM_COMPUTE_TYPE)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE)
endif()
if(HIP_NEW_TYPE_ENUMS)
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
endif()
add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines")

View File

@ -243,4 +243,29 @@ if(HIP_FOUND)
endif()
endif()
# check whether HIP declares new types
set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc")
file(WRITE ${file} ""
"#include <hip/library_types.h>\n"
"int main() {\n"
" hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n"
" return 0;\n"
"}\n"
)
try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
OUTPUT_VARIABLE hipblaslt_compile_output)
if(hipblaslt_compile_result)
set(HIP_NEW_TYPE_ENUMS ON)
#message("HIP is using new type enums: ${hipblaslt_compile_output}")
message("HIP is using new type enums")
else()
set(HIP_NEW_TYPE_ENUMS OFF)
#message("HIP is NOT using new type enums: ${hipblaslt_compile_output}")
message("HIP is NOT using new type enums")
endif()
endif()