pytorch/c10/util/ThreadLocalDebugInfo.h
Rohan Varma 70d2e4d1f6 [RPC profiling] Allow disableProfiler() to be called from another thread. (#44653)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44653

This changes the profiler per a discussion with ilia-cher offline that enables `disableProfiler()` event consolidation logic to be called from different threads (i.e. threads where the profiler was not explicitly enabled). This is needed to support the functionality enabled by D23638387 where we defer profiling event collection until executing an async callback that can execute on a different thread, to support RPC async function profiling.

This is done by introducing 2 flags `cleanupTLSState` and `consolidate` which controls whether we should clean up thread local settings (we don't do this when calling `disableProfiler()` on non-main threads) and whether we should consolidate all profiled events. Backwards compatiblity is ensured since both options are true by default.

Added a test in `test_misc.cpp` to test this.
ghstack-source-id: 112605620

Reviewed By: mrshenli

Differential Revision: D23638499

fbshipit-source-id: f5bbb0d41ef883c5e5870bc27e086b8b8908f46b
2020-09-22 21:16:58 -07:00

86 lines
2.5 KiB
C++

#pragma once
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <memory>
#include <string>
#include <unordered_map>
namespace c10 {
enum class C10_API_ENUM DebugInfoKind : uint8_t {
PRODUCER_INFO = 0,
MOBILE_RUNTIME_INFO,
PROFILER_STATE,
TEST_INFO, // used only in tests
TEST_INFO_2, // used only in tests
};
class C10_API DebugInfoBase {
public:
DebugInfoBase() {}
virtual ~DebugInfoBase() {}
};
// Thread local debug information is propagated across the forward
// (including async fork tasks) and backward passes and is supposed
// to be utilized by the user's code to pass extra information from
// the higher layers (e.g. model id) down to the lower levels
// (e.g. to the operator observers used for debugging, logging,
// profiling, etc)
class C10_API ThreadLocalDebugInfo {
public:
static std::shared_ptr<DebugInfoBase> get(DebugInfoKind kind);
// Get current ThreadLocalDebugInfo
static std::shared_ptr<ThreadLocalDebugInfo> current();
// Internal, use DebugInfoGuard/ThreadLocalStateGuard
static void _forceCurrentDebugInfo(
const std::shared_ptr<ThreadLocalDebugInfo>& info);
// Push debug info struct of a given kind
static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
// Pop debug info, throws in case the last pushed
// debug info is not of a given kind
static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
// Peek debug info, throws in case the last pushed debug info is not of the
// given kind
static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
private:
std::shared_ptr<DebugInfoBase> info_;
DebugInfoKind kind_;
std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
friend class DebugInfoGuard;
};
// DebugInfoGuard is used to set debug information,
// ThreadLocalDebugInfo is semantically immutable, the values are set
// through the scope-based guard object.
// Nested DebugInfoGuard adds/overrides existing values in the scope,
// restoring the original values after exiting the scope.
// Users can access the values through the ThreadLocalDebugInfo::get() call;
class C10_API DebugInfoGuard {
public:
DebugInfoGuard(
DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
explicit DebugInfoGuard(
std::shared_ptr<ThreadLocalDebugInfo> info);
~DebugInfoGuard();
DebugInfoGuard(const DebugInfoGuard&) = delete;
DebugInfoGuard(DebugInfoGuard&&) = delete;
private:
bool active_ = false;
std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
};
} // namespace c10