#include #include #include #include #include #include using namespace torch; using namespace torch::nn; using Catch::StartsWith; TEST_CASE("sequential") { SECTION("construction") { Sequential sequential( Linear(2, 3).build(), Linear(2, 3), Linear(2, 3).build()); REQUIRE(sequential.size() == 3); } SECTION("push_back") { Sequential sequential; REQUIRE(sequential.size() == 0); REQUIRE(sequential.is_empty()); sequential.push_back(Linear(3, 4).build()); REQUIRE(sequential.size() == 1); sequential.push_back(Linear(4, 5).build()); REQUIRE(sequential.size() == 2); } SECTION("access") { std::vector> modules = { Linear(2, 3).build(), Linear(3, 4).build(), Linear(4, 5).build()}; Sequential sequential; for (auto& module : modules) { sequential.push_back(module); } REQUIRE(sequential.size() == 3); SECTION("at()") { SECTION("returns the correct module for a given index") { for (size_t i = 0; i < modules.size(); ++i) { REQUIRE(&sequential.at(i) == modules[i].get()); } } SECTION("throws for a bad index") { REQUIRE_THROWS_WITH( sequential.at(modules.size() + 1), StartsWith("Index out of range")); REQUIRE_THROWS_WITH( sequential.at(modules.size() + 1000000), StartsWith("Index out of range")); } } SECTION("ptr()") { SECTION("returns the correct module for a given index") { for (size_t i = 0; i < modules.size(); ++i) { REQUIRE(sequential.ptr(i).get() == modules[i].get()); REQUIRE(sequential[i].get() == modules[i].get()); REQUIRE(sequential.ptr(i).get() == modules[i].get()); } } SECTION("throws for a bad index") { REQUIRE_THROWS_WITH( sequential.ptr(modules.size() + 1), StartsWith("Index out of range")); REQUIRE_THROWS_WITH( sequential.ptr(modules.size() + 1000000), StartsWith("Index out of range")); } } } SECTION("forward") { SECTION("calling forward() on an empty sequential is disallowed") { Sequential empty; REQUIRE_THROWS_WITH( empty.forward(), StartsWith("Cannot call forward() on an empty Sequential")); } SECTION("calling forward() on a non-empty sequential chains correctly") { struct MockModule : nn::Module { explicit MockModule(int value) : expected(value) {} int expected; int forward(int value) { REQUIRE(value == expected); return value + 1; } }; Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3}); REQUIRE(sequential.forward(1) == 4); } SECTION("calling forward() with the wrong return type throws") { struct M : public nn::Module { int forward() { return 5; } }; Sequential sequential(M{}); REQUIRE(sequential.forward() == 5); REQUIRE_THROWS_WITH( sequential.forward(), StartsWith("The type of the return value " "is int, but you asked for type float")); } SECTION("The return type of forward() defaults to Variable") { struct M : public nn::Module { autograd::Variable forward(autograd::Variable v) { return v; } }; Sequential sequential(M{}); auto variable = torch::ones({3, 3}, at::requires_grad()); REQUIRE(sequential.forward(variable).equal(variable)); } } SECTION("returns the last value") { Sequential sequential( Linear(10, 3).build(), Linear(3, 5).build(), Linear(5, 100).build()); auto x = torch::randn({1000, 10}, at::requires_grad()); auto y = sequential.forward>(std::vector{x}) .front(); REQUIRE(y.ndimension() == 2); REQUIRE(y.size(0) == 1000); REQUIRE(y.size(1) == 100); } }