mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Adds Python Garbage Collection to Kineto Traces and Profiler FunctionEvents. Create custom cpp callback in profiler_python.cpp. Then define a python function with cpp and register that callback for all python garbage collection. We don't worry about thread safety in this case because we are only doing init/teardown for main thread while holding GIL. Currently we are hiding this behind experimental config because python tracing tends to be unstable especially when adding any new feature. If this is found to not add too much overhead we can set this to on by default. NOTE: To enable this you need both with_stack=True and the experimental config on! Test Plan: Ran trace with GC induced and saw it on trace Also added a test Rollback Plan: Differential Revision: D80491146 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161209 Approved by: https://github.com/ngimel
203 lines
6.7 KiB
C++
203 lines
6.7 KiB
C++
#include <torch/csrc/profiler/orchestration/observer.h>
|
|
|
|
#include <torch/csrc/profiler/util.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace torch::profiler::impl {
|
|
|
|
using GlobalManager = GlobalStateManager<ProfilerStateBase>;
|
|
|
|
// ----------------------------------------------------------------------------
|
|
// -- Profiler Config ---------------------------------------------------------
|
|
// ----------------------------------------------------------------------------
|
|
ExperimentalConfig::ExperimentalConfig(
|
|
std::vector<std::string> profiler_metrics,
|
|
bool profiler_measure_per_kernel,
|
|
bool verbose,
|
|
std::vector<std::string> performance_events,
|
|
bool enable_cuda_sync_events,
|
|
bool adjust_profiler_step,
|
|
bool disable_external_correlation,
|
|
bool profile_all_threads,
|
|
bool capture_overload_names,
|
|
bool record_python_gc_info,
|
|
std::string custom_profiler_config,
|
|
bool adjust_timestamps)
|
|
: profiler_metrics{std::move(profiler_metrics)},
|
|
profiler_measure_per_kernel{profiler_measure_per_kernel},
|
|
verbose{verbose},
|
|
performance_events(std::move(performance_events)),
|
|
enable_cuda_sync_events{enable_cuda_sync_events},
|
|
adjust_profiler_step{adjust_profiler_step},
|
|
disable_external_correlation{disable_external_correlation},
|
|
profile_all_threads{profile_all_threads},
|
|
capture_overload_names{capture_overload_names},
|
|
record_python_gc_info{record_python_gc_info},
|
|
custom_profiler_config(std::move(custom_profiler_config)),
|
|
adjust_timestamps{adjust_timestamps} {}
|
|
|
|
/*explicit*/ ExperimentalConfig::operator bool() const {
|
|
return !profiler_metrics.empty();
|
|
}
|
|
|
|
ProfilerConfig::ProfilerConfig(
|
|
ProfilerState state,
|
|
bool report_input_shapes,
|
|
bool profile_memory,
|
|
bool with_stack,
|
|
bool with_flops,
|
|
bool with_modules,
|
|
ExperimentalConfig experimental_config,
|
|
std::string trace_id)
|
|
: state{state},
|
|
experimental_config{std::move(experimental_config)},
|
|
report_input_shapes{report_input_shapes},
|
|
profile_memory{profile_memory},
|
|
with_stack{with_stack},
|
|
with_flops{with_flops},
|
|
with_modules{with_modules},
|
|
trace_id{std::move(trace_id)} {}
|
|
|
|
bool ProfilerConfig::disabled() const {
|
|
return state == torch::profiler::impl::ProfilerState::Disabled;
|
|
}
|
|
|
|
bool ProfilerConfig::global() const {
|
|
return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND;
|
|
}
|
|
|
|
bool ProfilerConfig::pushGlobalCallbacks() const {
|
|
return global() || experimental_config.profile_all_threads;
|
|
}
|
|
|
|
namespace {
|
|
enum ProfilerIValueIdx {
|
|
STATE = 0,
|
|
REPORT_INPUT_SHAPES,
|
|
PROFILE_MEMORY,
|
|
NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
|
|
};
|
|
} // namespace
|
|
|
|
at::IValue ProfilerConfig::toIValue() const {
|
|
c10::impl::GenericList eventIValueList(at::AnyType::get());
|
|
eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
|
|
eventIValueList.emplace_back(static_cast<int64_t>(state));
|
|
eventIValueList.emplace_back(report_input_shapes);
|
|
eventIValueList.emplace_back(profile_memory);
|
|
return eventIValueList;
|
|
}
|
|
|
|
ProfilerConfig ProfilerConfig::fromIValue(
|
|
const at::IValue& profilerConfigIValue) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
profilerConfigIValue.isList(),
|
|
"Expected IValue to contain type c10::impl::GenericList");
|
|
auto ivalues = profilerConfigIValue.toList();
|
|
TORCH_INTERNAL_ASSERT(
|
|
ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX,
|
|
c10::str(
|
|
"Expected exactly ",
|
|
NUM_PROFILER_CFG_IVALUE_IDX,
|
|
" ivalues to resconstruct ProfilerConfig."));
|
|
return ProfilerConfig(
|
|
static_cast<ProfilerState>(ivalues.get(ProfilerIValueIdx::STATE).toInt()),
|
|
ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(),
|
|
ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool());
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
// -- Profiler base class -----------------------------------------------------
|
|
// ----------------------------------------------------------------------------
|
|
/*explicit*/ ProfilerStateBase::ProfilerStateBase(ProfilerConfig config)
|
|
: c10::MemoryReportingInfoBase(), config_(std::move(config)) {}
|
|
|
|
ProfilerStateBase::~ProfilerStateBase() {
|
|
if (handle_) {
|
|
auto handle = handle_;
|
|
removeCallback();
|
|
SOFT_ASSERT(false, "Leaked callback handle: ", handle);
|
|
}
|
|
}
|
|
|
|
/*static*/ ProfilerStateBase* ProfilerStateBase::get(bool global) {
|
|
auto* out = global
|
|
? GlobalManager::get()
|
|
: static_cast<ProfilerStateBase*>(
|
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
!out || out->config().pushGlobalCallbacks() == global);
|
|
return out;
|
|
}
|
|
|
|
/*static*/ void ProfilerStateBase::push(
|
|
std::shared_ptr<ProfilerStateBase>&& state) {
|
|
TORCH_INTERNAL_ASSERT(state != nullptr);
|
|
if (state->config().pushGlobalCallbacks()) {
|
|
GlobalManager::push(std::move(state));
|
|
} else {
|
|
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
std::shared_ptr<ProfilerStateBase> popTLS() {
|
|
// If there is no active thread local profiler then we simply return null.
|
|
// However if there is an active profiler but it is not the top
|
|
// `DebugInfoBase`then `c10::ThreadLocalDebugInfo::_pop` will throw.
|
|
// TODO(robieta): make `noexcept` version.
|
|
return c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)
|
|
? std::static_pointer_cast<ProfilerStateBase>(
|
|
c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE))
|
|
: nullptr;
|
|
}
|
|
} // namespace
|
|
|
|
/*static*/ std::shared_ptr<ProfilerStateBase> ProfilerStateBase::pop(
|
|
bool global) {
|
|
auto out = global ? GlobalManager::pop() : popTLS();
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
|
|
return out;
|
|
}
|
|
|
|
void ProfilerStateBase::setCallbackHandle(at::CallbackHandle handle) {
|
|
if (handle_) {
|
|
at::removeCallback(handle_);
|
|
SOFT_ASSERT(
|
|
false,
|
|
"ProfilerStateBase already has a registered callback. "
|
|
"Removing to avoid leaked callback.");
|
|
}
|
|
|
|
handle_ = handle;
|
|
}
|
|
|
|
void ProfilerStateBase::removeCallback() {
|
|
if (handle_) {
|
|
at::removeCallback(handle_);
|
|
handle_ = 0;
|
|
}
|
|
}
|
|
|
|
bool profilerEnabled() {
|
|
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
|
|
return state_ptr && !state_ptr->config().disabled();
|
|
}
|
|
|
|
TORCH_API ActiveProfilerType profilerType() {
|
|
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
|
|
return state_ptr == nullptr ? ActiveProfilerType::NONE
|
|
: state_ptr->profilerType();
|
|
}
|
|
|
|
torch::profiler::impl::ProfilerConfig getProfilerConfig() {
|
|
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
|
|
TORCH_CHECK(
|
|
state_ptr,
|
|
"Tried to access profiler config, but profiler is not enabled!");
|
|
return state_ptr->config();
|
|
}
|
|
|
|
} // namespace torch::profiler::impl
|