diff --git a/aten/src/ATen/cuda/CUDADataType.h b/aten/src/ATen/cuda/CUDADataType.h index 92259edd63d..e48a3d51eaf 100644 --- a/aten/src/ATen/cuda/CUDADataType.h +++ b/aten/src/ATen/cuda/CUDADataType.h @@ -93,11 +93,16 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) return CUDA_R_8F_E5M2; #endif #else // USE_ROCM -#if ROCM_VERSION >= 60000 +#if defined(HIP_NEW_TYPE_ENUMS) case c10::ScalarType::Float8_e4m3fnuz: return HIP_R_8F_E4M3_FNUZ; case c10::ScalarType::Float8_e5m2fnuz: return HIP_R_8F_E5M2_FNUZ; +#else + case c10::ScalarType::Float8_e4m3fnuz: + return static_cast(1000); + case c10::ScalarType::Float8_e5m2fnuz: + return static_cast(1001); #endif #endif default: diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 4b009211708..4dd80420584 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1279,6 +1279,9 @@ if(USE_ROCM) if(HIPBLASLT_CUSTOM_COMPUTE_TYPE) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE) endif() + if(HIP_NEW_TYPE_ENUMS) + list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS) + endif() add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT}) add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION}) message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 197c60e7b7f..1abeb06228e 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -243,4 +243,29 @@ if(HIP_FOUND) endif() endif() + # check whether HIP declares new types + set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc") + file(WRITE ${file} "" + "#include \n" + "int main() {\n" + " hipDataType baz = HIP_R_8F_E4M3_FNUZ;\n" + " return 0;\n" + "}\n" + ) + + try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hipblaslt_compile_output) + + if(hipblaslt_compile_result) + set(HIP_NEW_TYPE_ENUMS ON) + #message("HIP is using new type enums: ${hipblaslt_compile_output}") + message("HIP is using new type enums") + else() + set(HIP_NEW_TYPE_ENUMS OFF) + #message("HIP is NOT using new type enums: ${hipblaslt_compile_output}") + message("HIP is NOT using new type enums") + endif() + endif()