[Profiler] Add Optional Flag to turn off external correlations (#142516)

Summary: External Correlations are super spammy and oftentimes not even useful. Add flag during init to remove them entirely

Test Plan: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devvm2185.cco0.facebook.com/rank-0.Dec_10_12_33_31.531106.pt.trace.json.gz&bucket=gpu_traces

Differential Revision: D67048206

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142516
Approved by: https://github.com/ngimel
This commit is contained in:
Shivam Raikundalia 2024-12-13 22:32:07 +00:00 committed by PyTorch MergeBot
parent bb574abe73
commit b29fc52f82
6 changed files with 65 additions and 10 deletions

View File

@ -2122,6 +2122,44 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
for step in range(len(test_schedule_expected_outputs)):
self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_disable_external_correlation(self):
cuda_external_id_events = {"cuda_runtime", "gpu_memcpy", "kernel"}
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
def check_correlations(event, disable_external_correlation):
if "cat" in event and event["cat"] in cuda_external_id_events:
if disable_external_correlation:
self.assertTrue("External id" not in event["args"])
elif event["name"] != "cudaDeviceSynchronize":
self.assertTrue("External id" in event["args"])
self.assertTrue(event["args"]["External id"] > 0)
def validate_json(prof, disable_external_correlation):
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
events = json.load(f)["traceEvents"]
seen_event_types = set()
for event in events:
check_correlations(event, disable_external_correlation)
if "cat" in event:
seen_event_types.add(event["cat"])
self.assertTrue(cuda_external_id_events.issubset(seen_event_types))
# Run with External Id for CUDA events on and off
for disable_external_correlation in [False, True]:
with profile(
activities=activities,
experimental_config=torch._C._profiler._ExperimentalConfig(
disable_external_correlation=disable_external_correlation
),
) as prof:
self.payload(use_cuda=True)
validate_json(prof, disable_external_correlation)
class SimpleNet(nn.Module):
def __init__(self) -> None:

View File

@ -527,10 +527,12 @@ void onFunctionExit(
nullptr, &fallback->device_event_end_, nullptr);
}
if (fn.scope() == at::RecordScope::USER_SCOPE) {
torch::profiler::impl::kineto::popUserCorrelationId();
} else {
torch::profiler::impl::kineto::popCorrelationId();
if (!config.experimental_config.disable_external_correlation) {
if (fn.scope() == at::RecordScope::USER_SCOPE) {
torch::profiler::impl::kineto::popUserCorrelationId();
} else {
torch::profiler::impl::kineto::popCorrelationId();
}
}
}

View File

@ -339,10 +339,12 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
torch_ops_.inputs_outputs_.push(fn.inputs());
torch_ops_.kwinputs_.emplace_back(fn.kwinputs());
}
if (fn.scope() == at::RecordScope::USER_SCOPE) {
torch::profiler::impl::kineto::pushUserCorrelationId(corr_id);
} else {
torch::profiler::impl::kineto::pushCorrelationId(corr_id);
if (!config_.experimental_config.disable_external_correlation) {
if (fn.scope() == at::RecordScope::USER_SCOPE) {
torch::profiler::impl::kineto::pushUserCorrelationId(corr_id);
} else {
torch::profiler::impl::kineto::pushCorrelationId(corr_id);
}
}
#if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE

View File

@ -18,6 +18,7 @@ ExperimentalConfig::ExperimentalConfig(
std::vector<std::string> performance_events,
bool enable_cuda_sync_events,
bool adjust_profiler_step,
bool disable_external_correlation,
bool adjust_timestamps)
: profiler_metrics{std::move(profiler_metrics)},
profiler_measure_per_kernel{profiler_measure_per_kernel},
@ -25,6 +26,7 @@ ExperimentalConfig::ExperimentalConfig(
performance_events(std::move(performance_events)),
enable_cuda_sync_events{enable_cuda_sync_events},
adjust_profiler_step{adjust_profiler_step},
disable_external_correlation{disable_external_correlation},
adjust_timestamps{adjust_timestamps} {}
/*explicit*/ ExperimentalConfig::operator bool() const {

View File

@ -58,6 +58,7 @@ struct TORCH_API ExperimentalConfig {
std::vector<std::string> performance_events = {},
bool enable_cuda_sync_events = false,
bool adjust_profiler_step = false,
bool disable_external_correlation = false,
bool adjust_timestamps = false);
explicit operator bool() const;
@ -81,6 +82,12 @@ struct TORCH_API ExperimentalConfig {
* affects only the start of profiler step events.
*/
bool adjust_profiler_step;
/*
* Controls whether or not external correlation is disabled. This is used to
* lower the amount of events received by CUPTI as correlation events are
* paired with runtime/gpu events for each kind of correlation
*/
bool disable_external_correlation;
/*
* Controls whether or not timestamp adjustment occurs after profiling.

View File

@ -336,7 +336,8 @@ void initPythonBindings(PyObject* module) {
bool /* verbose */,
std::vector<std::string> /* performance_events */,
bool /* enable_cuda_sync_events */,
bool /* adjust_profiler_step */
bool /* adjust_profiler_step */,
bool /* disable_external_correlation*/
>(),
"An experimental config for Kineto features. Please note that"
"backward compatibility is not guaranteed.\n"
@ -352,12 +353,14 @@ void initPythonBindings(PyObject* module) {
" and currently disabled by default.\n"
" adjust_profiler_step (bool) : whether to adjust the profiler step to\n"
" match the parent python event duration. This feature is new and currently disabled by default.\n",
" disable_external_correlation (bool) : whether to disable external correlation\n",
py::arg("profiler_metrics") = std::vector<std::string>(),
py::arg("profiler_measure_per_kernel") = false,
py::arg("verbose") = false,
py::arg("performance_events") = std::vector<std::string>(),
py::arg("enable_cuda_sync_events") = false,
py::arg("adjust_profiler_step") = false)
py::arg("adjust_profiler_step") = false,
py::arg("disable_external_correlation") = false)
.def(py::pickle(
[](const ExperimentalConfig& p) { // __getstate__
py::list py_metrics;
@ -377,6 +380,7 @@ void initPythonBindings(PyObject* module) {
p.verbose,
p.enable_cuda_sync_events,
p.adjust_profiler_step,
p.disable_external_correlation,
p.performance_events);
},
[](const py::tuple& t) { // __setstate__