#include #include #include #include #include #include #include #include template < typename T, typename = torch::enable_if_t::value>> bool f(T&& m) { return false; } template torch::detail::enable_if_module_t f(T&& m) { return true; } TEST_CASE("static") { SECTION("all_of") { REQUIRE(torch::all_of<>::value == true); REQUIRE(torch::all_of::value == true); REQUIRE(torch::all_of::value == true); REQUIRE(torch::all_of::value == false); REQUIRE(torch::all_of::value == false); REQUIRE(torch::all_of::value == false); } SECTION("any_of") { REQUIRE(torch::any_of<>::value == false); REQUIRE(torch::any_of::value == true); REQUIRE(torch::any_of::value == true); REQUIRE(torch::any_of::value == false); REQUIRE(torch::any_of::value == true); } SECTION("enable_if_module_t") { REQUIRE(f(torch::nn::LinearImpl({1, 2})) == true); REQUIRE(f(5) == false); } SECTION("check_not_lvalue_references") { REQUIRE(torch::detail::check_not_lvalue_references() == true); REQUIRE( torch::detail::check_not_lvalue_references() == true); REQUIRE( torch::detail::check_not_lvalue_references() == false); REQUIRE(torch::detail::check_not_lvalue_references() == true); REQUIRE( torch::detail::check_not_lvalue_references() == false); } SECTION("apply") { std::vector v; torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5); REQUIRE(v.size() == 5); for (size_t i = 0; i < v.size(); ++i) { REQUIRE(v.at(i) == 1 + i); } } }