mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
135 lines
4.0 KiB
C++
135 lines
4.0 KiB
C++
#include <catch.hpp>
|
|
|
|
#include <torch/nn/modules/linear.h>
|
|
#include <torch/nn/modules/sequential.h>
|
|
|
|
#include <vector>
|
|
|
|
using namespace torch;
|
|
using namespace torch::nn;
|
|
|
|
using Catch::StartsWith;
|
|
|
|
TEST_CASE("sequential") {
|
|
SECTION("construction") {
|
|
Sequential sequential(
|
|
Linear(2, 3).build(), Linear(2, 3), Linear(2, 3).build());
|
|
REQUIRE(sequential.size() == 3);
|
|
}
|
|
SECTION("push_back") {
|
|
Sequential sequential;
|
|
REQUIRE(sequential.size() == 0);
|
|
REQUIRE(sequential.is_empty());
|
|
sequential.push_back(Linear(3, 4).build());
|
|
REQUIRE(sequential.size() == 1);
|
|
sequential.push_back(Linear(4, 5).build());
|
|
REQUIRE(sequential.size() == 2);
|
|
}
|
|
SECTION("access") {
|
|
std::vector<std::shared_ptr<Linear>> modules = {
|
|
Linear(2, 3).build(), Linear(3, 4).build(), Linear(4, 5).build()};
|
|
|
|
Sequential sequential;
|
|
for (auto& module : modules) {
|
|
sequential.push_back(module);
|
|
}
|
|
REQUIRE(sequential.size() == 3);
|
|
|
|
SECTION("at()") {
|
|
SECTION("returns the correct module for a given index") {
|
|
for (size_t i = 0; i < modules.size(); ++i) {
|
|
REQUIRE(&sequential.at<Linear>(i) == modules[i].get());
|
|
}
|
|
}
|
|
SECTION("throws for a bad index") {
|
|
REQUIRE_THROWS_WITH(
|
|
sequential.at<Linear>(modules.size() + 1),
|
|
StartsWith("Index out of range"));
|
|
REQUIRE_THROWS_WITH(
|
|
sequential.at<Linear>(modules.size() + 1000000),
|
|
StartsWith("Index out of range"));
|
|
}
|
|
}
|
|
|
|
SECTION("ptr()") {
|
|
SECTION("returns the correct module for a given index") {
|
|
for (size_t i = 0; i < modules.size(); ++i) {
|
|
REQUIRE(sequential.ptr(i).get() == modules[i].get());
|
|
REQUIRE(sequential[i].get() == modules[i].get());
|
|
REQUIRE(sequential.ptr<Linear>(i).get() == modules[i].get());
|
|
}
|
|
}
|
|
SECTION("throws for a bad index") {
|
|
REQUIRE_THROWS_WITH(
|
|
sequential.ptr(modules.size() + 1),
|
|
StartsWith("Index out of range"));
|
|
REQUIRE_THROWS_WITH(
|
|
sequential.ptr(modules.size() + 1000000),
|
|
StartsWith("Index out of range"));
|
|
}
|
|
}
|
|
}
|
|
SECTION("forward") {
|
|
SECTION("calling forward() on an empty sequential is disallowed") {
|
|
Sequential empty;
|
|
REQUIRE_THROWS_WITH(
|
|
empty.forward<int>(),
|
|
StartsWith("Cannot call forward() on an empty Sequential"));
|
|
}
|
|
|
|
SECTION("calling forward() on a non-empty sequential chains correctly") {
|
|
struct MockModule : nn::Module {
|
|
explicit MockModule(int value) : expected(value) {}
|
|
int expected;
|
|
int forward(int value) {
|
|
REQUIRE(value == expected);
|
|
return value + 1;
|
|
}
|
|
};
|
|
|
|
Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
|
|
|
|
REQUIRE(sequential.forward<int>(1) == 4);
|
|
}
|
|
|
|
SECTION("calling forward() with the wrong return type throws") {
|
|
struct M : public nn::Module {
|
|
int forward() {
|
|
return 5;
|
|
}
|
|
};
|
|
|
|
Sequential sequential(M{});
|
|
REQUIRE(sequential.forward<int>() == 5);
|
|
REQUIRE_THROWS_WITH(
|
|
sequential.forward<float>(),
|
|
StartsWith("The type of the return value "
|
|
"is int, but you asked for type float"));
|
|
}
|
|
|
|
SECTION("The return type of forward() defaults to Variable") {
|
|
struct M : public nn::Module {
|
|
autograd::Variable forward(autograd::Variable v) {
|
|
return v;
|
|
}
|
|
};
|
|
|
|
Sequential sequential(M{});
|
|
auto variable =
|
|
autograd::make_variable(at::CPU(at::kFloat).ones({3, 3}), true);
|
|
REQUIRE(sequential.forward(variable).equal(variable));
|
|
}
|
|
}
|
|
|
|
SECTION("returns the last value") {
|
|
Sequential sequential(
|
|
Linear(10, 3).build(), Linear(3, 5).build(), Linear(5, 100).build());
|
|
|
|
auto x = Var(at::CPU(at::kFloat).randn({1000, 10}));
|
|
auto y = sequential.forward<variable_list>(variable_list{x}).front();
|
|
REQUIRE(y.ndimension() == 2);
|
|
REQUIRE(y.size(0) == 1000);
|
|
REQUIRE(y.size(1) == 100);
|
|
}
|
|
}
|