From c72bc61bcdca487940ebd22facc83937dbc63cec Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Wed, 20 Dec 2023 04:02:29 +0000 Subject: [PATCH] [ROCm] Fix caffe2 build with hipblasv2 api (#116073) Summary: we need this change along with D52244365 to make caffe2 build happy Test Plan: OSS CI Differential Revision: D52275058 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116073 Approved by: https://github.com/jeffdaily, https://github.com/malfet --- caffe2/utils/math_gpu.cu | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 3a587638d33..98f090ab416 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -44,8 +44,10 @@ // until we use hipblas v2 // hipify correctly maps things like CUDA_R_16F to HIP_R_16F, // however hipblas v1 is still using its custom type +#ifndef HIPBLAS_V2 #define HIP_R_16F HIPBLAS_R_16F #define HIP_R_32F HIPBLAS_R_32F +#endif // HIPBLAS_V2 #else // USE_ROCM #define CUBLAS_HALF_TYPE __half #endif // USE_ROCM @@ -618,6 +620,11 @@ CAFFE2_CUDA_EXPORT void Gemm( // It has more general hipblasGemmEx API which is more close to cublasGemmEx. // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C +#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto compute_type = HIPBLAS_COMPUTE_32F; +#else + auto compute_type = HIPBLAS_R_32F; +#endif HIPBLAS_ENFORCE(hipblasGemmEx( context->hipblas_handle(), cu_trans_B, @@ -636,7 +643,7 @@ CAFFE2_CUDA_EXPORT void Gemm( C, HIPBLAS_R_16F, N, - HIPBLAS_R_32F, // compute type + compute_type, HIPBLAS_GEMM_DEFAULT)); #else CUBLAS_ENFORCE(cublasSgemmEx( @@ -854,6 +861,11 @@ CAFFE2_CUDA_EXPORT void GemmBatched( thrust::device_vector C_device(C, C + batch_size); CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto compute_type = HIPBLAS_COMPUTE_32F; +#else + auto compute_type = CUDA_R_32F; +#endif CUBLAS_ENFORCE(cublasGemmBatchedEx( context->cublas_handle(), cu_trans_B, @@ -873,7 +885,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched( CUDA_R_16F, ldc, batch_size, - CUDA_R_32F, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else if (math_type == TensorProto_DataType_FLOAT16) { // Convert alpha, beta from float -> __half @@ -945,6 +957,11 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( if (math_type == TensorProto_DataType_FLOAT) { CUBLAS_ENFORCE(cublasSetPointerMode( context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto compute_type = HIPBLAS_COMPUTE_32F; +#else + auto compute_type = CUDA_R_32F; +#endif CUBLAS_ENFORCE(cublasGemmStridedBatchedEx( context->cublas_handle(), cu_trans_B, @@ -967,7 +984,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched( ldc, C_stride, batch_size, - CUDA_R_32F, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else if (math_type == TensorProto_DataType_FLOAT16) { // Convert alpha, beta from float -> __half @@ -1059,6 +1076,11 @@ CAFFE2_CUDA_EXPORT void Gemv( // It has more general hipblasGemmEx API which is more close to cublasGemmEx. // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C +#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto compute_type = HIPBLAS_COMPUTE_32F; +#else + auto compute_type = HIPBLAS_R_32F; +#endif HIPBLAS_ENFORCE(hipblasGemmEx( context->hipblas_handle(), cu_trans_A, @@ -1077,7 +1099,7 @@ CAFFE2_CUDA_EXPORT void Gemv( y, HIPBLAS_R_16F, ldc, - HIPBLAS_R_32F, // compute type + compute_type, HIPBLAS_GEMM_DEFAULT)); #else CUBLAS_ENFORCE(cublasSgemmEx(