mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Switching to rocblas_gemm_ex for MFMA-enabled architectures
This commit is contained in:
parent
63019e2b06
commit
5a336a4150
|
|
@ -474,6 +474,11 @@ class GpuDriver {
|
|||
static port::Status GetGpuGCNArchName(GpuDeviceHandle device,
|
||||
std::string* gcnArchName);
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// tests the current device for MFMA insn support (ROCm only)
|
||||
static port::Status GetMFMASupport(bool& support);
|
||||
#endif
|
||||
|
||||
// Returns the number of multiprocessors on the device (note that the device
|
||||
// may be multi-GPU-per-board).
|
||||
static port::StatusOr<int> GetMultiprocessorCount(GpuDeviceHandle device);
|
||||
|
|
|
|||
|
|
@ -267,6 +267,7 @@ namespace wrap {
|
|||
__macro(rocblas_ztrmm) \
|
||||
__macro(rocblas_sgeam) \
|
||||
__macro(rocblas_dgeam) \
|
||||
__macro(rocblas_gemm_ex) \
|
||||
/*__macro(rocblas_cgeam) \
|
||||
__macro(rocblas_zgeam) \
|
||||
__macro(rocblas_sdgmm) \
|
||||
|
|
@ -1617,9 +1618,13 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
"precondition violation";
|
||||
}
|
||||
}
|
||||
const Eigen::half alpha_half(alpha);
|
||||
const Eigen::half beta_half(beta);
|
||||
return DoBlasInternal(
|
||||
bool hasXDLOPS = false;
|
||||
auto status = GpuDriver::GetMFMASupport(hasXDLOPS);
|
||||
if(!hasXDLOPS) {
|
||||
VLOG(1) << "Using rocblas_hgemm";
|
||||
const Eigen::half alpha_half(alpha);
|
||||
const Eigen::half beta_half(beta);
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true,
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
||||
reinterpret_cast<const rocblas_half *>(&alpha_half),
|
||||
|
|
@ -1627,6 +1632,21 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
|
||||
reinterpret_cast<const rocblas_half *>(&beta_half),
|
||||
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
|
||||
} else {
|
||||
VLOG(1) << "Using rocblas_gemm_ex";
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb),
|
||||
(rocblas_int)m, (rocblas_int)n, (rocblas_int)k,
|
||||
reinterpret_cast<const void*>(&alpha),
|
||||
reinterpret_cast<const void*>(GpuMemory(a)), rocblas_datatype_f16_r, lda,
|
||||
reinterpret_cast<const void*>(GpuMemory(b)), rocblas_datatype_f16_r, ldb,
|
||||
reinterpret_cast<const void*>(&beta),
|
||||
reinterpret_cast<const void*>(GpuMemoryMutable(c)),
|
||||
rocblas_datatype_f16_r, ldc,
|
||||
reinterpret_cast<void*>(GpuMemoryMutable(c)), rocblas_datatype_f16_r, ldc,
|
||||
rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
|
|
|
|||
|
|
@ -1095,6 +1095,31 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||
device)};
|
||||
}
|
||||
|
||||
/* static */ port::Status GpuDriver::GetMFMASupport(bool& supported) {
|
||||
supported = false;
|
||||
hipDeviceProp_t props;
|
||||
int dev = 0;
|
||||
hipError_t result = hipGetDevice(&dev);
|
||||
result = tensorflow::wrap::hipGetDeviceProperties(&props, dev);
|
||||
if (result == hipSuccess) {
|
||||
std::string gcnArchName = props.gcnArchName;
|
||||
VLOG(1)<<"GCN arch name " << gcnArchName;
|
||||
auto pos = gcnArchName.find(":");
|
||||
if(pos!=string::npos)
|
||||
gcnArchName = gcnArchName.substr(0, pos);
|
||||
pos = gcnArchName.find("gfx");
|
||||
if(pos!=string::npos)
|
||||
gcnArchName = gcnArchName.substr(pos+3);
|
||||
VLOG(1)<<"GCN arch name (stripped) " << gcnArchName;
|
||||
supported = (gcnArchName=="908" || gcnArchName=="909");
|
||||
return port::Status::OK();
|
||||
}
|
||||
return port::Status{
|
||||
port::error::INTERNAL,
|
||||
absl::StrFormat("failed to determine AMDGpu GCN Arch Name for device %d",
|
||||
dev)};
|
||||
}
|
||||
|
||||
// Helper function that turns the integer output of hipDeviceGetAttribute to
|
||||
// type T and wraps it in a StatusOr.
|
||||
template <typename T>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user