#include #include #include #include #include #include using namespace torch::nn; using namespace torch::detail; using Catch::Contains; using Catch::StartsWith; TEST_CASE("any-module") { torch::manual_seed(0); SECTION("int()") { struct M : torch::nn::Module { int forward() { return 123; } }; AnyModule any(M{}); REQUIRE(any.forward().get() == 123); } SECTION("int(int)") { struct M : torch::nn::Module { int forward(int x) { return x; } }; AnyModule any(M{}); REQUIRE(any.forward(5).get() == 5); } SECTION("const char*(const char*)") { struct M : torch::nn::Module { const char* forward(const char* x) { return x; } }; AnyModule any(M{}); REQUIRE(any.forward("hello").get() == std::string("hello")); } SECTION("string(int, const double)") { struct M : torch::nn::Module { std::string forward(int x, const double f) { return std::to_string(static_cast(x + f)); } }; AnyModule any(M{}); int x = 4; REQUIRE(any.forward(x, 3.14).get() == std::string("7")); } SECTION("Tensor(string, const string&, string&&)") { struct M : torch::nn::Module { torch::Tensor forward( std::string a, const std::string& b, std::string&& c) { const auto s = a + b + c; return torch::ones({static_cast(s.size())}); } }; AnyModule any(M{}); REQUIRE( any.forward(std::string("a"), std::string("ab"), std::string("abc")) .get() .sum() .toCInt() == 6); } SECTION("wrong argument type") { struct M : torch::nn::Module { int forward(float x) { return x; } }; AnyModule any(M{}); REQUIRE_THROWS_WITH( any.forward(5.0), StartsWith("Expected argument #0 to be of type float, " "but received value of type double")); } SECTION("wrong number of arguments") { struct M : torch::nn::Module { int forward(int a, int b) { return a + b; } }; AnyModule any(M{}); REQUIRE_THROWS_WITH( any.forward(), Contains("M's forward() method expects 2 arguments, but received 0")); REQUIRE_THROWS_WITH( any.forward(5), Contains("M's forward() method expects 2 arguments, but received 1")); REQUIRE_THROWS_WITH( any.forward(1, 2, 3), Contains("M's forward() method expects 2 arguments, but received 3")); } SECTION("get()") { struct M : torch::nn::Module { explicit M(int value_) : torch::nn::Module("M"), value(value_) {} int value; int forward(float x) { return x; } }; AnyModule any(M{5}); SECTION("good cast") { REQUIRE(any.get().value == 5); } SECTION("bad cast") { struct N : torch::nn::Module {}; REQUIRE_THROWS_WITH(any.get(), StartsWith("Attempted to cast module")); } } SECTION("ptr()") { struct M : torch::nn::Module { explicit M(int value_) : torch::nn::Module("M"), value(value_) {} int value; int forward(float x) { return x; } }; AnyModule any(M{5}); SECTION("base class cast") { auto ptr = any.ptr(); REQUIRE(ptr != nullptr); REQUIRE(ptr->name() == "M"); } SECTION("good downcast") { auto ptr = any.ptr(); REQUIRE(ptr != nullptr); REQUIRE(ptr->value == 5); } SECTION("bad downcast") { struct N : torch::nn::Module {}; REQUIRE_THROWS_WITH(any.ptr(), StartsWith("Attempted to cast module")); } } SECTION("default state is empty") { struct M : torch::nn::Module { explicit M(int value_) : value(value_) {} int value; int forward(float x) { return x; } }; AnyModule any; REQUIRE(any.is_empty()); any = std::make_shared(5); REQUIRE(!any.is_empty()); REQUIRE(any.get().value == 5); } SECTION("all methods throw for empty AnyModule") { struct M : torch::nn::Module { int forward(int x) { return x; } }; AnyModule any; REQUIRE(any.is_empty()); REQUIRE_THROWS_WITH( any.get(), StartsWith("Cannot call get() on an empty AnyModule")); REQUIRE_THROWS_WITH( any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule")); REQUIRE_THROWS_WITH( any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule")); REQUIRE_THROWS_WITH( any.type_info(), StartsWith("Cannot call type_info() on an empty AnyModule")); REQUIRE_THROWS_WITH( any.forward(5), StartsWith("Cannot call forward() on an empty AnyModule")); } SECTION("can move assign differentm modules") { struct M : torch::nn::Module { std::string forward(int x) { return std::to_string(x); } }; struct N : torch::nn::Module { int forward(float x) { return 3 + x; } }; AnyModule any; REQUIRE(any.is_empty()); any = std::make_shared(); REQUIRE(!any.is_empty()); REQUIRE(any.forward(5).get() == "5"); any = std::make_shared(); REQUIRE(!any.is_empty()); REQUIRE(any.forward(5.0f).get() == 8); } SECTION("constructs from ModuleHolder") { struct MImpl : torch::nn::Module { explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {} int value; int forward(float x) { return x; } }; struct M : torch::nn::ModuleHolder { using torch::nn::ModuleHolder::ModuleHolder; using torch::nn::ModuleHolder::get; }; AnyModule any(M{5}); REQUIRE(any.get().value == 5); REQUIRE(any.get()->value == 5); } SECTION("converts at::Tensor to torch::Tensor correctly") { struct M : torch::nn::Module { torch::Tensor forward(torch::Tensor input) { return input; } }; struct N : torch::nn::Module { at::Tensor forward(at::Tensor input) { return input; } }; { // When you get an at::Tensor by performing an operation on a // torch::Tensor, the tensor should be converted back to torch::Tensor // before being passed to the function (to avoid a type mismatch). AnyModule any(M{}); at::Tensor tensor_that_is_actually_a_variable = torch::ones(5) * 2; REQUIRE( any.forward(tensor_that_is_actually_a_variable) .get() .sum() .toCFloat() == 10); // But tensors that are really tensors should just error. REQUIRE_THROWS_WITH( any.forward(at::ones(5)), StartsWith( "Expected argument #0 to be of type torch::autograd::Variable, " "but received value of type at::Tensor")); } { // If the function does really accept an `at::Tensor`, this should still // work. AnyModule any(N{}); REQUIRE(any.forward(at::ones(5)).get().sum().toCFloat() == 5); } } } namespace torch { namespace nn { struct TestValue { template explicit TestValue(T&& value) : value_(std::forward(value)) {} AnyModule::Value operator()() { return std::move(value_); } AnyModule::Value value_; }; template AnyModule::Value make_value(T&& value) { return TestValue(std::forward(value))(); } } // namespace nn } // namespace torch TEST_CASE("any-value") { torch::manual_seed(0); SECTION("gets the correct value for the right type") { SECTION("int") { auto value = make_value(5); // const and non-const types have the same typeid() REQUIRE(value.try_get() != nullptr); REQUIRE(value.try_get() != nullptr); REQUIRE(value.get() == 5); } SECTION("const int") { auto value = make_value(5); REQUIRE(value.try_get() != nullptr); REQUIRE(value.try_get() != nullptr); REQUIRE(value.get() == 5); } SECTION("const char*") { auto value = make_value("hello"); REQUIRE(value.try_get() != nullptr); REQUIRE(value.get() == std::string("hello")); } SECTION("std::string") { auto value = make_value(std::string("hello")); REQUIRE(value.try_get() != nullptr); REQUIRE(value.get() == "hello"); } SECTION("pointers") { std::string s("hello"); std::string* p = &s; auto value = make_value(p); REQUIRE(value.try_get() != nullptr); REQUIRE(*value.get() == "hello"); } SECTION("references") { std::string s("hello"); const std::string& t = s; auto value = make_value(t); REQUIRE(value.try_get() != nullptr); REQUIRE(value.get() == "hello"); } } SECTION("try_get returns nullptr for the wrong type") { auto value = make_value(5); REQUIRE(value.try_get() != nullptr); REQUIRE(value.try_get() == nullptr); REQUIRE(value.try_get() == nullptr); REQUIRE(value.try_get() == nullptr); } SECTION("get throws for the wrong type") { auto value = make_value(5); REQUIRE(value.try_get() != nullptr); REQUIRE_THROWS_WITH( value.get(), StartsWith("Attempted to cast Value to float, " "but its actual type is int")); REQUIRE_THROWS_WITH( value.get(), StartsWith("Attempted to cast Value to long, " "but its actual type is int")); } SECTION("move is allowed") { auto value = make_value(5); SECTION("construction") { auto copy = make_value(std::move(value)); REQUIRE(copy.try_get() != nullptr); REQUIRE(copy.get() == 5); } SECTION("assignment") { auto copy = make_value(10); copy = std::move(value); REQUIRE(copy.try_get() != nullptr); REQUIRE(copy.get() == 5); } } SECTION("type_info is correct") { SECTION("int") { auto value = make_value(5); REQUIRE(value.type_info().hash_code() == typeid(int).hash_code()); } SECTION("const char") { auto value = make_value("hello"); REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code()); } SECTION("std::string") { auto value = make_value(std::string("hello")); REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code()); } } }