mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Addressing PR feedback
This commit is contained in:
parent
6a74de1f86
commit
457df8e17a
|
|
@ -149,6 +149,7 @@ tf_cuda_library(
|
|||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/internal/cpu:annotation_stack",
|
||||
"//tensorflow/core/profiler/utils:time_utils",
|
||||
"//tensorflow/stream_executor/rocm:roctracer_wrapper",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
|
|
|||
|
|
@ -433,7 +433,6 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector {
|
|||
absl::StrCat("/device:GPU:", device_ordinal, "/memcpy"));
|
||||
}
|
||||
memcpy_dev_stats->add_node_stats()->Swap(ns.release());
|
||||
|
||||
} break;
|
||||
case RocmTracerEventType::MemoryAlloc: {
|
||||
std::string details = absl::StrCat(
|
||||
|
|
@ -718,8 +717,6 @@ Status GpuTracer::DoStart() {
|
|||
GetRocmTraceCollectorOptions(rocm_tracer_->NumGpus());
|
||||
uint64_t start_gputime_ns = RocmTracer::GetTimestamp();
|
||||
uint64_t start_walltime_ns = tensorflow::EnvTime::NowNanos();
|
||||
// VLOG(3) << "CPU Start Time : " << start_walltime_ns / 1000
|
||||
// << " , GPU Start Time : " << start_gputime_ns / 1000;
|
||||
rocm_trace_collector_ = std::make_unique<RocmTraceCollectorImpl>(
|
||||
trace_collector_options, start_walltime_ns, start_gputime_ns);
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,6 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/profiler/internal/gpu/rocm_tracer.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "rocm/rocm_config.h"
|
||||
|
|
@ -31,6 +26,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/profiler/internal/cpu/annotation_stack.h"
|
||||
#include "tensorflow/core/profiler/utils/time_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
|
@ -85,7 +81,7 @@ const char* GetActivityDomainName(uint32_t domain) {
|
|||
return "";
|
||||
}
|
||||
|
||||
string GetActivityDomainOpName(uint32_t domain, uint32_t op) {
|
||||
std::string GetActivityDomainOpName(uint32_t domain, uint32_t op) {
|
||||
std::ostringstream oss;
|
||||
oss << GetActivityDomainName(domain) << " - ";
|
||||
switch (domain) {
|
||||
|
|
@ -458,7 +454,7 @@ void RocmApiCallbackImpl::AddMemcpyEventUponApiExit(
|
|||
event.thread_id = GetCachedTID();
|
||||
event.correlation_id = data->correlation_id;
|
||||
|
||||
// ROCM TODO: figure out a way to properly populate this field.
|
||||
// TODO(rocm): figure out a way to properly populate this field.
|
||||
event.memcpy_info.destination = 0;
|
||||
switch (cbid) {
|
||||
case HIP_API_ID_hipMemcpyDtoH:
|
||||
|
|
@ -1106,7 +1102,7 @@ Status RocmTracer::DisableActivityTracing() {
|
|||
<< ", Threshold = " << threshold;
|
||||
VLOG(3) << "Wait for pending activity records : sleep for " << duration_ms
|
||||
<< " ms";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(duration_ms));
|
||||
tensorflow::profiler::SleepForMillis(duration_ms);
|
||||
}
|
||||
ClearPendingActivityRecordsCount();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user