mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Addressing PR feedback
This commit is contained in:
parent
8d5eeb6189
commit
64d578a77d
|
|
@ -476,7 +476,7 @@ class GpuDriver {
|
|||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// tests the current device for MFMA insn support (ROCm only)
|
||||
static port::Status GetMFMASupport(bool& support);
|
||||
static port::StatusOr<bool> GetMFMASupport();
|
||||
#endif
|
||||
|
||||
// Returns the number of multiprocessors on the device (note that the device
|
||||
|
|
|
|||
|
|
@ -1618,9 +1618,8 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
"precondition violation";
|
||||
}
|
||||
}
|
||||
bool hasXDLOPS = false;
|
||||
port::Status status = GpuDriver::GetMFMASupport(hasXDLOPS);
|
||||
if (status.ok() && hasXDLOPS) {
|
||||
port::StatusOr<bool> maybe_hasXDLOPS = GpuDriver::GetMFMASupport();
|
||||
if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.ValueOrDie()) {
|
||||
VLOG(1) << "Using rocblas_gemm_ex";
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,
|
||||
|
|
|
|||
|
|
@ -1095,8 +1095,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||
device)};
|
||||
}
|
||||
|
||||
/* static */ port::Status GpuDriver::GetMFMASupport(bool& supported) {
|
||||
supported = false;
|
||||
/* static */ port::StatusOr<bool> GpuDriver::GetMFMASupport() {
|
||||
hipDeviceProp_t props;
|
||||
int dev = 0;
|
||||
hipError_t result = hipGetDevice(&dev);
|
||||
|
|
@ -1111,8 +1110,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||
if(pos!=string::npos)
|
||||
gcnArchName = gcnArchName.substr(pos+3);
|
||||
VLOG(1)<<"GCN arch name (stripped) " << gcnArchName;
|
||||
supported = (gcnArchName=="908" || gcnArchName=="909");
|
||||
return port::Status::OK();
|
||||
return ((gcnArchName == "908") || (gcnArchName == "909"));
|
||||
}
|
||||
return port::Status{
|
||||
port::error::INTERNAL,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user