pytorch/c10/test/util/Enumerate_test.cpp
dolpm a765e2ddda [nativert] port enumerate from folly to c10::utill (#152481)
Summary:
nativert RFC: https://github.com/zhxchen17/rfcs/blob/master/RFC-0043-torch-native-runtime.md

To land the runtime into PyTorch core, we will gradually land logical parts of the code into the Github issue and get each piece properly reviewed.

This diff ports an enumeration util from folly into c10.

Test Plan: CI

Differential Revision: D73881042

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152481
Approved by: https://github.com/Skylion007, https://github.com/zhxchen17, https://github.com/cyyever
2025-05-01 21:41:05 +00:00

270 lines
6.2 KiB
C++

/*
* Ported from folly/container/test/EnumerateTest.cpp
*/
#include <c10/util/Enumerate.h>
#include <gtest/gtest.h>
#include <array>
namespace {
template <class T>
struct IsConstReference {
constexpr static bool value = false;
};
template <class T>
struct IsConstReference<const T&> {
constexpr static bool value = true;
};
constexpr int basicSum(const std::array<int, 3>& test) {
int sum = 0;
for (auto it : c10::enumerate(test)) {
sum += *it;
}
return sum;
}
constexpr int cpp17StructuredBindingSum(const std::array<int, 3>& test) {
int sum = 0;
for (auto&& [_, integer] : c10::enumerate(test)) {
sum += integer;
}
return sum;
}
} // namespace
TEST(Enumerate, Basic) {
std::vector<std::string> v = {"abc", "a", "ab"};
size_t i = 0;
for (auto it : c10::enumerate(v)) {
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
EXPECT_EQ(it->size(), v[i].size());
/* Test mutability. */
std::string newValue = "x";
*it = newValue;
EXPECT_EQ(newValue, v[i]);
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, BasicRRef) {
std::vector<std::string> v = {"abc", "a", "ab"};
size_t i = 0;
for (auto&& it : c10::enumerate(v)) {
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
EXPECT_EQ(it->size(), v[i].size());
/* Test mutability. */
std::string newValue = "x";
*it = newValue;
EXPECT_EQ(newValue, v[i]);
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, BasicConst) {
std::vector<std::string> v = {"abc", "a", "ab"};
size_t i = 0;
for (const auto it : c10::enumerate(v)) {
static_assert(IsConstReference<decltype(*it)>::value, "Const enumeration");
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
EXPECT_EQ(it->size(), v[i].size());
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, BasicConstRef) {
std::vector<std::string> v = {"abc", "a", "ab"};
size_t i = 0;
for (const auto& it : c10::enumerate(v)) {
static_assert(IsConstReference<decltype(*it)>::value, "Const enumeration");
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
EXPECT_EQ(it->size(), v[i].size());
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, BasicConstRRef) {
std::vector<std::string> v = {"abc", "a", "ab"};
size_t i = 0;
for (const auto&& it : c10::enumerate(v)) {
static_assert(IsConstReference<decltype(*it)>::value, "Const enumeration");
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
EXPECT_EQ(it->size(), v[i].size());
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, BasicVecBool) {
std::vector<bool> v = {true, false, false, true};
size_t i = 0;
for (auto it : c10::enumerate(v)) {
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, BasicVecBoolRRef) {
std::vector<bool> v = {true, false, false, true};
size_t i = 0;
for (auto it : c10::enumerate(v)) {
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, Temporary) {
std::vector<std::string> v = {"abc", "a", "ab"};
size_t i = 0;
for (auto&& it : c10::enumerate(decltype(v)(v))) { // Copy v.
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
EXPECT_EQ(it->size(), v[i].size());
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, BasicConstArg) {
const std::vector<std::string> v = {"abc", "a", "ab"};
size_t i = 0;
for (auto&& it : c10::enumerate(v)) {
static_assert(
IsConstReference<decltype(*it)>::value, "Enumerating a const vector");
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
EXPECT_EQ(it->size(), v[i].size());
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, TemporaryConstEnumerate) {
std::vector<std::string> v = {"abc", "a", "ab"};
size_t i = 0;
for (const auto&& it : c10::enumerate(decltype(v)(v))) { // Copy v.
static_assert(IsConstReference<decltype(*it)>::value, "Const enumeration");
EXPECT_EQ(it.index, i);
EXPECT_EQ(*it, v[i]);
EXPECT_EQ(it->size(), v[i].size());
++i;
}
EXPECT_EQ(i, v.size());
}
TEST(Enumerate, EmptyRange) {
std::vector<std::string> v;
for (auto&& it : c10::enumerate(v)) {
(void)it; // Silence warnings.
ADD_FAILURE();
}
}
class CStringRange {
const char* cstr;
public:
struct Sentinel {};
explicit CStringRange(const char* cstr_) : cstr(cstr_) {}
const char* begin() const {
return cstr;
}
Sentinel end() const {
return Sentinel{};
}
};
static bool operator==(const char* c, CStringRange::Sentinel) {
return *c == 0;
}
TEST(Enumerate, Cpp17Support) {
std::array<char, 5> test = {"test"};
for (const auto&& it : c10::enumerate(CStringRange{test.data()})) {
ASSERT_LT(it.index, test.size());
EXPECT_EQ(*it, test[it.index]);
}
}
TEST(Enumerate, Cpp17StructuredBindingConstRef) {
std::vector<std::string> test = {"abc", "a", "ab"};
for (const auto& [index, str] : c10::enumerate(test)) {
ASSERT_LT(index, test.size());
EXPECT_EQ(str, test[index]);
}
}
TEST(Enumerate, Cpp17StructuredBindingConstRRef) {
std::vector<std::string> test = {"abc", "a", "ab"};
for (const auto&& [index, str] : c10::enumerate(test)) {
ASSERT_LT(index, test.size());
EXPECT_EQ(str, test[index]);
}
}
TEST(Enumerate, Cpp17StructuredBindingConstVector) {
const std::vector<std::string> test = {"abc", "a", "ab"};
for (auto&& [index, str] : c10::enumerate(test)) {
static_assert(
IsConstReference<decltype(str)>::value, "Enumerating const vector");
ASSERT_LT(index, test.size());
EXPECT_EQ(str, test[index]);
}
}
TEST(Enumerate, Cpp17StructuredBindingModify) {
std::vector<int> test = {1, 2, 3, 4, 5};
for (auto&& [index, integer] : c10::enumerate(test)) {
integer = 0;
}
for (const auto& integer : test) {
EXPECT_EQ(integer, 0);
}
}
TEST(Enumerate, BasicConstexpr) {
constexpr std::array<int, 3> test = {1, 2, 3};
static_assert(basicSum(test) == 6, "Basic enumerating is not constexpr");
EXPECT_EQ(basicSum(test), 6);
}
TEST(Enumerate, Cpp17StructuredBindingConstexpr) {
constexpr std::array<int, 3> test = {1, 2, 3};
static_assert(
cpp17StructuredBindingSum(test) == 6,
"C++17 structured binding enumerating is not constexpr");
EXPECT_EQ(cpp17StructuredBindingSum(test), 6);
}