Switching to rocblas_gemm_ex for MFMA-enabled architectures

This commit is contained in:
Eugene Kuznetsov 2021-02-03 18:00:14 +00:00 committed by Deven Desai
parent 63019e2b06
commit 5a336a4150
3 changed files with 53 additions and 3 deletions

View File

@ -474,6 +474,11 @@ class GpuDriver {
static port::Status GetGpuGCNArchName(GpuDeviceHandle device,
std::string* gcnArchName);
#if TENSORFLOW_USE_ROCM
// tests the current device for MFMA insn support (ROCm only)
static port::Status GetMFMASupport(bool& support);
#endif
// Returns the number of multiprocessors on the device (note that the device
// may be multi-GPU-per-board).
static port::StatusOr<int> GetMultiprocessorCount(GpuDeviceHandle device);

View File

@ -267,6 +267,7 @@ namespace wrap {
__macro(rocblas_ztrmm) \
__macro(rocblas_sgeam) \
__macro(rocblas_dgeam) \
__macro(rocblas_gemm_ex) \
/*__macro(rocblas_cgeam) \
__macro(rocblas_zgeam) \
__macro(rocblas_sdgmm) \
@ -1617,9 +1618,13 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
"precondition violation";
}
}
const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta);
return DoBlasInternal(
bool hasXDLOPS = false;
auto status = GpuDriver::GetMFMASupport(hasXDLOPS);
if(!hasXDLOPS) {
VLOG(1) << "Using rocblas_hgemm";
const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta);
return DoBlasInternal(
wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
reinterpret_cast<const rocblas_half *>(&alpha_half),
@ -1627,6 +1632,21 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
reinterpret_cast<const rocblas_half *>(&beta_half),
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
} else {
VLOG(1) << "Using rocblas_gemm_ex";
return DoBlasInternal(
wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb),
(rocblas_int)m, (rocblas_int)n, (rocblas_int)k,
reinterpret_cast<const void*>(&alpha),
reinterpret_cast<const void*>(GpuMemory(a)), rocblas_datatype_f16_r, lda,
reinterpret_cast<const void*>(GpuMemory(b)), rocblas_datatype_f16_r, ldb,
reinterpret_cast<const void*>(&beta),
reinterpret_cast<const void*>(GpuMemoryMutable(c)),
rocblas_datatype_f16_r, ldc,
reinterpret_cast<void*>(GpuMemoryMutable(c)), rocblas_datatype_f16_r, ldc,
rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0);
}
}
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,

View File

@ -1095,6 +1095,31 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
device)};
}
/* static */ port::Status GpuDriver::GetMFMASupport(bool& supported) {
supported = false;
hipDeviceProp_t props;
int dev = 0;
hipError_t result = hipGetDevice(&dev);
result = tensorflow::wrap::hipGetDeviceProperties(&props, dev);
if (result == hipSuccess) {
std::string gcnArchName = props.gcnArchName;
VLOG(1)<<"GCN arch name " << gcnArchName;
auto pos = gcnArchName.find(":");
if(pos!=string::npos)
gcnArchName = gcnArchName.substr(0, pos);
pos = gcnArchName.find("gfx");
if(pos!=string::npos)
gcnArchName = gcnArchName.substr(pos+3);
VLOG(1)<<"GCN arch name (stripped) " << gcnArchName;
supported = (gcnArchName=="908" || gcnArchName=="909");
return port::Status::OK();
}
return port::Status{
port::error::INTERNAL,
absl::StrFormat("failed to determine AMDGpu GCN Arch Name for device %d",
dev)};
}
// Helper function that turns the integer output of hipDeviceGetAttribute to
// type T and wraps it in a StatusOr.
template <typename T>