diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index 872a62a311f..ec9994e15ec 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -207,6 +207,7 @@ class ExperimentalConfigWrapper { configss << "\nCUPTI_PROFILER_ENABLE_PER_KERNEL=" << (config_.profiler_measure_per_kernel ? "true" : "false") << "\n"; + configss << "CUSTOM_CONFIG=" << config_.custom_profiler_config << "\n"; LOG(INFO) << "Generated config = " << configss.str(); libkineto::api().activityProfiler().prepareTrace( @@ -239,6 +240,18 @@ static const std::string setTraceID(const std::string& trace_id) { configss << "REQUEST_GROUP_TRACE_ID=" << trace_id << "\n"; return configss.str(); } + +static const std::string appendCustomConfig( + const std::string& config, + const std::string& custom_profiler_config) { + if (custom_profiler_config.empty()) { + return config; + } + std::stringstream configss; + configss << config; + configss << "CUSTOM_CONFIG=" << custom_profiler_config << "\n"; + return configss.str(); +} #endif void prepareTrace( @@ -295,7 +308,9 @@ void prepareTrace( return; } - const std::string configStr = setTraceID(trace_id); + const std::string traceIdStr = setTraceID(trace_id); + const std::string configStr = + appendCustomConfig(traceIdStr, config.custom_profiler_config); libkineto::api().activityProfiler().prepareTrace(k_activities, configStr); #endif // USE_KINETO diff --git a/torch/csrc/profiler/orchestration/observer.cpp b/torch/csrc/profiler/orchestration/observer.cpp index 363fb206353..18b792a1abe 100644 --- a/torch/csrc/profiler/orchestration/observer.cpp +++ b/torch/csrc/profiler/orchestration/observer.cpp @@ -21,6 +21,7 @@ ExperimentalConfig::ExperimentalConfig( bool disable_external_correlation, bool profile_all_threads, bool capture_overload_names, + std::string custom_profiler_config, bool adjust_timestamps) : profiler_metrics{std::move(profiler_metrics)}, profiler_measure_per_kernel{profiler_measure_per_kernel}, @@ -31,6 +32,7 @@ ExperimentalConfig::ExperimentalConfig( disable_external_correlation{disable_external_correlation}, profile_all_threads{profile_all_threads}, capture_overload_names{capture_overload_names}, + custom_profiler_config(std::move(custom_profiler_config)), adjust_timestamps{adjust_timestamps} {} /*explicit*/ ExperimentalConfig::operator bool() const { diff --git a/torch/csrc/profiler/orchestration/observer.h b/torch/csrc/profiler/orchestration/observer.h index 54f109ae5c8..427736e6c63 100644 --- a/torch/csrc/profiler/orchestration/observer.h +++ b/torch/csrc/profiler/orchestration/observer.h @@ -62,6 +62,7 @@ struct TORCH_API ExperimentalConfig { bool disable_external_correlation = false, bool profile_all_threads = false, bool capture_overload_names = false, + std::string custom_profiler_config = "", bool adjust_timestamps = false); explicit operator bool() const; @@ -101,6 +102,12 @@ struct TORCH_API ExperimentalConfig { * function schema and stored in the profile */ bool capture_overload_names; + /* + * A custom_profiler_config option is introduced to allow custom backends + * to apply custom configurations as needed. + */ + std::string custom_profiler_config; + /* * Controls whether or not timestamp adjustment occurs after profiling. * The purpose of this is to adjust Vulkan event timelines to align with those diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index db08af05074..92f2f39a5da 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -340,7 +340,8 @@ void initPythonBindings(PyObject* module) { bool /* adjust_profiler_step */, bool /* disable_external_correlation*/, bool /* profile_all_threads */, - bool /* capture_overload_names */ + bool /* capture_overload_names */, + std::string /* custom_profiler_config*/ >(), "An experimental config for Kineto features. Please note that" "backward compatibility is not guaranteed.\n" @@ -359,6 +360,7 @@ void initPythonBindings(PyObject* module) { " disable_external_correlation (bool) : whether to disable external correlation\n", " profile_all_threads (bool) : whether to profile all threads\n", " capture_overload_names (bool) : whether to include ATen overload names in the profile\n", + " custom_profiler_config (string) : Used to pass some configurations to the custom profiler backend.\n", py::arg("profiler_metrics") = std::vector(), py::arg("profiler_measure_per_kernel") = false, py::arg("verbose") = false, @@ -367,7 +369,8 @@ void initPythonBindings(PyObject* module) { py::arg("adjust_profiler_step") = false, py::arg("disable_external_correlation") = false, py::arg("profile_all_threads") = false, - py::arg("capture_overload_names") = false) + py::arg("capture_overload_names") = false, + py::arg("custom_profiler_config") = "") .def(py::pickle( [](const ExperimentalConfig& p) { // __getstate__ py::list py_metrics; @@ -390,6 +393,7 @@ void initPythonBindings(PyObject* module) { p.disable_external_correlation, p.profile_all_threads, p.capture_overload_names, + p.custom_profiler_config, p.performance_events); }, [](const py::tuple& t) { // __setstate__