mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add a custom profiler configuration option (#151656)
We aim to pass some configuration options to our custom Kineto backend via ExperimentalConfig,, so we added a `custom_profiler_config` parameter. Requires https://github.com/pytorch/kineto/pull/1077 , Pull Request resolved: https://github.com/pytorch/pytorch/pull/151656 Approved by: https://github.com/sraikund16
This commit is contained in:
parent
b60569ed94
commit
f860992db5
|
|
@ -207,6 +207,7 @@ class ExperimentalConfigWrapper {
|
|||
configss << "\nCUPTI_PROFILER_ENABLE_PER_KERNEL="
|
||||
<< (config_.profiler_measure_per_kernel ? "true" : "false")
|
||||
<< "\n";
|
||||
configss << "CUSTOM_CONFIG=" << config_.custom_profiler_config << "\n";
|
||||
LOG(INFO) << "Generated config = " << configss.str();
|
||||
|
||||
libkineto::api().activityProfiler().prepareTrace(
|
||||
|
|
@ -239,6 +240,18 @@ static const std::string setTraceID(const std::string& trace_id) {
|
|||
configss << "REQUEST_GROUP_TRACE_ID=" << trace_id << "\n";
|
||||
return configss.str();
|
||||
}
|
||||
|
||||
static const std::string appendCustomConfig(
|
||||
const std::string& config,
|
||||
const std::string& custom_profiler_config) {
|
||||
if (custom_profiler_config.empty()) {
|
||||
return config;
|
||||
}
|
||||
std::stringstream configss;
|
||||
configss << config;
|
||||
configss << "CUSTOM_CONFIG=" << custom_profiler_config << "\n";
|
||||
return configss.str();
|
||||
}
|
||||
#endif
|
||||
|
||||
void prepareTrace(
|
||||
|
|
@ -295,7 +308,9 @@ void prepareTrace(
|
|||
return;
|
||||
}
|
||||
|
||||
const std::string configStr = setTraceID(trace_id);
|
||||
const std::string traceIdStr = setTraceID(trace_id);
|
||||
const std::string configStr =
|
||||
appendCustomConfig(traceIdStr, config.custom_profiler_config);
|
||||
|
||||
libkineto::api().activityProfiler().prepareTrace(k_activities, configStr);
|
||||
#endif // USE_KINETO
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ ExperimentalConfig::ExperimentalConfig(
|
|||
bool disable_external_correlation,
|
||||
bool profile_all_threads,
|
||||
bool capture_overload_names,
|
||||
std::string custom_profiler_config,
|
||||
bool adjust_timestamps)
|
||||
: profiler_metrics{std::move(profiler_metrics)},
|
||||
profiler_measure_per_kernel{profiler_measure_per_kernel},
|
||||
|
|
@ -31,6 +32,7 @@ ExperimentalConfig::ExperimentalConfig(
|
|||
disable_external_correlation{disable_external_correlation},
|
||||
profile_all_threads{profile_all_threads},
|
||||
capture_overload_names{capture_overload_names},
|
||||
custom_profiler_config(std::move(custom_profiler_config)),
|
||||
adjust_timestamps{adjust_timestamps} {}
|
||||
|
||||
/*explicit*/ ExperimentalConfig::operator bool() const {
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ struct TORCH_API ExperimentalConfig {
|
|||
bool disable_external_correlation = false,
|
||||
bool profile_all_threads = false,
|
||||
bool capture_overload_names = false,
|
||||
std::string custom_profiler_config = "",
|
||||
bool adjust_timestamps = false);
|
||||
explicit operator bool() const;
|
||||
|
||||
|
|
@ -101,6 +102,12 @@ struct TORCH_API ExperimentalConfig {
|
|||
* function schema and stored in the profile */
|
||||
bool capture_overload_names;
|
||||
|
||||
/*
|
||||
* A custom_profiler_config option is introduced to allow custom backends
|
||||
* to apply custom configurations as needed.
|
||||
*/
|
||||
std::string custom_profiler_config;
|
||||
|
||||
/*
|
||||
* Controls whether or not timestamp adjustment occurs after profiling.
|
||||
* The purpose of this is to adjust Vulkan event timelines to align with those
|
||||
|
|
|
|||
|
|
@ -340,7 +340,8 @@ void initPythonBindings(PyObject* module) {
|
|||
bool /* adjust_profiler_step */,
|
||||
bool /* disable_external_correlation*/,
|
||||
bool /* profile_all_threads */,
|
||||
bool /* capture_overload_names */
|
||||
bool /* capture_overload_names */,
|
||||
std::string /* custom_profiler_config*/
|
||||
>(),
|
||||
"An experimental config for Kineto features. Please note that"
|
||||
"backward compatibility is not guaranteed.\n"
|
||||
|
|
@ -359,6 +360,7 @@ void initPythonBindings(PyObject* module) {
|
|||
" disable_external_correlation (bool) : whether to disable external correlation\n",
|
||||
" profile_all_threads (bool) : whether to profile all threads\n",
|
||||
" capture_overload_names (bool) : whether to include ATen overload names in the profile\n",
|
||||
" custom_profiler_config (string) : Used to pass some configurations to the custom profiler backend.\n",
|
||||
py::arg("profiler_metrics") = std::vector<std::string>(),
|
||||
py::arg("profiler_measure_per_kernel") = false,
|
||||
py::arg("verbose") = false,
|
||||
|
|
@ -367,7 +369,8 @@ void initPythonBindings(PyObject* module) {
|
|||
py::arg("adjust_profiler_step") = false,
|
||||
py::arg("disable_external_correlation") = false,
|
||||
py::arg("profile_all_threads") = false,
|
||||
py::arg("capture_overload_names") = false)
|
||||
py::arg("capture_overload_names") = false,
|
||||
py::arg("custom_profiler_config") = "")
|
||||
.def(py::pickle(
|
||||
[](const ExperimentalConfig& p) { // __getstate__
|
||||
py::list py_metrics;
|
||||
|
|
@ -390,6 +393,7 @@ void initPythonBindings(PyObject* module) {
|
|||
p.disable_external_correlation,
|
||||
p.profile_all_threads,
|
||||
p.capture_overload_names,
|
||||
p.custom_profiler_config,
|
||||
p.performance_events);
|
||||
},
|
||||
[](const py::tuple& t) { // __setstate__
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user