pytorch/test/cpp/monitor/test_counters.cpp
Jiawei Lv c80b5b8c8f Revert D33102715: Back out "Revert D32606547: torch/monitor: add C++ events and handlers"
Test Plan: revert-hammer

Differential Revision:
D33102715 (eb374de3f5)

Original commit changeset: 3816ff01c578

Original Phabricator Diff: D33102715 (eb374de3f5)

fbshipit-source-id: e262b6d8c80a05f3a67e024fedfbadefdbfe6e29
2021-12-16 09:39:57 -08:00

198 lines
3.5 KiB
C++

#include <gtest/gtest.h>
#include <torch/csrc/monitor/counters.h>
using namespace torch::monitor;
TEST(MonitorTest, CounterDouble) {
Stat<double> a{
"a",
{MEAN, COUNT},
};
a.add(5.0);
ASSERT_EQ(a.count(), 1);
a.add(6.0);
ASSERT_EQ(a.count(), 2);
a.closeWindow();
auto stats = a.get();
ASSERT_EQ(a.count(), 0);
std::vector<std::pair<Aggregation, double>> want = {
{MEAN, 5.5},
{COUNT, 2.0},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64Sum) {
Stat<int64_t> a{
"a",
{SUM},
};
a.add(5);
a.add(6);
a.closeWindow();
auto stats = a.get();
std::vector<std::pair<Aggregation, int64_t>> want = {
{SUM, 11},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64Value) {
Stat<int64_t> a{
"a",
{VALUE},
};
a.add(5);
a.add(6);
a.closeWindow();
auto stats = a.get();
std::vector<std::pair<Aggregation, int64_t>> want = {
{VALUE, 6},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64Mean) {
Stat<int64_t> a{
"a",
{MEAN},
};
a.add(0);
a.add(10);
{
a.closeWindow();
auto stats = a.get();
std::vector<std::pair<Aggregation, int64_t>> want = {
{MEAN, 5},
};
ASSERT_EQ(stats, want);
}
{
// zero samples case
a.closeWindow();
auto stats = a.get();
std::vector<std::pair<Aggregation, int64_t>> want = {
{MEAN, 0},
};
ASSERT_EQ(stats, want);
}
}
TEST(MonitorTest, CounterInt64Count) {
Stat<int64_t> a{
"a",
{COUNT},
};
ASSERT_EQ(a.count(), 0);
a.add(0);
ASSERT_EQ(a.count(), 1);
a.add(10);
ASSERT_EQ(a.count(), 2);
a.closeWindow();
auto stats = a.get();
ASSERT_EQ(a.count(), 0);
std::vector<std::pair<Aggregation, int64_t>> want = {
{COUNT, 2},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64MinMax) {
Stat<int64_t> a{
"a",
{MIN, MAX},
};
{
a.closeWindow();
auto stats = a.get();
std::vector<std::pair<Aggregation, int64_t>> want = {
{MAX, 0},
{MIN, 0},
};
ASSERT_EQ(stats, want);
}
a.add(0);
a.add(5);
a.add(-5);
a.add(-6);
a.add(9);
a.add(2);
{
a.closeWindow();
auto stats = a.get();
std::vector<std::pair<Aggregation, int64_t>> want = {
{MAX, 9},
{MIN, -6},
};
ASSERT_EQ(stats, want);
}
}
TEST(MonitorTest, CounterInt64WindowSize) {
Stat<int64_t> a{
"a",
{COUNT, SUM},
/*windowSize=*/3,
};
a.add(1);
a.add(2);
ASSERT_EQ(a.count(), 2);
a.add(3);
ASSERT_EQ(a.count(), 0);
a.closeWindow();
auto stats = a.get();
std::vector<std::pair<Aggregation, int64_t>> want = {
{COUNT, 3},
{SUM, 6},
};
ASSERT_EQ(stats, want);
a.closeWindow();
ASSERT_EQ(stats, a.get());
}
TEST(MonitorTest, CloseAndGetStats) {
Stat<int64_t> a{
"a",
{COUNT, SUM},
/*windowSize=*/3,
};
Stat<double> b{
"b",
{MIN, MAX},
2,
};
a.add(1);
b.add(1);
{
auto out = closeAndGetStats();
std::pair<
std::unordered_map<std::string, double>,
std::unordered_map<std::string, int64_t>>
want = {
{{"a.count", 1}, {"a.sum", 1}},
{{"b.min", 0}, {"b.max", 0}},
};
}
a.add(2);
b.add(2);
{
auto out = closeAndGetStats();
std::pair<
std::unordered_map<std::string, double>,
std::unordered_map<std::string, int64_t>>
want = {
{{"a.count", 1}, {"a.sum", 2}},
{{"b.min", 1}, {"b.max", 2}},
};
}
}