#include #include #include #include #include #include #include #include template < typename T, typename = torch::enable_if_t::value>> bool f(T&& m) { return false; } template torch::detail::enable_if_module_t f(T&& m) { return true; } TEST(TestStatic, AllOf) { ASSERT_TRUE(torch::all_of<>::value); ASSERT_TRUE(torch::all_of::value); ASSERT_TRUE((torch::all_of::value)); ASSERT_FALSE(torch::all_of::value); ASSERT_FALSE((torch::all_of::value)); ASSERT_FALSE((torch::all_of::value)); } TEST(TestStatic, AnyOf) { ASSERT_FALSE(torch::any_of<>::value); ASSERT_TRUE(bool((torch::any_of::value))); ASSERT_TRUE(bool((torch::any_of::value))); ASSERT_FALSE(bool((torch::any_of::value))); } TEST(TestStatic, EnableIfModule) { ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2))); ASSERT_FALSE(f(5)); ASSERT_TRUE(torch::detail::check_not_lvalue_references()); ASSERT_TRUE((torch::detail::check_not_lvalue_references())); ASSERT_FALSE( (torch::detail::check_not_lvalue_references())); ASSERT_TRUE(torch::detail::check_not_lvalue_references()); ASSERT_FALSE(torch::detail::check_not_lvalue_references()); } TEST(TestStatic, Apply) { std::vector v; torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5); ASSERT_EQ(v.size(), 5); for (size_t i = 0; i < v.size(); ++i) { ASSERT_EQ(v.at(i), i + 1); } }