Addressing PR feedback

This commit is contained in:
Deven Desai 2021-04-08 20:43:06 +00:00
parent 6a74de1f86
commit 457df8e17a
3 changed files with 5 additions and 11 deletions

View File

@ -149,6 +149,7 @@ tf_cuda_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/profiler/internal/cpu:annotation_stack",
"//tensorflow/core/profiler/utils:time_utils",
"//tensorflow/stream_executor/rocm:roctracer_wrapper",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -433,7 +433,6 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector {
absl::StrCat("/device:GPU:", device_ordinal, "/memcpy"));
}
memcpy_dev_stats->add_node_stats()->Swap(ns.release());
} break;
case RocmTracerEventType::MemoryAlloc: {
std::string details = absl::StrCat(
@ -718,8 +717,6 @@ Status GpuTracer::DoStart() {
GetRocmTraceCollectorOptions(rocm_tracer_->NumGpus());
uint64_t start_gputime_ns = RocmTracer::GetTimestamp();
uint64_t start_walltime_ns = tensorflow::EnvTime::NowNanos();
// VLOG(3) << "CPU Start Time : " << start_walltime_ns / 1000
// << " , GPU Start Time : " << start_gputime_ns / 1000;
rocm_trace_collector_ = std::make_unique<RocmTraceCollectorImpl>(
trace_collector_options, start_walltime_ns, start_gputime_ns);

View File

@ -15,11 +15,6 @@ limitations under the License.
#include "tensorflow/core/profiler/internal/gpu/rocm_tracer.h"
#include <chrono>
#include <iostream>
#include <sstream>
#include <thread>
#include "absl/container/flat_hash_map.h"
#include "absl/container/node_hash_map.h"
#include "rocm/rocm_config.h"
@ -31,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/profiler/internal/cpu/annotation_stack.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
namespace tensorflow {
namespace profiler {
@ -85,7 +81,7 @@ const char* GetActivityDomainName(uint32_t domain) {
return "";
}
string GetActivityDomainOpName(uint32_t domain, uint32_t op) {
std::string GetActivityDomainOpName(uint32_t domain, uint32_t op) {
std::ostringstream oss;
oss << GetActivityDomainName(domain) << " - ";
switch (domain) {
@ -458,7 +454,7 @@ void RocmApiCallbackImpl::AddMemcpyEventUponApiExit(
event.thread_id = GetCachedTID();
event.correlation_id = data->correlation_id;
// ROCM TODO: figure out a way to properly populate this field.
// TODO(rocm): figure out a way to properly populate this field.
event.memcpy_info.destination = 0;
switch (cbid) {
case HIP_API_ID_hipMemcpyDtoH:
@ -1106,7 +1102,7 @@ Status RocmTracer::DisableActivityTracing() {
<< ", Threshold = " << threshold;
VLOG(3) << "Wait for pending activity records : sleep for " << duration_ms
<< " ms";
std::this_thread::sleep_for(std::chrono::milliseconds(duration_ms));
tensorflow::profiler::SleepForMillis(duration_ms);
}
ClearPendingActivityRecordsCount();