mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ROCm] backward compatible type enums (#118137)
Fixes builds of pytorch using unreleased ROCm packages that are missing type enums introduced in ROCm 6.0 release. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118137 Approved by: https://github.com/xw285cornell, https://github.com/anupambhatnagar
This commit is contained in:
parent
f8e14f3b46
commit
2c9a90cde6
|
|
@ -93,11 +93,16 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
|
|||
return CUDA_R_8F_E5M2;
|
||||
#endif
|
||||
#else // USE_ROCM
|
||||
#if ROCM_VERSION >= 60000
|
||||
#if defined(HIP_NEW_TYPE_ENUMS)
|
||||
case c10::ScalarType::Float8_e4m3fnuz:
|
||||
return HIP_R_8F_E4M3_FNUZ;
|
||||
case c10::ScalarType::Float8_e5m2fnuz:
|
||||
return HIP_R_8F_E5M2_FNUZ;
|
||||
#else
|
||||
case c10::ScalarType::Float8_e4m3fnuz:
|
||||
return static_cast<hipDataType>(1000);
|
||||
case c10::ScalarType::Float8_e5m2fnuz:
|
||||
return static_cast<hipDataType>(1001);
|
||||
#endif
|
||||
#endif
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -1279,6 +1279,9 @@ if(USE_ROCM)
|
|||
if(HIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
||||
endif()
|
||||
if(HIP_NEW_TYPE_ENUMS)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
|
||||
endif()
|
||||
add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
|
||||
add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
|
||||
message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines")
|
||||
|
|
|
|||
|
|
@ -243,4 +243,29 @@ if(HIP_FOUND)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
# check whether HIP declares new types
|
||||
set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc")
|
||||
file(WRITE ${file} ""
|
||||
"#include <hip/library_types.h>\n"
|
||||
"int main() {\n"
|
||||
" hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n"
|
||||
" return 0;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
||||
|
||||
if(hipblaslt_compile_result)
|
||||
set(HIP_NEW_TYPE_ENUMS ON)
|
||||
#message("HIP is using new type enums: ${hipblaslt_compile_output}")
|
||||
message("HIP is using new type enums")
|
||||
else()
|
||||
set(HIP_NEW_TYPE_ENUMS OFF)
|
||||
#message("HIP is NOT using new type enums: ${hipblaslt_compile_output}")
|
||||
message("HIP is NOT using new type enums")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user