pytorch/test/cpp/api/module.cpp
Peter Goldsborough ab0c72ab6f Replace cursors with OrderedDict (#13427)
Summary:
This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly,  using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere.

For this I did:

1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules.
2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too.
3. Deleted all uses of Cursor.
4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation.
5. Added many tests for the OrderedDict use in `nn::Module`.

ebetica ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427

Differential Revision: D12894092

Pulled By: goldsborough

fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 11:10:05 -08:00

832 lines
26 KiB
C++

#include <gtest/gtest.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/rnn.h>
#include <torch/types.h>
#include <torch/nn/modules/sequential.h>
#include <torch/utils.h>
#include <test/cpp/api/support.h>
using namespace torch::nn;
using namespace torch::test;
struct AGIUnit : torch::nn::Module {};
namespace test {
struct AGIUnit : torch::nn::Module {};
struct AGIUnit2 : torch::nn::Module {
AGIUnit2() : torch::nn::Module("Foo") {}
};
} // namespace test
struct ModuleTest : torch::test::SeedingFixture {};
TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
Linear module(3, 4);
ASSERT_TRUE(module->is_training());
module->eval();
ASSERT_FALSE(module->is_training());
module->train();
ASSERT_TRUE(module->is_training());
}
TEST_F(ModuleTest, ZeroGrad) {
Linear module(3, 4);
auto weight = torch::ones({8, 3}, torch::requires_grad());
auto loss = module->forward(weight).sum();
loss.backward();
for (auto& parameter : module->parameters()) {
auto grad = parameter.grad();
ASSERT_TRUE(grad.defined());
ASSERT_NE(grad.sum().item<float>(), 0);
}
module->zero_grad();
for (auto& parameter : module->parameters()) {
auto grad = parameter.grad();
ASSERT_TRUE(grad.defined());
ASSERT_EQ(grad.sum().item<float>(), 0);
}
}
TEST_F(ModuleTest, ZeroGradWithUndefined) {
struct TestModule : torch::nn::Module {
TestModule() {
x = register_parameter("x", torch::ones(5, torch::requires_grad()));
y = register_parameter("y", torch::ones(5, torch::requires_grad()));
}
torch::Tensor x, y;
};
TestModule module;
auto z = module.x * 2;
z.sum().backward();
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
module.zero_grad();
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
}
TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::register_module;
};
ASSERT_THROWS_WITH(
TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
"Submodule name must not contain a dot (got 'name.with.dot')");
ASSERT_THROWS_WITH(
TestModel{}.register_module("", torch::nn::Linear(3, 4)),
"Submodule name must not be empty");
}
TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::register_module;
};
TestModel model;
model.register_module("linear", torch::nn::Linear(3, 4));
ASSERT_THROWS_WITH(
model.register_module("linear", torch::nn::Linear(3, 4)),
"Submodule 'linear' already defined");
}
TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::register_parameter;
};
ASSERT_THROWS_WITH(
TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
"Parameter name must not contain a dot (got 'name.with.dot')");
ASSERT_THROWS_WITH(
TestModel{}.register_parameter("", torch::ones(5)),
"Parameter name must not be empty");
}
TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::register_parameter;
};
TestModel model;
model.register_parameter("p", torch::ones(5));
ASSERT_THROWS_WITH(
model.register_parameter("p", torch::ones(5)),
"Parameter 'p' already defined");
}
TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::register_buffer;
};
ASSERT_THROWS_WITH(
TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
"Buffer name must not contain a dot (got 'name.with.dot')");
ASSERT_THROWS_WITH(
TestModel{}.register_buffer("", torch::ones(5)),
"Buffer name must not be empty");
}
TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::register_buffer;
};
TestModel model;
model.register_buffer("p", torch::ones(5));
ASSERT_THROWS_WITH(
model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
}
TEST_F(ModuleTest, CanGetName) {
// CHECK instead of REQUIRE because demangling may fail.
AGIUnit agi;
// Call it twice just to make sure there are no bugs in the lazy
// initialization semantics.
EXPECT_TRUE(agi.name() == "AGIUnit");
EXPECT_TRUE(agi.name() == "AGIUnit");
EXPECT_TRUE(test::AGIUnit().name() == "test::AGIUnit");
EXPECT_TRUE(test::AGIUnit2().name() == "Foo");
}
TEST_F(ModuleTest, AsCastsModulesCorrectly) {
Linear module(3, 4);
ASSERT_EQ(module->as<Linear>(), module.get());
ASSERT_EQ(module->as<LinearImpl>(), module.get());
ASSERT_EQ(module->as<Module>(), module.get());
ASSERT_EQ(module->as<AGIUnit>(), nullptr);
std::shared_ptr<Module> raw = module.ptr();
ASSERT_EQ(raw->as<Linear>(), module.get());
ASSERT_EQ(raw->as<LinearImpl>(), module.get());
ASSERT_EQ(raw->as<Module>(), module.get());
ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
Module& raw_ref = *raw.get();
ASSERT_EQ(raw_ref.as<Linear>(), module.get());
ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
ASSERT_EQ(raw_ref.as<Module>(), module.get());
ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
if (auto* linear = raw_ref.as<Linear>()) {
ASSERT_EQ(linear->weight.ndimension(), 2);
}
AGIUnit unit;
ASSERT_EQ(unit.as<Linear>(), nullptr);
ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
ASSERT_EQ(unit.as<AGIUnit>(), &unit);
}
TEST_F(ModuleTest, Conversion_MultiCUDA) {
Linear module(128, 64);
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
ASSERT_EQ(parameter.dtype(), torch::kFloat32);
}
{
module->to({torch::kCUDA, 0});
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
ASSERT_EQ(parameter.device().index(), 0);
}
module->to({torch::kCUDA, 1});
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
ASSERT_EQ(parameter.device().index(), 1);
}
}
{
module->to(torch::Device(torch::kCPU));
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
}
}
{
module->to(torch::kInt32);
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.dtype(), torch::kInt32);
}
}
{
module->to(torch::kFloat64);
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.dtype(), torch::kFloat64);
}
}
{
module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
ASSERT_EQ(parameter.device().index(), 1);
}
for (auto& parameter : module->parameters()) {
ASSERT_EQ(parameter.dtype(), torch::kUInt8);
}
}
}
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
struct UnCloneable : Module {};
UnCloneable module;
ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
}
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
struct Cloneable : Module {
std::shared_ptr<Module> clone(
torch::optional<torch::Device> device = torch::nullopt) const override {
return nullptr;
}
};
Cloneable module;
ASSERT_NO_THROW({ module.clone(); });
}
TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
}
void reset() override {
l1 = register_module("l1", Linear(10, 3));
l2 = register_module("l2", Linear(3, 5));
l3 = register_module("l3", Linear(5, 100));
buffer = register_buffer("buf", torch::ones({2, 2}));
}
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
torch::Tensor buffer;
};
auto module = std::make_shared<TestModule>();
torch::NoGradGuard no_grad;
auto module2 = module->clone();
auto params1 = module->named_parameters();
auto params2 = module2->named_parameters();
ASSERT_EQ(params1.size(), 6);
ASSERT_EQ(params2.size(), 6);
for (auto& param : params1) {
ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
ASSERT_TRUE(param->allclose(params2[param.key()]));
param->add_(2);
}
for (auto& param : params1) {
ASSERT_FALSE(param->allclose(params2[param.key()]));
}
auto buffers1 = module->named_buffers();
auto buffers2 = module2->named_buffers();
ASSERT_EQ(buffers1.size(), 1);
ASSERT_EQ(buffers2.size(), 1);
for (auto& buffer : buffers1) {
ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
buffer->add_(2);
}
for (auto& buffer : buffers1) {
ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
}
}
TEST_F(ModuleTest, ClonePreservesExternalReferences) {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
}
void reset() override {
weight = register_parameter("weight", torch::ones({4, 4}));
}
torch::Tensor weight;
};
auto module = std::make_shared<TestModule>();
{
torch::NoGradGuard no_grad;
module->weight += 1;
}
ASSERT_TRUE(
pointer_equal(module->weight, module->named_parameters()["weight"]));
ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
auto module2 = std::dynamic_pointer_cast<TestModule>(
std::shared_ptr<Module>(module->clone()));
ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
ASSERT_TRUE(
pointer_equal(module2->weight, module2->named_parameters()["weight"]));
ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
ASSERT_TRUE(module2->weight.allclose(module->weight));
ASSERT_FALSE(
pointer_equal(module2->weight, module->named_parameters()["weight"]));
}
TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
}
void reset() override {
weight = register_parameter("weight", torch::ones({4, 4}));
}
torch::Tensor weight;
int value = 0;
};
struct NestedModule : public Cloneable<NestedModule> {
NestedModule() {
reset();
}
void reset() override {
module = register_module("module", std::make_shared<TestModule>());
}
std::shared_ptr<TestModule> module;
};
auto a = std::make_shared<NestedModule>();
{
torch::NoGradGuard no_grad;
a->module->weight += 1;
a->module->value = 123;
}
auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
ASSERT_TRUE(pointer_equal(
b->module->weight, b->module->named_parameters()["weight"]));
ASSERT_TRUE(
b->module->named_parameters()["weight"].allclose(a->module->weight));
ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
ASSERT_EQ(b->module->value, a->module->value);
}
TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
}
void reset() override {
l1 = register_module("l1", Linear(10, 3));
l2 = register_module("l2", Linear(3, 5));
l3 = register_module("l3", Linear(5, 100));
buffer = register_buffer("buf", torch::ones({2, 2}));
}
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
torch::Tensor buffer;
};
TestModule m;
torch::Device device(torch::kCUDA, 0);
m.to(device);
auto clone = m.clone();
for (const auto& parameter : clone->parameters()) {
ASSERT_EQ(parameter.device().type(), device.type());
ASSERT_EQ(parameter.device().index(), device.index());
}
for (const auto& buffer : clone->buffers()) {
ASSERT_EQ(buffer.device().type(), device.type());
ASSERT_EQ(buffer.device().index(), device.index());
}
}
TEST_F(ModuleTest, CloningToAParticularDevicePlacesAllParametersThere_CUDA) {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
}
void reset() override {
l1 = register_module("l1", Linear(10, 3));
l2 = register_module("l2", Linear(3, 5));
l3 = register_module("l3", Linear(5, 100));
buffer = register_buffer("buf", torch::ones({2, 2}));
}
Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
torch::Tensor buffer;
};
TestModule m;
torch::Device device(torch::kCUDA, 1);
// everything is on CPU here
auto clone = m.clone(device);
for (const auto& parameter : clone->parameters()) {
ASSERT_EQ(parameter.device().type(), device.type());
ASSERT_EQ(parameter.device().index(), device.index());
}
for (const auto& buffer : clone->buffers()) {
ASSERT_EQ(buffer.device().type(), device.type());
ASSERT_EQ(buffer.device().index(), device.index());
}
}
struct ParameterTestModule : Module {
ParameterTestModule() {
a = register_parameter("a", torch::zeros({2, 2}));
b = register_parameter("b", torch::ones({2, 2}));
c = register_parameter("c", torch::ones({2, 2}) * 2);
}
torch::Tensor a, b, c;
};
TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
ParameterTestModule module;
ASSERT_EQ(module.parameters().size(), 3);
ASSERT_EQ(module.named_parameters().size(), 3);
}
TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
ParameterTestModule module;
auto parameters = module.named_parameters();
ASSERT_TRUE(parameters.contains("a"));
ASSERT_TRUE(parameters.contains("b"));
ASSERT_TRUE(parameters.contains("c"));
}
struct BufferTestModule : Module {
BufferTestModule() {
a = register_buffer("a", torch::zeros({2, 2}));
b = register_buffer("b", torch::ones({2, 2}));
c = register_buffer("c", torch::ones({2, 2}) * 2);
}
torch::Tensor a, b, c;
};
TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
BufferTestModule module;
ASSERT_EQ(module.buffers().size(), 3);
ASSERT_EQ(module.named_buffers().size(), 3);
}
TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
BufferTestModule module;
auto buffers = module.named_buffers();
ASSERT_TRUE(buffers.contains("a"));
ASSERT_TRUE(buffers.contains("b"));
ASSERT_TRUE(buffers.contains("c"));
}
struct AImpl : torch::nn::Module {
AImpl() : x_(123) {}
AImpl(int x) : x_(x) {}
int x_;
};
TORCH_MODULE(A);
TEST_F(
ModuleTest,
DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
A a;
ASSERT_TRUE(a);
ASSERT_FALSE(a.is_empty());
ASSERT_EQ(a->x_, 123);
}
TEST_F(
ModuleTest,
ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
A a(5);
ASSERT_TRUE(a);
ASSERT_FALSE(a.is_empty());
ASSERT_EQ(a->x_, 5);
}
TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
A a = nullptr;
ASSERT_FALSE(a);
ASSERT_TRUE(a.is_empty());
ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
}
struct TestModule : public torch::nn::Module {
TestModule(int64_t size) {
p1 = register_parameter("p1", torch::randn({size}));
p2 = register_parameter("p2", torch::randn({size}));
b1 = register_buffer("b1", torch::randn({size}));
b2 = register_buffer("b2", torch::randn({size}));
}
torch::Tensor forward(torch::Tensor input) {
return input;
}
torch::Tensor p1, p2, b1, b2;
};
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model.ptr(), model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (size_t i = 0; i < expected.size(); ++i) {
// Assert pointer equality.
ASSERT_EQ(modules[i].get(), expected[i].get());
}
}
TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
std::vector<std::shared_ptr<torch::nn::Module>> modules =
model->modules(/*include_self=*/false);
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (size_t i = 0; i < expected.size(); ++i) {
// Assert pointer equality.
ASSERT_EQ(modules[i].get(), expected[i].get());
}
}
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_modules();
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model.ptr(), model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (size_t i = 0; i < expected.size(); ++i) {
// Assert pointer equality.
ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
ASSERT_EQ(modules[i].value().get(), expected[i].get());
}
}
TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_modules(
/*name_prefix=*/std::string(), /*include_self=*/false);
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (size_t i = 0; i < expected.size(); ++i) {
// Assert pointer equality.
ASSERT_EQ(modules[i].key(), std::to_string(i));
ASSERT_EQ(modules[i].value().get(), expected[i].get());
}
}
TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (size_t i = 0; i < expected.size(); ++i) {
// Assert pointer equality.
ASSERT_EQ(modules[i].get(), expected[i].get());
}
// For this flat model, this should be true.
ASSERT_EQ(modules, model->modules(/*include_self=*/false));
}
TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_children();
std::vector<std::shared_ptr<torch::nn::Module>> expected = {
model[0], model[1], model[2]};
ASSERT_EQ(modules.size(), expected.size());
for (size_t i = 0; i < expected.size(); ++i) {
// Assert pointer equality.
ASSERT_EQ(modules[i].key(), std::to_string(i));
ASSERT_EQ(modules[i].value().get(), expected[i].get());
}
}
TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
TestModule module(1);
std::vector<torch::Tensor> parameters = module.parameters();
ASSERT_EQ(parameters.size(), 2);
ASSERT_EQ(parameters[0].data<float>(), module.p1.data<float>());
ASSERT_EQ(parameters[1].data<float>(), module.p2.data<float>());
}
TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
TestModule module(1);
torch::OrderedDict<std::string, torch::Tensor> parameters =
module.named_parameters();
ASSERT_EQ(parameters.size(), 2);
ASSERT_EQ(parameters[0].key(), "p1");
ASSERT_EQ(parameters[0]->data<float>(), module.p1.data<float>());
ASSERT_EQ(parameters[1].key(), "p2");
ASSERT_EQ(parameters[1]->data<float>(), module.p2.data<float>());
}
TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
TestModule module(1);
std::vector<torch::Tensor> buffers = module.buffers();
ASSERT_EQ(buffers.size(), 2);
ASSERT_EQ(buffers[0].data<float>(), module.b1.data<float>());
ASSERT_EQ(buffers[1].data<float>(), module.b2.data<float>());
}
TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
TestModule module(1);
torch::OrderedDict<std::string, torch::Tensor> buffers =
module.named_buffers();
ASSERT_EQ(buffers.size(), 2);
ASSERT_EQ(buffers[0].key(), "b1");
ASSERT_EQ(buffers[0]->data<float>(), module.b1.data<float>());
ASSERT_EQ(buffers[1].key(), "b2");
ASSERT_EQ(buffers[1]->data<float>(), module.b2.data<float>());
}
struct TestContainer : torch::nn::Module {
TestContainer(int64_t number, std::vector<TestContainer> modules = {})
: tensor(torch::tensor(number)) {
for (size_t i = 0; i < modules.size(); ++i) {
register_module(
std::to_string(i),
std::make_shared<TestContainer>(std::move(modules[i])));
}
}
torch::Tensor tensor;
};
int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
return std::dynamic_pointer_cast<TestContainer>(module)
->tensor.item<int64_t>();
}
std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
return std::make_shared<TestContainer>(TestContainer(
0,
{TestContainer(1, {TestContainer(2), TestContainer(3)}),
TestContainer(4),
TestContainer(
5,
{TestContainer(6),
TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
}
std::vector<std::pair<std::string, int64_t>>
make_key_value_pairs_for_deeply_nested_container() {
return {{"test_prefix", 0},
{"test_prefix.0", 1},
{"test_prefix.0.0", 2},
{"test_prefix.0.1", 3},
{"test_prefix.1", 4},
{"test_prefix.2", 5},
{"test_prefix.2.0", 6},
{"test_prefix.2.1", 7},
{"test_prefix.2.1.0", 8},
{"test_prefix.2.1.1", 9}};
}
TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
auto model = make_deeply_nested_test_container();
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
ASSERT_EQ(modules.size(), 10);
for (size_t i = 0; i < modules.size(); ++i) {
ASSERT_EQ(get_test_container_item(modules[i]), i);
}
}
TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
auto model = make_deeply_nested_test_container();
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_modules(/*name_prefix=*/"test_prefix");
auto expected = make_key_value_pairs_for_deeply_nested_container();
ASSERT_EQ(modules.size(), expected.size());
for (size_t i = 0; i < expected.size(); ++i) {
ASSERT_EQ(modules[i].key(), expected[i].first);
ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
}
}
TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
auto model = make_deeply_nested_test_container();
std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
ASSERT_EQ(modules.size(), 3);
ASSERT_EQ(get_test_container_item(modules[0]), 1);
ASSERT_EQ(get_test_container_item(modules[1]), 4);
ASSERT_EQ(get_test_container_item(modules[2]), 5);
}
TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
auto model = make_deeply_nested_test_container();
torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules =
model->named_children();
ASSERT_EQ(modules.size(), 3);
ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
ASSERT_EQ(modules[0].key(), "0");
ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
ASSERT_EQ(modules[1].key(), "1");
ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
ASSERT_EQ(modules[2].key(), "2");
}
TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
auto model = make_deeply_nested_test_container();
int64_t index = 0;
model->apply([&index](torch::nn::Module& module) {
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
});
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
std::shared_ptr<const TestContainer> model =
make_deeply_nested_test_container();
int64_t index = 0;
model->apply([&index](const torch::nn::Module& module) {
ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
});
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
auto model = make_deeply_nested_test_container();
auto expected = make_key_value_pairs_for_deeply_nested_container();
int64_t index = 0;
model->apply(
[&index, expected](const std::string& name, torch::nn::Module& module) {
ASSERT_EQ(name, expected[index].first);
ASSERT_EQ(
module.as<TestContainer>()->tensor.item<int64_t>(),
expected[index++].second);
},
/*name_prefix=*/"test_prefix");
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
std::shared_ptr<const TestContainer> model =
make_deeply_nested_test_container();
auto expected = make_key_value_pairs_for_deeply_nested_container();
int64_t index = 0;
model->apply(
[&index, &expected](
const std::string& name, const torch::nn::Module& module) {
ASSERT_EQ(name, expected[index].first);
ASSERT_EQ(
module.as<const TestContainer>()->tensor.item<int64_t>(),
expected[index++].second);
},
/*name_prefix=*/"test_prefix");
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
auto model = make_deeply_nested_test_container();
int64_t index = 0;
model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
ASSERT_EQ(get_test_container_item(module), index++);
});
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
auto model = make_deeply_nested_test_container();
auto expected = make_key_value_pairs_for_deeply_nested_container();
int64_t index = 0;
model->apply(
[&index, &expected](
const std::string& name,
const std::shared_ptr<torch::nn::Module>& module) {
ASSERT_EQ(name, expected[index].first);
ASSERT_EQ(get_test_container_item(module), expected[index++].second);
},
/*name_prefix=*/"test_prefix");
ASSERT_EQ(index, 10);
}
TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
{
TestModule module(1);
ASSERT_THROWS_WITH(
module.modules(),
"It looks like you attempted to retrieve "
"your top-level module as a shared_ptr")
}
{
TestModule module(1);
ASSERT_NO_THROW(module.modules(/*include_self=*/false));
}
{
auto module = std::make_shared<TestModule>(1);
ASSERT_NO_THROW(module->modules());
}
}