[ROCm] Fix caffe2 build with hipblasv2 api (#116073)

Summary: we need this change along with D52244365 to make caffe2 build happy

Test Plan: OSS CI

Differential Revision: D52275058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116073
Approved by: https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
Xiaodong Wang 2023-12-20 04:02:29 +00:00 committed by PyTorch MergeBot
parent a597a00c87
commit c72bc61bcd

View File

@ -44,8 +44,10 @@
// until we use hipblas v2 // until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F, // hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type // however hipblas v1 is still using its custom type
#ifndef HIPBLAS_V2
#define HIP_R_16F HIPBLAS_R_16F #define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F #define HIP_R_32F HIPBLAS_R_32F
#endif // HIPBLAS_V2
#else // USE_ROCM #else // USE_ROCM
#define CUBLAS_HALF_TYPE __half #define CUBLAS_HALF_TYPE __half
#endif // USE_ROCM #endif // USE_ROCM
@ -618,6 +620,11 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
// It has more general hipblasGemmEx API which is more close to cublasGemmEx. // It has more general hipblasGemmEx API which is more close to cublasGemmEx.
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto compute_type = HIPBLAS_COMPUTE_32F;
#else
auto compute_type = HIPBLAS_R_32F;
#endif
HIPBLAS_ENFORCE(hipblasGemmEx( HIPBLAS_ENFORCE(hipblasGemmEx(
context->hipblas_handle(), context->hipblas_handle(),
cu_trans_B, cu_trans_B,
@ -636,7 +643,7 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
C, C,
HIPBLAS_R_16F, HIPBLAS_R_16F,
N, N,
HIPBLAS_R_32F, // compute type compute_type,
HIPBLAS_GEMM_DEFAULT)); HIPBLAS_GEMM_DEFAULT));
#else #else
CUBLAS_ENFORCE(cublasSgemmEx( CUBLAS_ENFORCE(cublasSgemmEx(
@ -854,6 +861,11 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
thrust::device_vector<void*> C_device(C, C + batch_size); thrust::device_vector<void*> C_device(C, C + batch_size);
CUBLAS_ENFORCE(cublasSetPointerMode( CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto compute_type = HIPBLAS_COMPUTE_32F;
#else
auto compute_type = CUDA_R_32F;
#endif
CUBLAS_ENFORCE(cublasGemmBatchedEx( CUBLAS_ENFORCE(cublasGemmBatchedEx(
context->cublas_handle(), context->cublas_handle(),
cu_trans_B, cu_trans_B,
@ -873,7 +885,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
CUDA_R_16F, CUDA_R_16F,
ldc, ldc,
batch_size, batch_size,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else if (math_type == TensorProto_DataType_FLOAT16) { } else if (math_type == TensorProto_DataType_FLOAT16) {
// Convert alpha, beta from float -> __half // Convert alpha, beta from float -> __half
@ -945,6 +957,11 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
if (math_type == TensorProto_DataType_FLOAT) { if (math_type == TensorProto_DataType_FLOAT) {
CUBLAS_ENFORCE(cublasSetPointerMode( CUBLAS_ENFORCE(cublasSetPointerMode(
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto compute_type = HIPBLAS_COMPUTE_32F;
#else
auto compute_type = CUDA_R_32F;
#endif
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx( CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
context->cublas_handle(), context->cublas_handle(),
cu_trans_B, cu_trans_B,
@ -967,7 +984,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
ldc, ldc,
C_stride, C_stride,
batch_size, batch_size,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else if (math_type == TensorProto_DataType_FLOAT16) { } else if (math_type == TensorProto_DataType_FLOAT16) {
// Convert alpha, beta from float -> __half // Convert alpha, beta from float -> __half
@ -1059,6 +1076,11 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
// It has more general hipblasGemmEx API which is more close to cublasGemmEx. // It has more general hipblasGemmEx API which is more close to cublasGemmEx.
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto compute_type = HIPBLAS_COMPUTE_32F;
#else
auto compute_type = HIPBLAS_R_32F;
#endif
HIPBLAS_ENFORCE(hipblasGemmEx( HIPBLAS_ENFORCE(hipblasGemmEx(
context->hipblas_handle(), context->hipblas_handle(),
cu_trans_A, cu_trans_A,
@ -1077,7 +1099,7 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
y, y,
HIPBLAS_R_16F, HIPBLAS_R_16F,
ldc, ldc,
HIPBLAS_R_32F, // compute type compute_type,
HIPBLAS_GEMM_DEFAULT)); HIPBLAS_GEMM_DEFAULT));
#else #else
CUBLAS_ENFORCE(cublasSgemmEx( CUBLAS_ENFORCE(cublasSgemmEx(