mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This adds a C++ event handler corresponding to the Python one mentioned in the RFC. This changes the counters a bit to all be push driven instead of being polled. The two window types are "fixed count" and "interval". One is based off the number of logged events and the other is based off of time windows. There's currently no active ticker for interval so it needs a regular stream of events to ensure events are produced. A follow up diff can add support for things like HHWheel / simple ticker. Pull Request resolved: https://github.com/pytorch/pytorch/pull/68783 Test Plan: buck test //caffe2/test/cpp/monitor:monitor Reviewed By: kiukchung Differential Revision: D32606547 fbshipit-source-id: a00d0364092d7d8a98e0b18e503c0ca8ede2bead
70 lines
1.3 KiB
C++
70 lines
1.3 KiB
C++
#include <torch/csrc/monitor/counters.h>
|
|
#include <torch/csrc/monitor/events.h>
|
|
|
|
#include <sstream>
|
|
#include <unordered_set>
|
|
|
|
namespace torch {
|
|
namespace monitor {
|
|
|
|
const char* aggregationName(Aggregation agg) {
|
|
switch (agg) {
|
|
case NONE:
|
|
return "none";
|
|
case VALUE:
|
|
return "value";
|
|
case MEAN:
|
|
return "mean";
|
|
case COUNT:
|
|
return "count";
|
|
case SUM:
|
|
return "sum";
|
|
case MAX:
|
|
return "max";
|
|
case MIN:
|
|
return "min";
|
|
default:
|
|
throw std::runtime_error("unknown aggregation: " + std::to_string(agg));
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
struct Stats {
|
|
std::mutex mu;
|
|
|
|
std::unordered_set<Stat<double>*> doubles;
|
|
std::unordered_set<Stat<int64_t>*> int64s;
|
|
};
|
|
|
|
Stats& stats() {
|
|
static Stats stats;
|
|
return stats;
|
|
}
|
|
} // namespace
|
|
|
|
namespace detail {
|
|
void registerStat(Stat<double>* stat) {
|
|
std::lock_guard<std::mutex> guard(stats().mu);
|
|
|
|
stats().doubles.insert(stat);
|
|
}
|
|
void registerStat(Stat<int64_t>* stat) {
|
|
std::lock_guard<std::mutex> guard(stats().mu);
|
|
|
|
stats().int64s.insert(stat);
|
|
}
|
|
void unregisterStat(Stat<double>* stat) {
|
|
std::lock_guard<std::mutex> guard(stats().mu);
|
|
|
|
stats().doubles.erase(stat);
|
|
}
|
|
void unregisterStat(Stat<int64_t>* stat) {
|
|
std::lock_guard<std::mutex> guard(stats().mu);
|
|
|
|
stats().int64s.erase(stat);
|
|
}
|
|
} // namespace detail
|
|
|
|
} // namespace monitor
|
|
} // namespace torch
|