#include #include #include #include #include #include #include #include namespace torch::jit { namespace { class ProfilesRegistry { public: bool empty() { return empty_.load(std::memory_order_relaxed); } void addProfile(ScriptProfile& p) { std::lock_guard g(mutex_); enabledProfiles_.emplace(&p); empty_.store(false, std::memory_order_relaxed); } void removeProfile(ScriptProfile& p) { std::lock_guard g(mutex_); enabledProfiles_.erase(&p); if (enabledProfiles_.empty()) { empty_.store(true, std::memory_order_relaxed); } } void send(std::unique_ptr datapoint) { auto shared = std::shared_ptr(std::move(datapoint)); std::lock_guard g(mutex_); for (auto* p : enabledProfiles_) { p->addDatapoint(shared); } } private: std::atomic empty_{true}; std::mutex mutex_; std::unordered_set enabledProfiles_; }; ProfilesRegistry& getProfilesRegistry() { static auto registry = std::ref(*new ProfilesRegistry{}); return registry; } auto initBindings() { torch::class_("profiling", "SourceRef") .def( "starting_lineno", [](const c10::intrusive_ptr& self) { return static_cast((*self)->starting_line_no()); }) .def("text", [](const c10::intrusive_ptr& self) { return (*self)->text_str().str(); }); torch::class_("profiling", "InstructionStats") .def( "count", [](const c10::intrusive_ptr& self) { return self->count; }) .def("duration_ns", [](const c10::intrusive_ptr& self) { return static_cast(self->duration.count()); }); torch::class_("profiling", "SourceStats") .def( "source", [](const c10::intrusive_ptr& self) { return c10::make_intrusive(self->getSourceRef()); }) .def("line_map", &SourceStats::getLineMap); torch::class_("profiling", "_ScriptProfile") .def(torch::init<>()) .def("enable", &ScriptProfile::enable) .def("disable", &ScriptProfile::disable) .def("_dump_stats", [](const c10::intrusive_ptr& self) { const auto& stats = self->dumpStats(); c10::List> ret; for (const auto& source : stats) { SourceStats::LineMap lineMap; for (const auto& line : source.second) { lineMap.insert( line.first, c10::make_intrusive(line.second)); } ret.push_back(c10::make_intrusive( source.first, std::move(lineMap))); } return ret; }); return nullptr; } const auto C10_UNUSED torchBindInitializer = initBindings(); } // namespace namespace profiling { InstructionSpan::InstructionSpan(Node& node) { datapoint_ = std::make_unique(node.sourceRange()); } InstructionSpan::~InstructionSpan() { datapoint_->end = std::chrono::steady_clock::now(); getProfilesRegistry().send(std::move(datapoint_)); } bool isProfilingOngoing() { return !getProfilesRegistry().empty(); } } // namespace profiling void ScriptProfile::enable() { if (!std::exchange(enabled_, true)) { getProfilesRegistry().addProfile(*this); } } void ScriptProfile::disable() { if (std::exchange(enabled_, false)) { getProfilesRegistry().removeProfile(*this); } } void ScriptProfile::addDatapoint( std::shared_ptr datapoint) { TORCH_CHECK(enabled_, "Cannot only add datapoint to disabled profilers."); datapoints_.push_back(std::move(datapoint)); } const ScriptProfile::SourceMap& ScriptProfile::dumpStats() { TORCH_CHECK(!enabled_, "Only disabled profilers are allowed to dump stats."); for (const auto& datapoint : datapoints_) { if (const auto& source = datapoint->sourceRange.source()) { if (auto fileLineCol = datapoint->sourceRange.file_line_col()) { auto it = sourceMap_.find(*source.get()); if (it == sourceMap_.end()) { it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first; } auto& stats = it->second[std::get<1>(*fileLineCol)]; stats.count++; stats.duration += datapoint->end - datapoint->start; } } } datapoints_.clear(); return sourceMap_; } ScriptProfile::~ScriptProfile() { if (enabled_) { getProfilesRegistry().removeProfile(*this); } } } // namespace torch::jit