#include #include #include #include #include #include #include template < typename T, typename = torch::enable_if_t::value>> bool f(T&& m) { return false; } template torch::detail::enable_if_module_t f(T&& m) { return true; } TEST(TestStatic, AllOf) { ASSERT_TRUE(torch::all_of<>::value); ASSERT_TRUE(torch::all_of::value); ASSERT_TRUE((torch::all_of::value)); ASSERT_FALSE(torch::all_of::value); ASSERT_FALSE((torch::all_of::value)); ASSERT_FALSE((torch::all_of::value)); } TEST(TestStatic, AnyOf) { ASSERT_FALSE(torch::any_of<>::value); ASSERT_TRUE(bool((torch::any_of::value))); ASSERT_TRUE(bool((torch::any_of::value))); ASSERT_FALSE(bool((torch::any_of::value))); } TEST(TestStatic, EnableIfModule) { ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2))); ASSERT_FALSE(f(5)); ASSERT_TRUE(torch::detail::check_not_lvalue_references()); ASSERT_TRUE((torch::detail::check_not_lvalue_references())); ASSERT_FALSE( (torch::detail::check_not_lvalue_references())); ASSERT_TRUE(torch::detail::check_not_lvalue_references()); ASSERT_FALSE(torch::detail::check_not_lvalue_references()); } struct A : torch::nn::Module { int forward() { return 5; } }; struct B : torch::nn::Module { std::string forward(torch::Tensor tensor) { return ""; } }; struct C : torch::nn::Module { float forward(torch::Tensor& tensor) { return 5.0; } }; struct D : torch::nn::Module { char forward(torch::Tensor&& tensor) { return 'x'; } }; struct E : torch::nn::Module {}; // Put in a function because macros don't handle the comma between arguments to // is_same well ... template void assert_has_expected_type() { using ReturnType = typename torch::detail::return_type_of_forward::type; constexpr bool is_expected_type = std::is_same::value; ASSERT_TRUE(is_expected_type) << Module().name(); } TEST(TestStatic, ReturnTypeOfForward) { assert_has_expected_type(); assert_has_expected_type(); assert_has_expected_type(); assert_has_expected_type(); assert_has_expected_type(); } TEST(TestStatic, Apply) { std::vector v; torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5); ASSERT_EQ(v.size(), 5); for (const auto i : c10::irange(v.size())) { ASSERT_EQ(v.at(i), i + 1); } }