pytorch/torch/csrc/monitor/counters.cpp
Tristan Rice 758d7dea9c torch.monitor - Initial C++ Stats (#68074)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68074

This is the first step of many PRs towards implementing the `torch.monitor` RFC https://github.com/pytorch/rfcs/pull/30

This defines the aggregation types, the `Stat` class and provides some simple collection of the stats.

This doesn't match the RFC exactly as it incorporates some of the comments on the RFC as well as a few changes for performance.

Changes:
* added window_size to the stats. If specified it will always compute the stat using the `window_size` number of values. If there aren't enough values within that window it reports the previous stats.
* This doesn't include the push metrics yet (will be coming).
  After more discussion it looks like the best way to handle this is to support a hybrid where the metric can set how frequently it'll be logged. For fixed window_size metrics it'll be logged each time it hits the window size. This will allow performant counters as well as lower frequency push counters (window_size=1).

Performance considerations:
* Updating the stats acquires a lock on that Stat object. This should be performant unless there's many-many threads writing to the same stat. Single thread will typically use futex so should be quite fast.
* Adding/removing/fetching all stats sets a global lock on the stat list -- this shouldn't be an issue since these events happen infrequently.
* Fetching stats accesses one stat at a time instead of a global lock. This means the exported values are linearizable but not serializable across multiple stats but I don't expect this to be an issue.

Next steps:
1. Add StatCollector interface for push style metrics
1. Add pybind interfaces to expose to Python
1. Add default metric providers
1. Integrate into Kineto trace view

Test Plan:
buck test //caffe2/test/cpp/monitor:monitor

CI

Reviewed By: kiukchung

Differential Revision: D32266032

fbshipit-source-id: dab8747b4712f5dba5644387817a3a0fda18b66a
2021-11-18 21:46:23 -08:00

101 lines
2.0 KiB
C++

#include <torch/csrc/monitor/counters.h>
#include <sstream>
#include <unordered_set>
namespace torch {
namespace monitor {
const char* aggregationName(Aggregation agg) {
switch (agg) {
case NONE:
return "none";
case VALUE:
return "value";
case COUNT:
return "count";
case SUM:
return "sum";
case MAX:
return "max";
case MIN:
return "min";
default:
throw std::runtime_error("unknown aggregation: " + std::to_string(agg));
}
}
namespace {
struct Stats {
std::mutex mu;
std::unordered_set<Stat<double>*> doubles;
std::unordered_set<Stat<int64_t>*> int64s;
};
Stats& stats() {
static Stats stats;
return stats;
}
} // namespace
namespace detail {
void registerStat(Stat<double>* stat) {
std::lock_guard<std::mutex> guard(stats().mu);
stats().doubles.insert(stat);
}
void registerStat(Stat<int64_t>* stat) {
std::lock_guard<std::mutex> guard(stats().mu);
stats().int64s.insert(stat);
}
void unregisterStat(Stat<double>* stat) {
std::lock_guard<std::mutex> guard(stats().mu);
stats().doubles.erase(stat);
}
void unregisterStat(Stat<int64_t>* stat) {
std::lock_guard<std::mutex> guard(stats().mu);
stats().int64s.erase(stat);
}
} // namespace detail
template <typename T>
void closeAndGetStat(Stat<T>* s, std::unordered_map<std::string, T>& m) {
s->closeWindow();
auto out = s->get();
for (auto& kv : out) {
std::stringstream key;
key << s->name();
key << ".";
key << aggregationName(kv.first);
m[key.str()] = kv.second;
}
}
std::pair<
std::unordered_map<std::string, double>,
std::unordered_map<std::string, int64_t>>
closeAndGetStats() noexcept {
std::pair<
std::unordered_map<std::string, double>,
std::unordered_map<std::string, int64_t>>
out;
std::lock_guard<std::mutex> guard(stats().mu);
for (auto* s : stats().doubles) {
closeAndGetStat(s, out.first);
}
for (auto* s : stats().int64s) {
closeAndGetStat(s, out.second);
}
return out;
}
} // namespace monitor
} // namespace torch