Addressing PR feedback

This commit is contained in:
Deven Desai 2021-04-12 14:19:13 +00:00
parent 8d5eeb6189
commit 64d578a77d
3 changed files with 5 additions and 8 deletions

View File

@ -476,7 +476,7 @@ class GpuDriver {
#if TENSORFLOW_USE_ROCM
// tests the current device for MFMA insn support (ROCm only)
static port::Status GetMFMASupport(bool& support);
static port::StatusOr<bool> GetMFMASupport();
#endif
// Returns the number of multiprocessors on the device (note that the device

View File

@ -1618,9 +1618,8 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
"precondition violation";
}
}
bool hasXDLOPS = false;
port::Status status = GpuDriver::GetMFMASupport(hasXDLOPS);
if (status.ok() && hasXDLOPS) {
port::StatusOr<bool> maybe_hasXDLOPS = GpuDriver::GetMFMASupport();
if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.ValueOrDie()) {
VLOG(1) << "Using rocblas_gemm_ex";
return DoBlasInternal(
wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,

View File

@ -1095,8 +1095,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
device)};
}
/* static */ port::Status GpuDriver::GetMFMASupport(bool& supported) {
supported = false;
/* static */ port::StatusOr<bool> GpuDriver::GetMFMASupport() {
hipDeviceProp_t props;
int dev = 0;
hipError_t result = hipGetDevice(&dev);
@ -1111,8 +1110,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
if(pos!=string::npos)
gcnArchName = gcnArchName.substr(pos+3);
VLOG(1)<<"GCN arch name (stripped) " << gcnArchName;
supported = (gcnArchName=="908" || gcnArchName=="909");
return port::Status::OK();
return ((gcnArchName == "908") || (gcnArchName == "909"));
}
return port::Status{
port::error::INTERNAL,