#include #include #include namespace torch { namespace monitor { const char* aggregationName(Aggregation agg) { switch (agg) { case NONE: return "none"; case VALUE: return "value"; case COUNT: return "count"; case SUM: return "sum"; case MAX: return "max"; case MIN: return "min"; default: throw std::runtime_error("unknown aggregation: " + std::to_string(agg)); } } namespace { struct Stats { std::mutex mu; std::unordered_set*> doubles; std::unordered_set*> int64s; }; Stats& stats() { static Stats stats; return stats; } } // namespace namespace detail { void registerStat(Stat* stat) { std::lock_guard guard(stats().mu); stats().doubles.insert(stat); } void registerStat(Stat* stat) { std::lock_guard guard(stats().mu); stats().int64s.insert(stat); } void unregisterStat(Stat* stat) { std::lock_guard guard(stats().mu); stats().doubles.erase(stat); } void unregisterStat(Stat* stat) { std::lock_guard guard(stats().mu); stats().int64s.erase(stat); } } // namespace detail template void closeAndGetStat(Stat* s, std::unordered_map& m) { s->closeWindow(); auto out = s->get(); for (auto& kv : out) { std::stringstream key; key << s->name(); key << "."; key << aggregationName(kv.first); m[key.str()] = kv.second; } } std::pair< std::unordered_map, std::unordered_map> closeAndGetStats() noexcept { std::pair< std::unordered_map, std::unordered_map> out; std::lock_guard guard(stats().mu); for (auto* s : stats().doubles) { closeAndGetStat(s, out.first); } for (auto* s : stats().int64s) { closeAndGetStat(s, out.second); } return out; } } // namespace monitor } // namespace torch