mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44653 This changes the profiler per a discussion with ilia-cher offline that enables `disableProfiler()` event consolidation logic to be called from different threads (i.e. threads where the profiler was not explicitly enabled). This is needed to support the functionality enabled by D23638387 where we defer profiling event collection until executing an async callback that can execute on a different thread, to support RPC async function profiling. This is done by introducing 2 flags `cleanupTLSState` and `consolidate` which controls whether we should clean up thread local settings (we don't do this when calling `disableProfiler()` on non-main threads) and whether we should consolidate all profiled events. Backwards compatiblity is ensured since both options are true by default. Added a test in `test_misc.cpp` to test this. ghstack-source-id: 112605620 Reviewed By: mrshenli Differential Revision: D23638499 fbshipit-source-id: f5bbb0d41ef883c5e5870bc27e086b8b8908f46b
94 lines
2.1 KiB
C++
94 lines
2.1 KiB
C++
#include <c10/util/ThreadLocalDebugInfo.h>
|
|
|
|
namespace c10 {
|
|
|
|
namespace {
|
|
thread_local std::shared_ptr<ThreadLocalDebugInfo> debug_info = nullptr;
|
|
}
|
|
|
|
/* static */
|
|
std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::get(
|
|
DebugInfoKind kind) {
|
|
auto cur = debug_info;
|
|
while (cur) {
|
|
if (cur->kind_ == kind) {
|
|
return cur->info_;
|
|
}
|
|
cur = cur->parent_info_;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/* static */
|
|
std::shared_ptr<ThreadLocalDebugInfo> ThreadLocalDebugInfo::current() {
|
|
return debug_info;
|
|
}
|
|
|
|
/* static */
|
|
void ThreadLocalDebugInfo::_forceCurrentDebugInfo(
|
|
const std::shared_ptr<ThreadLocalDebugInfo>& info) {
|
|
debug_info = info;
|
|
}
|
|
|
|
/* static */
|
|
void ThreadLocalDebugInfo::_push(
|
|
DebugInfoKind kind,
|
|
std::shared_ptr<DebugInfoBase> info) {
|
|
auto prev_info = debug_info;
|
|
debug_info = std::make_shared<ThreadLocalDebugInfo>();
|
|
debug_info->parent_info_ = prev_info;
|
|
debug_info->kind_ = kind;
|
|
debug_info->info_ = info;
|
|
}
|
|
|
|
/* static */
|
|
std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_pop(DebugInfoKind kind) {
|
|
TORCH_CHECK(
|
|
debug_info && debug_info->kind_ == kind,
|
|
"Expected debug info of type ", (size_t)kind);
|
|
auto res = debug_info;
|
|
debug_info = debug_info->parent_info_;
|
|
return res->info_;
|
|
}
|
|
|
|
/* static */
|
|
std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_peek(DebugInfoKind kind) {
|
|
TORCH_CHECK(
|
|
debug_info && debug_info->kind_ == kind,
|
|
"Expected debug info of type ",
|
|
(size_t)kind);
|
|
return debug_info->info_;
|
|
}
|
|
|
|
|
|
DebugInfoGuard::DebugInfoGuard(
|
|
DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info) {
|
|
if (!info) {
|
|
return;
|
|
}
|
|
prev_info_ = debug_info;
|
|
ThreadLocalDebugInfo::_push(kind, info);
|
|
active_ = true;
|
|
}
|
|
|
|
DebugInfoGuard::~DebugInfoGuard() {
|
|
if (active_) {
|
|
debug_info = prev_info_;
|
|
}
|
|
}
|
|
|
|
// Used only for setting a debug info after crossing the thread boundary;
|
|
// in this case we assume that thread pool's thread does not have an
|
|
// active debug info
|
|
DebugInfoGuard::DebugInfoGuard(
|
|
std::shared_ptr<ThreadLocalDebugInfo> info) {
|
|
if (!info) {
|
|
return;
|
|
}
|
|
prev_info_ = debug_info;
|
|
debug_info = info;
|
|
active_ = true;
|
|
}
|
|
|
|
} // namespace c10
|