mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
## Overview
This PR allows the profiler users to access `Kineto` and `TorchOp` metadata in JSON string format through a new `metadata_json` attribute in `FunctionEvent` objects, which is triggered through a new `expose_kineto_event_metadata` flag in `ExperimentalConfig`.
## Testing
A unit test was added to validate functionality.
## Documentation
Added/updated function doc strings where appropriate.
## Example output
```python
import torch
from torch.profiler import profile
with profile(experimental_config=torch._C._profiler._ExperimentalConfig(expose_kineto_event_metadata=True)) as prof:
res = torch.mm(torch.rand(1024, 1024), torch.rand(1024, 1024))
for event in prof.events():
print(f'name: {event.key}, metadata: {event.metadata_json}')
```
```
name: aten::rand, metadata: "Ev Idx": 0
name: aten::empty, metadata: "Ev Idx": 1
name: aten::uniform_, metadata: "Ev Idx": 2
name: aten::rand, metadata: "Ev Idx": 3
name: aten::empty, metadata: "Ev Idx": 4
name: aten::uniform_, metadata: "Ev Idx": 5
name: aten::mm, metadata: "Ev Idx": 6
name: aten::resolve_conj, metadata: "Ev Idx": 7
name: aten::resolve_conj, metadata: "Ev Idx": 8
name: aten::resolve_conj, metadata: "Ev Idx": 9
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161624
Approved by: https://github.com/sraikund16
205 lines
6.8 KiB
C++
205 lines
6.8 KiB
C++
#include <torch/csrc/profiler/orchestration/observer.h>
|
|
|
|
#include <torch/csrc/profiler/util.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace torch::profiler::impl {
|
|
|
|
using GlobalManager = GlobalStateManager<ProfilerStateBase>;
|
|
|
|
// ----------------------------------------------------------------------------
|
|
// -- Profiler Config ---------------------------------------------------------
|
|
// ----------------------------------------------------------------------------
|
|
ExperimentalConfig::ExperimentalConfig(
|
|
std::vector<std::string> profiler_metrics,
|
|
bool profiler_measure_per_kernel,
|
|
bool verbose,
|
|
std::vector<std::string> performance_events,
|
|
bool enable_cuda_sync_events,
|
|
bool adjust_profiler_step,
|
|
bool disable_external_correlation,
|
|
bool profile_all_threads,
|
|
bool capture_overload_names,
|
|
bool record_python_gc_info,
|
|
bool expose_kineto_event_metadata,
|
|
std::string custom_profiler_config,
|
|
bool adjust_timestamps)
|
|
: profiler_metrics{std::move(profiler_metrics)},
|
|
profiler_measure_per_kernel{profiler_measure_per_kernel},
|
|
verbose{verbose},
|
|
performance_events(std::move(performance_events)),
|
|
enable_cuda_sync_events{enable_cuda_sync_events},
|
|
adjust_profiler_step{adjust_profiler_step},
|
|
disable_external_correlation{disable_external_correlation},
|
|
profile_all_threads{profile_all_threads},
|
|
capture_overload_names{capture_overload_names},
|
|
record_python_gc_info{record_python_gc_info},
|
|
expose_kineto_event_metadata{expose_kineto_event_metadata},
|
|
custom_profiler_config(std::move(custom_profiler_config)),
|
|
adjust_timestamps{adjust_timestamps} {}
|
|
|
|
/*explicit*/ ExperimentalConfig::operator bool() const {
|
|
return !profiler_metrics.empty();
|
|
}
|
|
|
|
ProfilerConfig::ProfilerConfig(
|
|
ProfilerState state,
|
|
bool report_input_shapes,
|
|
bool profile_memory,
|
|
bool with_stack,
|
|
bool with_flops,
|
|
bool with_modules,
|
|
ExperimentalConfig experimental_config,
|
|
std::string trace_id)
|
|
: state{state},
|
|
experimental_config{std::move(experimental_config)},
|
|
report_input_shapes{report_input_shapes},
|
|
profile_memory{profile_memory},
|
|
with_stack{with_stack},
|
|
with_flops{with_flops},
|
|
with_modules{with_modules},
|
|
trace_id{std::move(trace_id)} {}
|
|
|
|
bool ProfilerConfig::disabled() const {
|
|
return state == torch::profiler::impl::ProfilerState::Disabled;
|
|
}
|
|
|
|
bool ProfilerConfig::global() const {
|
|
return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND;
|
|
}
|
|
|
|
bool ProfilerConfig::pushGlobalCallbacks() const {
|
|
return global() || experimental_config.profile_all_threads;
|
|
}
|
|
|
|
namespace {
|
|
enum ProfilerIValueIdx {
|
|
STATE = 0,
|
|
REPORT_INPUT_SHAPES,
|
|
PROFILE_MEMORY,
|
|
NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
|
|
};
|
|
} // namespace
|
|
|
|
at::IValue ProfilerConfig::toIValue() const {
|
|
c10::impl::GenericList eventIValueList(at::AnyType::get());
|
|
eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
|
|
eventIValueList.emplace_back(static_cast<int64_t>(state));
|
|
eventIValueList.emplace_back(report_input_shapes);
|
|
eventIValueList.emplace_back(profile_memory);
|
|
return eventIValueList;
|
|
}
|
|
|
|
ProfilerConfig ProfilerConfig::fromIValue(
|
|
const at::IValue& profilerConfigIValue) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
profilerConfigIValue.isList(),
|
|
"Expected IValue to contain type c10::impl::GenericList");
|
|
auto ivalues = profilerConfigIValue.toList();
|
|
TORCH_INTERNAL_ASSERT(
|
|
ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX,
|
|
c10::str(
|
|
"Expected exactly ",
|
|
NUM_PROFILER_CFG_IVALUE_IDX,
|
|
" ivalues to resconstruct ProfilerConfig."));
|
|
return ProfilerConfig(
|
|
static_cast<ProfilerState>(ivalues.get(ProfilerIValueIdx::STATE).toInt()),
|
|
ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(),
|
|
ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool());
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------
|
|
// -- Profiler base class -----------------------------------------------------
|
|
// ----------------------------------------------------------------------------
|
|
/*explicit*/ ProfilerStateBase::ProfilerStateBase(ProfilerConfig config)
|
|
: c10::MemoryReportingInfoBase(), config_(std::move(config)) {}
|
|
|
|
ProfilerStateBase::~ProfilerStateBase() {
|
|
if (handle_) {
|
|
auto handle = handle_;
|
|
removeCallback();
|
|
SOFT_ASSERT(false, "Leaked callback handle: ", handle);
|
|
}
|
|
}
|
|
|
|
/*static*/ ProfilerStateBase* ProfilerStateBase::get(bool global) {
|
|
auto* out = global
|
|
? GlobalManager::get()
|
|
: static_cast<ProfilerStateBase*>(
|
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
!out || out->config().pushGlobalCallbacks() == global);
|
|
return out;
|
|
}
|
|
|
|
/*static*/ void ProfilerStateBase::push(
|
|
std::shared_ptr<ProfilerStateBase>&& state) {
|
|
TORCH_INTERNAL_ASSERT(state != nullptr);
|
|
if (state->config().pushGlobalCallbacks()) {
|
|
GlobalManager::push(std::move(state));
|
|
} else {
|
|
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
std::shared_ptr<ProfilerStateBase> popTLS() {
|
|
// If there is no active thread local profiler then we simply return null.
|
|
// However if there is an active profiler but it is not the top
|
|
// `DebugInfoBase`then `c10::ThreadLocalDebugInfo::_pop` will throw.
|
|
// TODO(robieta): make `noexcept` version.
|
|
return c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)
|
|
? std::static_pointer_cast<ProfilerStateBase>(
|
|
c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE))
|
|
: nullptr;
|
|
}
|
|
} // namespace
|
|
|
|
/*static*/ std::shared_ptr<ProfilerStateBase> ProfilerStateBase::pop(
|
|
bool global) {
|
|
auto out = global ? GlobalManager::pop() : popTLS();
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
|
|
return out;
|
|
}
|
|
|
|
void ProfilerStateBase::setCallbackHandle(at::CallbackHandle handle) {
|
|
if (handle_) {
|
|
at::removeCallback(handle_);
|
|
SOFT_ASSERT(
|
|
false,
|
|
"ProfilerStateBase already has a registered callback. "
|
|
"Removing to avoid leaked callback.");
|
|
}
|
|
|
|
handle_ = handle;
|
|
}
|
|
|
|
void ProfilerStateBase::removeCallback() {
|
|
if (handle_) {
|
|
at::removeCallback(handle_);
|
|
handle_ = 0;
|
|
}
|
|
}
|
|
|
|
bool profilerEnabled() {
|
|
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
|
|
return state_ptr && !state_ptr->config().disabled();
|
|
}
|
|
|
|
TORCH_API ActiveProfilerType profilerType() {
|
|
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
|
|
return state_ptr == nullptr ? ActiveProfilerType::NONE
|
|
: state_ptr->profilerType();
|
|
}
|
|
|
|
torch::profiler::impl::ProfilerConfig getProfilerConfig() {
|
|
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
|
|
TORCH_CHECK(
|
|
state_ptr,
|
|
"Tried to access profiler config, but profiler is not enabled!");
|
|
return state_ptr->config();
|
|
}
|
|
|
|
} // namespace torch::profiler::impl
|