mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68074 This is the first step of many PRs towards implementing the `torch.monitor` RFC https://github.com/pytorch/rfcs/pull/30 This defines the aggregation types, the `Stat` class and provides some simple collection of the stats. This doesn't match the RFC exactly as it incorporates some of the comments on the RFC as well as a few changes for performance. Changes: * added window_size to the stats. If specified it will always compute the stat using the `window_size` number of values. If there aren't enough values within that window it reports the previous stats. * This doesn't include the push metrics yet (will be coming). After more discussion it looks like the best way to handle this is to support a hybrid where the metric can set how frequently it'll be logged. For fixed window_size metrics it'll be logged each time it hits the window size. This will allow performant counters as well as lower frequency push counters (window_size=1). Performance considerations: * Updating the stats acquires a lock on that Stat object. This should be performant unless there's many-many threads writing to the same stat. Single thread will typically use futex so should be quite fast. * Adding/removing/fetching all stats sets a global lock on the stat list -- this shouldn't be an issue since these events happen infrequently. * Fetching stats accesses one stat at a time instead of a global lock. This means the exported values are linearizable but not serializable across multiple stats but I don't expect this to be an issue. Next steps: 1. Add StatCollector interface for push style metrics 1. Add pybind interfaces to expose to Python 1. Add default metric providers 1. Integrate into Kineto trace view Test Plan: buck test //caffe2/test/cpp/monitor:monitor CI Reviewed By: kiukchung Differential Revision: D32266032 fbshipit-source-id: dab8747b4712f5dba5644387817a3a0fda18b66a
198 lines
3.5 KiB
C++
198 lines
3.5 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/csrc/monitor/counters.h>
|
|
|
|
using namespace torch::monitor;
|
|
|
|
TEST(MonitorTest, CounterDouble) {
|
|
Stat<double> a{
|
|
"a",
|
|
{MEAN, COUNT},
|
|
};
|
|
a.add(5.0);
|
|
ASSERT_EQ(a.count(), 1);
|
|
a.add(6.0);
|
|
ASSERT_EQ(a.count(), 2);
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
ASSERT_EQ(a.count(), 0);
|
|
|
|
std::vector<std::pair<Aggregation, double>> want = {
|
|
{MEAN, 5.5},
|
|
{COUNT, 2.0},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64Sum) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{SUM},
|
|
};
|
|
a.add(5);
|
|
a.add(6);
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
|
{SUM, 11},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64Value) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{VALUE},
|
|
};
|
|
a.add(5);
|
|
a.add(6);
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
|
{VALUE, 6},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64Mean) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{MEAN},
|
|
};
|
|
a.add(0);
|
|
a.add(10);
|
|
|
|
{
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
|
{MEAN, 5},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
{
|
|
// zero samples case
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
|
{MEAN, 0},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64Count) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{COUNT},
|
|
};
|
|
ASSERT_EQ(a.count(), 0);
|
|
a.add(0);
|
|
ASSERT_EQ(a.count(), 1);
|
|
a.add(10);
|
|
ASSERT_EQ(a.count(), 2);
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
ASSERT_EQ(a.count(), 0);
|
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
|
{COUNT, 2},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64MinMax) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{MIN, MAX},
|
|
};
|
|
{
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
|
{MAX, 0},
|
|
{MIN, 0},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
a.add(0);
|
|
a.add(5);
|
|
a.add(-5);
|
|
a.add(-6);
|
|
a.add(9);
|
|
a.add(2);
|
|
{
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
|
{MAX, 9},
|
|
{MIN, -6},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
}
|
|
}
|
|
|
|
TEST(MonitorTest, CounterInt64WindowSize) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{COUNT, SUM},
|
|
/*windowSize=*/3,
|
|
};
|
|
a.add(1);
|
|
a.add(2);
|
|
ASSERT_EQ(a.count(), 2);
|
|
a.add(3);
|
|
ASSERT_EQ(a.count(), 0);
|
|
|
|
a.closeWindow();
|
|
auto stats = a.get();
|
|
std::vector<std::pair<Aggregation, int64_t>> want = {
|
|
{COUNT, 3},
|
|
{SUM, 6},
|
|
};
|
|
ASSERT_EQ(stats, want);
|
|
a.closeWindow();
|
|
ASSERT_EQ(stats, a.get());
|
|
}
|
|
|
|
TEST(MonitorTest, CloseAndGetStats) {
|
|
Stat<int64_t> a{
|
|
"a",
|
|
{COUNT, SUM},
|
|
/*windowSize=*/3,
|
|
};
|
|
Stat<double> b{
|
|
"b",
|
|
{MIN, MAX},
|
|
2,
|
|
};
|
|
|
|
a.add(1);
|
|
b.add(1);
|
|
|
|
{
|
|
auto out = closeAndGetStats();
|
|
std::pair<
|
|
std::unordered_map<std::string, double>,
|
|
std::unordered_map<std::string, int64_t>>
|
|
want = {
|
|
{{"a.count", 1}, {"a.sum", 1}},
|
|
{{"b.min", 0}, {"b.max", 0}},
|
|
};
|
|
}
|
|
|
|
a.add(2);
|
|
b.add(2);
|
|
|
|
{
|
|
auto out = closeAndGetStats();
|
|
std::pair<
|
|
std::unordered_map<std::string, double>,
|
|
std::unordered_map<std::string, int64_t>>
|
|
want = {
|
|
{{"a.count", 1}, {"a.sum", 2}},
|
|
{{"b.min", 1}, {"b.max", 2}},
|
|
};
|
|
}
|
|
}
|