pytorch/torch/csrc/profiler/orchestration/observer.cpp
Mihai Polceanu 6fa3715c12 Expose Kineto event metadata in PyTorch Profiler events (#161624)
## Overview
This PR allows the profiler users to access `Kineto` and `TorchOp` metadata in JSON string format through a new `metadata_json` attribute in `FunctionEvent` objects, which is triggered through a new `expose_kineto_event_metadata` flag in `ExperimentalConfig`.

## Testing
A unit test was added to validate functionality.

## Documentation
Added/updated function doc strings where appropriate.

## Example output
```python
import torch
from torch.profiler import profile

with profile(experimental_config=torch._C._profiler._ExperimentalConfig(expose_kineto_event_metadata=True)) as prof:
    res = torch.mm(torch.rand(1024, 1024), torch.rand(1024, 1024))

for event in prof.events():
    print(f'name: {event.key}, metadata: {event.metadata_json}')
```

```
name: aten::rand, metadata: "Ev Idx": 0
name: aten::empty, metadata: "Ev Idx": 1
name: aten::uniform_, metadata: "Ev Idx": 2
name: aten::rand, metadata: "Ev Idx": 3
name: aten::empty, metadata: "Ev Idx": 4
name: aten::uniform_, metadata: "Ev Idx": 5
name: aten::mm, metadata: "Ev Idx": 6
name: aten::resolve_conj, metadata: "Ev Idx": 7
name: aten::resolve_conj, metadata: "Ev Idx": 8
name: aten::resolve_conj, metadata: "Ev Idx": 9
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161624
Approved by: https://github.com/sraikund16
2025-09-25 14:58:30 +00:00

205 lines
6.8 KiB
C++

#include <torch/csrc/profiler/orchestration/observer.h>
#include <torch/csrc/profiler/util.h>
#include <utility>
namespace torch::profiler::impl {
using GlobalManager = GlobalStateManager<ProfilerStateBase>;
// ----------------------------------------------------------------------------
// -- Profiler Config ---------------------------------------------------------
// ----------------------------------------------------------------------------
ExperimentalConfig::ExperimentalConfig(
std::vector<std::string> profiler_metrics,
bool profiler_measure_per_kernel,
bool verbose,
std::vector<std::string> performance_events,
bool enable_cuda_sync_events,
bool adjust_profiler_step,
bool disable_external_correlation,
bool profile_all_threads,
bool capture_overload_names,
bool record_python_gc_info,
bool expose_kineto_event_metadata,
std::string custom_profiler_config,
bool adjust_timestamps)
: profiler_metrics{std::move(profiler_metrics)},
profiler_measure_per_kernel{profiler_measure_per_kernel},
verbose{verbose},
performance_events(std::move(performance_events)),
enable_cuda_sync_events{enable_cuda_sync_events},
adjust_profiler_step{adjust_profiler_step},
disable_external_correlation{disable_external_correlation},
profile_all_threads{profile_all_threads},
capture_overload_names{capture_overload_names},
record_python_gc_info{record_python_gc_info},
expose_kineto_event_metadata{expose_kineto_event_metadata},
custom_profiler_config(std::move(custom_profiler_config)),
adjust_timestamps{adjust_timestamps} {}
/*explicit*/ ExperimentalConfig::operator bool() const {
return !profiler_metrics.empty();
}
ProfilerConfig::ProfilerConfig(
ProfilerState state,
bool report_input_shapes,
bool profile_memory,
bool with_stack,
bool with_flops,
bool with_modules,
ExperimentalConfig experimental_config,
std::string trace_id)
: state{state},
experimental_config{std::move(experimental_config)},
report_input_shapes{report_input_shapes},
profile_memory{profile_memory},
with_stack{with_stack},
with_flops{with_flops},
with_modules{with_modules},
trace_id{std::move(trace_id)} {}
bool ProfilerConfig::disabled() const {
return state == torch::profiler::impl::ProfilerState::Disabled;
}
bool ProfilerConfig::global() const {
return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND;
}
bool ProfilerConfig::pushGlobalCallbacks() const {
return global() || experimental_config.profile_all_threads;
}
namespace {
enum ProfilerIValueIdx {
STATE = 0,
REPORT_INPUT_SHAPES,
PROFILE_MEMORY,
NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
};
} // namespace
at::IValue ProfilerConfig::toIValue() const {
c10::impl::GenericList eventIValueList(at::AnyType::get());
eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
eventIValueList.emplace_back(static_cast<int64_t>(state));
eventIValueList.emplace_back(report_input_shapes);
eventIValueList.emplace_back(profile_memory);
return eventIValueList;
}
ProfilerConfig ProfilerConfig::fromIValue(
const at::IValue& profilerConfigIValue) {
TORCH_INTERNAL_ASSERT(
profilerConfigIValue.isList(),
"Expected IValue to contain type c10::impl::GenericList");
auto ivalues = profilerConfigIValue.toList();
TORCH_INTERNAL_ASSERT(
ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX,
c10::str(
"Expected exactly ",
NUM_PROFILER_CFG_IVALUE_IDX,
" ivalues to resconstruct ProfilerConfig."));
return ProfilerConfig(
static_cast<ProfilerState>(ivalues.get(ProfilerIValueIdx::STATE).toInt()),
ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(),
ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool());
}
// ----------------------------------------------------------------------------
// -- Profiler base class -----------------------------------------------------
// ----------------------------------------------------------------------------
/*explicit*/ ProfilerStateBase::ProfilerStateBase(ProfilerConfig config)
: c10::MemoryReportingInfoBase(), config_(std::move(config)) {}
ProfilerStateBase::~ProfilerStateBase() {
if (handle_) {
auto handle = handle_;
removeCallback();
SOFT_ASSERT(false, "Leaked callback handle: ", handle);
}
}
/*static*/ ProfilerStateBase* ProfilerStateBase::get(bool global) {
auto* out = global
? GlobalManager::get()
: static_cast<ProfilerStateBase*>(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!out || out->config().pushGlobalCallbacks() == global);
return out;
}
/*static*/ void ProfilerStateBase::push(
std::shared_ptr<ProfilerStateBase>&& state) {
TORCH_INTERNAL_ASSERT(state != nullptr);
if (state->config().pushGlobalCallbacks()) {
GlobalManager::push(std::move(state));
} else {
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
}
}
namespace {
std::shared_ptr<ProfilerStateBase> popTLS() {
// If there is no active thread local profiler then we simply return null.
// However if there is an active profiler but it is not the top
// `DebugInfoBase`then `c10::ThreadLocalDebugInfo::_pop` will throw.
// TODO(robieta): make `noexcept` version.
return c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)
? std::static_pointer_cast<ProfilerStateBase>(
c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE))
: nullptr;
}
} // namespace
/*static*/ std::shared_ptr<ProfilerStateBase> ProfilerStateBase::pop(
bool global) {
auto out = global ? GlobalManager::pop() : popTLS();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!out || out->config().global() == global);
return out;
}
void ProfilerStateBase::setCallbackHandle(at::CallbackHandle handle) {
if (handle_) {
at::removeCallback(handle_);
SOFT_ASSERT(
false,
"ProfilerStateBase already has a registered callback. "
"Removing to avoid leaked callback.");
}
handle_ = handle;
}
void ProfilerStateBase::removeCallback() {
if (handle_) {
at::removeCallback(handle_);
handle_ = 0;
}
}
bool profilerEnabled() {
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
return state_ptr && !state_ptr->config().disabled();
}
TORCH_API ActiveProfilerType profilerType() {
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
return state_ptr == nullptr ? ActiveProfilerType::NONE
: state_ptr->profilerType();
}
torch::profiler::impl::ProfilerConfig getProfilerConfig() {
auto* state_ptr = ProfilerStateBase::get(/*global=*/false);
TORCH_CHECK(
state_ptr,
"Tried to access profiler config, but profiler is not enabled!");
return state_ptr->config();
}
} // namespace torch::profiler::impl