mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Fix caffe2 build with hipblasv2 api (#116073)
Summary: we need this change along with D52244365 to make caffe2 build happy Test Plan: OSS CI Differential Revision: D52275058 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116073 Approved by: https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
parent
a597a00c87
commit
c72bc61bcd
|
|
@ -44,8 +44,10 @@
|
||||||
// until we use hipblas v2
|
// until we use hipblas v2
|
||||||
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
|
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
|
||||||
// however hipblas v1 is still using its custom type
|
// however hipblas v1 is still using its custom type
|
||||||
|
#ifndef HIPBLAS_V2
|
||||||
#define HIP_R_16F HIPBLAS_R_16F
|
#define HIP_R_16F HIPBLAS_R_16F
|
||||||
#define HIP_R_32F HIPBLAS_R_32F
|
#define HIP_R_32F HIPBLAS_R_32F
|
||||||
|
#endif // HIPBLAS_V2
|
||||||
#else // USE_ROCM
|
#else // USE_ROCM
|
||||||
#define CUBLAS_HALF_TYPE __half
|
#define CUBLAS_HALF_TYPE __half
|
||||||
#endif // USE_ROCM
|
#endif // USE_ROCM
|
||||||
|
|
@ -618,6 +620,11 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
|
||||||
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
||||||
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
||||||
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||||
|
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||||
|
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||||
|
#else
|
||||||
|
auto compute_type = HIPBLAS_R_32F;
|
||||||
|
#endif
|
||||||
HIPBLAS_ENFORCE(hipblasGemmEx(
|
HIPBLAS_ENFORCE(hipblasGemmEx(
|
||||||
context->hipblas_handle(),
|
context->hipblas_handle(),
|
||||||
cu_trans_B,
|
cu_trans_B,
|
||||||
|
|
@ -636,7 +643,7 @@ CAFFE2_CUDA_EXPORT void Gemm<at::Half, CUDAContext>(
|
||||||
C,
|
C,
|
||||||
HIPBLAS_R_16F,
|
HIPBLAS_R_16F,
|
||||||
N,
|
N,
|
||||||
HIPBLAS_R_32F, // compute type
|
compute_type,
|
||||||
HIPBLAS_GEMM_DEFAULT));
|
HIPBLAS_GEMM_DEFAULT));
|
||||||
#else
|
#else
|
||||||
CUBLAS_ENFORCE(cublasSgemmEx(
|
CUBLAS_ENFORCE(cublasSgemmEx(
|
||||||
|
|
@ -854,6 +861,11 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
|
||||||
thrust::device_vector<void*> C_device(C, C + batch_size);
|
thrust::device_vector<void*> C_device(C, C + batch_size);
|
||||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||||
|
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||||
|
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||||
|
#else
|
||||||
|
auto compute_type = CUDA_R_32F;
|
||||||
|
#endif
|
||||||
CUBLAS_ENFORCE(cublasGemmBatchedEx(
|
CUBLAS_ENFORCE(cublasGemmBatchedEx(
|
||||||
context->cublas_handle(),
|
context->cublas_handle(),
|
||||||
cu_trans_B,
|
cu_trans_B,
|
||||||
|
|
@ -873,7 +885,7 @@ CAFFE2_CUDA_EXPORT void GemmBatched<at::Half, CUDAContext>(
|
||||||
CUDA_R_16F,
|
CUDA_R_16F,
|
||||||
ldc,
|
ldc,
|
||||||
batch_size,
|
batch_size,
|
||||||
CUDA_R_32F,
|
compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
||||||
// Convert alpha, beta from float -> __half
|
// Convert alpha, beta from float -> __half
|
||||||
|
|
@ -945,6 +957,11 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
|
||||||
if (math_type == TensorProto_DataType_FLOAT) {
|
if (math_type == TensorProto_DataType_FLOAT) {
|
||||||
CUBLAS_ENFORCE(cublasSetPointerMode(
|
CUBLAS_ENFORCE(cublasSetPointerMode(
|
||||||
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
|
||||||
|
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||||
|
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||||
|
#else
|
||||||
|
auto compute_type = CUDA_R_32F;
|
||||||
|
#endif
|
||||||
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
|
CUBLAS_ENFORCE(cublasGemmStridedBatchedEx(
|
||||||
context->cublas_handle(),
|
context->cublas_handle(),
|
||||||
cu_trans_B,
|
cu_trans_B,
|
||||||
|
|
@ -967,7 +984,7 @@ CAFFE2_CUDA_EXPORT void GemmStridedBatched<at::Half, CUDAContext>(
|
||||||
ldc,
|
ldc,
|
||||||
C_stride,
|
C_stride,
|
||||||
batch_size,
|
batch_size,
|
||||||
CUDA_R_32F,
|
compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
} else if (math_type == TensorProto_DataType_FLOAT16) {
|
||||||
// Convert alpha, beta from float -> __half
|
// Convert alpha, beta from float -> __half
|
||||||
|
|
@ -1059,6 +1076,11 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
|
||||||
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
// It has more general hipblasGemmEx API which is more close to cublasGemmEx.
|
||||||
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
// hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C,
|
||||||
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
// whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C
|
||||||
|
#if ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
|
||||||
|
auto compute_type = HIPBLAS_COMPUTE_32F;
|
||||||
|
#else
|
||||||
|
auto compute_type = HIPBLAS_R_32F;
|
||||||
|
#endif
|
||||||
HIPBLAS_ENFORCE(hipblasGemmEx(
|
HIPBLAS_ENFORCE(hipblasGemmEx(
|
||||||
context->hipblas_handle(),
|
context->hipblas_handle(),
|
||||||
cu_trans_A,
|
cu_trans_A,
|
||||||
|
|
@ -1077,7 +1099,7 @@ CAFFE2_CUDA_EXPORT void Gemv<at::Half, CUDAContext>(
|
||||||
y,
|
y,
|
||||||
HIPBLAS_R_16F,
|
HIPBLAS_R_16F,
|
||||||
ldc,
|
ldc,
|
||||||
HIPBLAS_R_32F, // compute type
|
compute_type,
|
||||||
HIPBLAS_GEMM_DEFAULT));
|
HIPBLAS_GEMM_DEFAULT));
|
||||||
#else
|
#else
|
||||||
CUBLAS_ENFORCE(cublasSgemmEx(
|
CUBLAS_ENFORCE(cublasSgemmEx(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user