mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
torch.monitor - Initial C++ Stats (#68074)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68074 This is the first step of many PRs towards implementing the `torch.monitor` RFC https://github.com/pytorch/rfcs/pull/30 This defines the aggregation types, the `Stat` class and provides some simple collection of the stats. This doesn't match the RFC exactly as it incorporates some of the comments on the RFC as well as a few changes for performance. Changes: * added window_size to the stats. If specified it will always compute the stat using the `window_size` number of values. If there aren't enough values within that window it reports the previous stats. * This doesn't include the push metrics yet (will be coming). After more discussion it looks like the best way to handle this is to support a hybrid where the metric can set how frequently it'll be logged. For fixed window_size metrics it'll be logged each time it hits the window size. This will allow performant counters as well as lower frequency push counters (window_size=1). Performance considerations: * Updating the stats acquires a lock on that Stat object. This should be performant unless there's many-many threads writing to the same stat. Single thread will typically use futex so should be quite fast. * Adding/removing/fetching all stats sets a global lock on the stat list -- this shouldn't be an issue since these events happen infrequently. * Fetching stats accesses one stat at a time instead of a global lock. This means the exported values are linearizable but not serializable across multiple stats but I don't expect this to be an issue. Next steps: 1. Add StatCollector interface for push style metrics 1. Add pybind interfaces to expose to Python 1. Add default metric providers 1. Integrate into Kineto trace view Test Plan: buck test //caffe2/test/cpp/monitor:monitor CI Reviewed By: kiukchung Differential Revision: D32266032 fbshipit-source-id: dab8747b4712f5dba5644387817a3a0fda18b66a
This commit is contained in:
parent
68d8ab0cc6
commit
758d7dea9c
197
test/cpp/monitor/test_counters.cpp
Normal file
197
test/cpp/monitor/test_counters.cpp
Normal file
|
|
@ -0,0 +1,197 @@
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/monitor/counters.h>
|
||||||
|
|
||||||
|
using namespace torch::monitor;
|
||||||
|
|
||||||
|
TEST(MonitorTest, CounterDouble) {
|
||||||
|
Stat<double> a{
|
||||||
|
"a",
|
||||||
|
{MEAN, COUNT},
|
||||||
|
};
|
||||||
|
a.add(5.0);
|
||||||
|
ASSERT_EQ(a.count(), 1);
|
||||||
|
a.add(6.0);
|
||||||
|
ASSERT_EQ(a.count(), 2);
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
ASSERT_EQ(a.count(), 0);
|
||||||
|
|
||||||
|
std::vector<std::pair<Aggregation, double>> want = {
|
||||||
|
{MEAN, 5.5},
|
||||||
|
{COUNT, 2.0},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MonitorTest, CounterInt64Sum) {
|
||||||
|
Stat<int64_t> a{
|
||||||
|
"a",
|
||||||
|
{SUM},
|
||||||
|
};
|
||||||
|
a.add(5);
|
||||||
|
a.add(6);
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
||||||
|
{SUM, 11},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MonitorTest, CounterInt64Value) {
|
||||||
|
Stat<int64_t> a{
|
||||||
|
"a",
|
||||||
|
{VALUE},
|
||||||
|
};
|
||||||
|
a.add(5);
|
||||||
|
a.add(6);
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
||||||
|
{VALUE, 6},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MonitorTest, CounterInt64Mean) {
|
||||||
|
Stat<int64_t> a{
|
||||||
|
"a",
|
||||||
|
{MEAN},
|
||||||
|
};
|
||||||
|
a.add(0);
|
||||||
|
a.add(10);
|
||||||
|
|
||||||
|
{
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
||||||
|
{MEAN, 5},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// zero samples case
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
||||||
|
{MEAN, 0},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MonitorTest, CounterInt64Count) {
|
||||||
|
Stat<int64_t> a{
|
||||||
|
"a",
|
||||||
|
{COUNT},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(a.count(), 0);
|
||||||
|
a.add(0);
|
||||||
|
ASSERT_EQ(a.count(), 1);
|
||||||
|
a.add(10);
|
||||||
|
ASSERT_EQ(a.count(), 2);
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
ASSERT_EQ(a.count(), 0);
|
||||||
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
||||||
|
{COUNT, 2},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MonitorTest, CounterInt64MinMax) {
|
||||||
|
Stat<int64_t> a{
|
||||||
|
"a",
|
||||||
|
{MIN, MAX},
|
||||||
|
};
|
||||||
|
{
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
||||||
|
{MAX, 0},
|
||||||
|
{MIN, 0},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
a.add(0);
|
||||||
|
a.add(5);
|
||||||
|
a.add(-5);
|
||||||
|
a.add(-6);
|
||||||
|
a.add(9);
|
||||||
|
a.add(2);
|
||||||
|
{
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
||||||
|
{MAX, 9},
|
||||||
|
{MIN, -6},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MonitorTest, CounterInt64WindowSize) {
|
||||||
|
Stat<int64_t> a{
|
||||||
|
"a",
|
||||||
|
{COUNT, SUM},
|
||||||
|
/*windowSize=*/3,
|
||||||
|
};
|
||||||
|
a.add(1);
|
||||||
|
a.add(2);
|
||||||
|
ASSERT_EQ(a.count(), 2);
|
||||||
|
a.add(3);
|
||||||
|
ASSERT_EQ(a.count(), 0);
|
||||||
|
|
||||||
|
a.closeWindow();
|
||||||
|
auto stats = a.get();
|
||||||
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
||||||
|
{COUNT, 3},
|
||||||
|
{SUM, 6},
|
||||||
|
};
|
||||||
|
ASSERT_EQ(stats, want);
|
||||||
|
a.closeWindow();
|
||||||
|
ASSERT_EQ(stats, a.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MonitorTest, CloseAndGetStats) {
|
||||||
|
Stat<int64_t> a{
|
||||||
|
"a",
|
||||||
|
{COUNT, SUM},
|
||||||
|
/*windowSize=*/3,
|
||||||
|
};
|
||||||
|
Stat<double> b{
|
||||||
|
"b",
|
||||||
|
{MIN, MAX},
|
||||||
|
2,
|
||||||
|
};
|
||||||
|
|
||||||
|
a.add(1);
|
||||||
|
b.add(1);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto out = closeAndGetStats();
|
||||||
|
std::pair<
|
||||||
|
std::unordered_map<std::string, double>,
|
||||||
|
std::unordered_map<std::string, int64_t>>
|
||||||
|
want = {
|
||||||
|
{{"a.count", 1}, {"a.sum", 1}},
|
||||||
|
{{"b.min", 0}, {"b.max", 0}},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
a.add(2);
|
||||||
|
b.add(2);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto out = closeAndGetStats();
|
||||||
|
std::pair<
|
||||||
|
std::unordered_map<std::string, double>,
|
||||||
|
std::unordered_map<std::string, int64_t>>
|
||||||
|
want = {
|
||||||
|
{{"a.count", 1}, {"a.sum", 2}},
|
||||||
|
{{"b.min", 1}, {"b.max", 2}},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -129,6 +129,7 @@ libtorch_sources_common = sorted(core_sources_common + torch_unpickler_common)
|
||||||
libtorch_profiler_sources = [
|
libtorch_profiler_sources = [
|
||||||
"torch/csrc/autograd/profiler_legacy.cpp",
|
"torch/csrc/autograd/profiler_legacy.cpp",
|
||||||
"torch/csrc/autograd/profiler_kineto.cpp",
|
"torch/csrc/autograd/profiler_kineto.cpp",
|
||||||
|
"torch/csrc/monitor/counters.cpp",
|
||||||
]
|
]
|
||||||
|
|
||||||
libtorch_edge_profiler_sources = libtorch_profiler_sources + [
|
libtorch_edge_profiler_sources = libtorch_profiler_sources + [
|
||||||
|
|
|
||||||
100
torch/csrc/monitor/counters.cpp
Normal file
100
torch/csrc/monitor/counters.cpp
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
#include <torch/csrc/monitor/counters.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace monitor {
|
||||||
|
|
||||||
|
const char* aggregationName(Aggregation agg) {
|
||||||
|
switch (agg) {
|
||||||
|
case NONE:
|
||||||
|
return "none";
|
||||||
|
case VALUE:
|
||||||
|
return "value";
|
||||||
|
case COUNT:
|
||||||
|
return "count";
|
||||||
|
case SUM:
|
||||||
|
return "sum";
|
||||||
|
case MAX:
|
||||||
|
return "max";
|
||||||
|
case MIN:
|
||||||
|
return "min";
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("unknown aggregation: " + std::to_string(agg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct Stats {
|
||||||
|
std::mutex mu;
|
||||||
|
|
||||||
|
std::unordered_set<Stat<double>*> doubles;
|
||||||
|
std::unordered_set<Stat<int64_t>*> int64s;
|
||||||
|
};
|
||||||
|
|
||||||
|
Stats& stats() {
|
||||||
|
static Stats stats;
|
||||||
|
return stats;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
void registerStat(Stat<double>* stat) {
|
||||||
|
std::lock_guard<std::mutex> guard(stats().mu);
|
||||||
|
|
||||||
|
stats().doubles.insert(stat);
|
||||||
|
}
|
||||||
|
void registerStat(Stat<int64_t>* stat) {
|
||||||
|
std::lock_guard<std::mutex> guard(stats().mu);
|
||||||
|
|
||||||
|
stats().int64s.insert(stat);
|
||||||
|
}
|
||||||
|
void unregisterStat(Stat<double>* stat) {
|
||||||
|
std::lock_guard<std::mutex> guard(stats().mu);
|
||||||
|
|
||||||
|
stats().doubles.erase(stat);
|
||||||
|
}
|
||||||
|
void unregisterStat(Stat<int64_t>* stat) {
|
||||||
|
std::lock_guard<std::mutex> guard(stats().mu);
|
||||||
|
|
||||||
|
stats().int64s.erase(stat);
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void closeAndGetStat(Stat<T>* s, std::unordered_map<std::string, T>& m) {
|
||||||
|
s->closeWindow();
|
||||||
|
auto out = s->get();
|
||||||
|
for (auto& kv : out) {
|
||||||
|
std::stringstream key;
|
||||||
|
key << s->name();
|
||||||
|
key << ".";
|
||||||
|
key << aggregationName(kv.first);
|
||||||
|
m[key.str()] = kv.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<
|
||||||
|
std::unordered_map<std::string, double>,
|
||||||
|
std::unordered_map<std::string, int64_t>>
|
||||||
|
closeAndGetStats() noexcept {
|
||||||
|
std::pair<
|
||||||
|
std::unordered_map<std::string, double>,
|
||||||
|
std::unordered_map<std::string, int64_t>>
|
||||||
|
out;
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> guard(stats().mu);
|
||||||
|
|
||||||
|
for (auto* s : stats().doubles) {
|
||||||
|
closeAndGetStat(s, out.first);
|
||||||
|
}
|
||||||
|
for (auto* s : stats().int64s) {
|
||||||
|
closeAndGetStat(s, out.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace monitor
|
||||||
|
} // namespace torch
|
||||||
191
torch/csrc/monitor/counters.h
Normal file
191
torch/csrc/monitor/counters.h
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <bitset>
|
||||||
|
#include <mutex>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace monitor {
|
||||||
|
|
||||||
|
constexpr int NUM_AGGREGATIONS = 7;
|
||||||
|
|
||||||
|
// Aggregation is the list of possible aggregations for Stats.
|
||||||
|
// These use bitwise flags so they can be efficiently stored.
|
||||||
|
enum Aggregation {
|
||||||
|
// NONE means no aggregations are set.
|
||||||
|
NONE = 0,
|
||||||
|
// VALUE exports the most recently set value.
|
||||||
|
VALUE = 1,
|
||||||
|
// MEAN computes the mean of the set values within the window. Zero if no
|
||||||
|
// values.
|
||||||
|
MEAN = 2,
|
||||||
|
// COUNT tracks the number of times a value is set within the window.
|
||||||
|
COUNT = 3,
|
||||||
|
// SUM computes the sum of the values set within the window.
|
||||||
|
SUM = 4,
|
||||||
|
// MIN computes the minimum of the values set within the window. Zero if no
|
||||||
|
// values.
|
||||||
|
MAX = 5,
|
||||||
|
// MAX computes the maximum of the values set within the window. Zero if no
|
||||||
|
// values.
|
||||||
|
MIN = 6,
|
||||||
|
};
|
||||||
|
|
||||||
|
const char* aggregationName(Aggregation agg);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Stat;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
inline std::bitset<NUM_AGGREGATIONS> merge(
|
||||||
|
std::initializer_list<Aggregation>& list) {
|
||||||
|
std::bitset<NUM_AGGREGATIONS> a;
|
||||||
|
for (Aggregation b : list) {
|
||||||
|
a.set(b);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
void registerStat(Stat<double>* stat);
|
||||||
|
void registerStat(Stat<int64_t>* stat);
|
||||||
|
void unregisterStat(Stat<double>* stat);
|
||||||
|
void unregisterStat(Stat<int64_t>* stat);
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Stat {
|
||||||
|
private:
|
||||||
|
struct Values {
|
||||||
|
T value{0};
|
||||||
|
T sum{0};
|
||||||
|
T min{0};
|
||||||
|
T max{0};
|
||||||
|
int64_t count{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
Stat(
|
||||||
|
std::string name,
|
||||||
|
std::initializer_list<Aggregation> aggregations,
|
||||||
|
int64_t windowSize = -1)
|
||||||
|
: name_(std::move(name)),
|
||||||
|
aggregations_(merge(aggregations)),
|
||||||
|
windowSize_(windowSize) {
|
||||||
|
detail::registerStat(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
~Stat() {
|
||||||
|
detail::unregisterStat(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add adds the value v to the current window.
|
||||||
|
void add(T v) noexcept {
|
||||||
|
std::lock_guard<std::mutex> guard(mu_);
|
||||||
|
|
||||||
|
if (aggregations_.test(VALUE)) {
|
||||||
|
current_.value = v;
|
||||||
|
}
|
||||||
|
if (aggregations_.test(MEAN) || aggregations_.test(SUM)) {
|
||||||
|
current_.sum += v;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (aggregations_.test(MAX)) {
|
||||||
|
if (current_.max < v || current_.count == 0) {
|
||||||
|
current_.max = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (aggregations_.test(MIN)) {
|
||||||
|
if (current_.min > v || current_.count == 0) {
|
||||||
|
current_.min = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
current_.count += 1;
|
||||||
|
if (windowSize_ > 0 && current_.count >= windowSize_) {
|
||||||
|
saveCurrentLocked();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string& name() const noexcept {
|
||||||
|
return name_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t windowSize() const noexcept {
|
||||||
|
return windowSize_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// count returns the number of items in the current open window.
|
||||||
|
int64_t count() noexcept {
|
||||||
|
std::lock_guard<std::mutex> guard(mu_);
|
||||||
|
|
||||||
|
return current_.count;
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeWindow finalizes the collected stats window so they can be accessed
|
||||||
|
// via get().
|
||||||
|
// If the Stat has a windowSize specified this doesn't do anything since the
|
||||||
|
// window is automatically closed when enough samples have been logged.
|
||||||
|
void closeWindow() noexcept {
|
||||||
|
if (windowSize_ <= 0) {
|
||||||
|
std::lock_guard<std::mutex> guard(mu_);
|
||||||
|
|
||||||
|
saveCurrentLocked();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<Aggregation, T>> get() noexcept {
|
||||||
|
std::vector<std::pair<Aggregation, T>> out;
|
||||||
|
out.reserve(aggregations_.count());
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> guard(mu_);
|
||||||
|
|
||||||
|
if (aggregations_.test(VALUE)) {
|
||||||
|
out.emplace_back(VALUE, prev_.value);
|
||||||
|
}
|
||||||
|
if (aggregations_.test(MEAN)) {
|
||||||
|
if (prev_.count == 0) {
|
||||||
|
out.emplace_back(MEAN, 0);
|
||||||
|
} else {
|
||||||
|
out.emplace_back(MEAN, prev_.sum / prev_.count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (aggregations_.test(COUNT)) {
|
||||||
|
out.emplace_back(COUNT, prev_.count);
|
||||||
|
}
|
||||||
|
if (aggregations_.test(SUM)) {
|
||||||
|
out.emplace_back(SUM, prev_.sum);
|
||||||
|
}
|
||||||
|
if (aggregations_.test(MAX)) {
|
||||||
|
out.emplace_back(MAX, prev_.max);
|
||||||
|
}
|
||||||
|
if (aggregations_.test(MIN)) {
|
||||||
|
out.emplace_back(MIN, prev_.min);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void saveCurrentLocked() {
|
||||||
|
prev_ = current_;
|
||||||
|
current_ = Values();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string name_;
|
||||||
|
const std::bitset<NUM_AGGREGATIONS> aggregations_;
|
||||||
|
const int64_t windowSize_;
|
||||||
|
|
||||||
|
std::mutex mu_;
|
||||||
|
Values current_;
|
||||||
|
Values prev_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::pair<
|
||||||
|
std::unordered_map<std::string, double>,
|
||||||
|
std::unordered_map<std::string, int64_t>>
|
||||||
|
closeAndGetStats() noexcept;
|
||||||
|
} // namespace monitor
|
||||||
|
} // namespace torch
|
||||||
Loading…
Reference in New Issue
Block a user