#include #include #include #include #include #include #include using namespace torch::nn; using namespace torch::test; struct AGIUnit : torch::nn::Module {}; namespace test { struct AGIUnit : torch::nn::Module {}; struct AGIUnit2 : torch::nn::Module { AGIUnit2() : torch::nn::Module("Foo") {} }; } // namespace test struct ModuleTest : torch::test::SeedingFixture {}; TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) { Linear module(3, 4); ASSERT_TRUE(module->is_training()); module->eval(); ASSERT_FALSE(module->is_training()); module->train(); ASSERT_TRUE(module->is_training()); } TEST_F(ModuleTest, ZeroGrad) { Linear module(3, 4); auto weight = torch::ones({8, 3}, torch::requires_grad()); auto loss = module->forward(weight).sum(); loss.backward(); for (auto& parameter : module->parameters()) { auto grad = parameter->grad(); ASSERT_TRUE(grad.defined()); ASSERT_NE(grad.sum().item(), 0); } module->zero_grad(); for (auto& parameter : module->parameters()) { auto grad = parameter->grad(); ASSERT_TRUE(grad.defined()); ASSERT_EQ(grad.sum().item(), 0); } } TEST_F(ModuleTest, ZeroGradWithUndefined) { struct TestModule : torch::nn::Module { TestModule() { x = register_parameter("x", torch::ones(5, torch::requires_grad())); y = register_parameter("y", torch::ones(5, torch::requires_grad())); } torch::Tensor x, y; }; TestModule module; auto z = module.x * 2; z.sum().backward(); ASSERT_TRUE(module.x.grad().defined()); ASSERT_FALSE(module.y.grad().defined()); module.zero_grad(); ASSERT_TRUE(module.x.grad().defined()); ASSERT_FALSE(module.y.grad().defined()); ASSERT_EQ(module.x.grad().sum().item(), 0); } TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) { struct TestModel : public torch::nn::Module { using torch::nn::Module::register_module; }; ASSERT_THROWS_WITH( TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)), "Submodule name must not contain a dot (got 'name.with.dot')"); ASSERT_THROWS_WITH( TestModel{}.register_module("", torch::nn::Linear(3, 4)), "Submodule name must not be empty"); } TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) { struct TestModel : public torch::nn::Module { using torch::nn::Module::register_module; }; TestModel model; model.register_module("linear", torch::nn::Linear(3, 4)); ASSERT_THROWS_WITH( model.register_module("linear", torch::nn::Linear(3, 4)), "Submodule 'linear' already defined"); } TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) { struct TestModel : public torch::nn::Module { using torch::nn::Module::register_parameter; }; ASSERT_THROWS_WITH( TestModel{}.register_parameter("name.with.dot", torch::ones(5)), "Parameter name must not contain a dot (got 'name.with.dot')"); ASSERT_THROWS_WITH( TestModel{}.register_parameter("", torch::ones(5)), "Parameter name must not be empty"); } TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) { struct TestModel : public torch::nn::Module { using torch::nn::Module::register_parameter; }; TestModel model; model.register_parameter("p", torch::ones(5)); ASSERT_THROWS_WITH( model.register_parameter("p", torch::ones(5)), "Parameter 'p' already defined"); } TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) { struct TestModel : public torch::nn::Module { using torch::nn::Module::register_buffer; }; ASSERT_THROWS_WITH( TestModel{}.register_buffer("name.with.dot", torch::ones(5)), "Buffer name must not contain a dot (got 'name.with.dot')"); ASSERT_THROWS_WITH( TestModel{}.register_buffer("", torch::ones(5)), "Buffer name must not be empty"); } TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) { struct TestModel : public torch::nn::Module { using torch::nn::Module::register_buffer; }; TestModel model; model.register_buffer("p", torch::ones(5)); ASSERT_THROWS_WITH( model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined"); } TEST_F(ModuleTest, CanGetName) { // CHECK instead of REQUIRE because demangling may fail. AGIUnit agi; // Call it twice just to make sure there are no bugs in the lazy // initialization semantics. EXPECT_TRUE(agi.name() == "AGIUnit"); EXPECT_TRUE(agi.name() == "AGIUnit"); EXPECT_TRUE(test::AGIUnit().name() == "test::AGIUnit"); EXPECT_TRUE(test::AGIUnit2().name() == "Foo"); } TEST_F(ModuleTest, TestAsCastsModulesCorrectly) { Linear module(3, 4); ASSERT_EQ(module->as(), module.get()); ASSERT_EQ(module->as(), module.get()); ASSERT_EQ(module->as(), module.get()); ASSERT_EQ(module->as(), nullptr); std::shared_ptr raw = module.ptr(); ASSERT_EQ(raw->as(), module.get()); ASSERT_EQ(raw->as(), module.get()); ASSERT_EQ(raw->as(), module.get()); ASSERT_EQ(raw->as(), nullptr); Module& raw_ref = *raw.get(); ASSERT_EQ(raw_ref.as(), module.get()); ASSERT_EQ(raw_ref.as(), module.get()); ASSERT_EQ(raw_ref.as(), module.get()); ASSERT_EQ(raw_ref.as(), nullptr); if (auto* linear = raw_ref.as()) { ASSERT_EQ(linear->weight.ndimension(), 2); } AGIUnit unit; ASSERT_EQ(unit.as(), nullptr); ASSERT_EQ(unit.as(), nullptr); ASSERT_EQ(unit.as(), &unit); } TEST_F(ModuleTest, Conversion_MultiCUDA) { Linear module(128, 64); for (auto& parameter : module->parameters()) { ASSERT_EQ(parameter->device(), torch::Device(torch::kCPU)); ASSERT_EQ(parameter->dtype(), torch::kFloat32); } { module->to({torch::kCUDA, 0}); for (auto& parameter : module->parameters()) { ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA); ASSERT_EQ(parameter->device().index(), 0); } module->to({torch::kCUDA, 1}); for (auto& parameter : module->parameters()) { ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA); ASSERT_EQ(parameter->device().index(), 1); } } { module->to(torch::Device(torch::kCPU)); for (auto& parameter : module->parameters()) { ASSERT_EQ(parameter->device().type(), torch::Device::Type::CPU); } } { module->to(torch::kInt32); for (auto& parameter : module->parameters()) { ASSERT_EQ(parameter->dtype(), torch::kInt32); } } { module->to(torch::kFloat64); for (auto& parameter : module->parameters()) { ASSERT_EQ(parameter->dtype(), torch::kFloat64); } } { module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8); for (auto& parameter : module->parameters()) { ASSERT_EQ(parameter->device().type(), torch::Device::Type::CUDA); ASSERT_EQ(parameter->device().index(), 1); } for (auto& parameter : module->parameters()) { ASSERT_EQ(parameter->dtype(), torch::kUInt8); } } } TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) { struct UnCloneable : Module {}; UnCloneable module; ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented"); } TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) { struct Cloneable : Module { std::shared_ptr clone( torch::optional device = torch::nullopt) const override { return nullptr; } }; Cloneable module; ASSERT_NO_THROW({ module.clone(); }); } TEST_F(ModuleTest, CloneCreatesDistinctParameters) { struct TestModule : public Cloneable { TestModule() { reset(); } void reset() override { l1 = register_module("l1", Linear(10, 3)); l2 = register_module("l2", Linear(3, 5)); l3 = register_module("l3", Linear(5, 100)); buffer = register_buffer("buf", torch::ones({2, 2})); } Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; torch::Tensor buffer; }; auto module = std::make_shared(); torch::NoGradGuard no_grad; auto module2 = module->clone(); auto params1 = module->parameters(); auto params2 = module2->parameters(); ASSERT_EQ(params1.size(), 6); ASSERT_EQ(params2.size(), 6); for (auto& param : params1) { ASSERT_FALSE(pointer_equal(param.value, params2[param.key])); ASSERT_TRUE(param->allclose(params2[param.key])); param->add_(2); } for (auto& param : params1) { ASSERT_FALSE(param->allclose(params2[param.key])); } auto buffers1 = module->buffers(); auto buffers2 = module2->buffers(); ASSERT_EQ(buffers1.size(), 1); ASSERT_EQ(buffers2.size(), 1); for (auto& buffer : buffers1) { ASSERT_FALSE(pointer_equal(buffer.value, buffers2[buffer.key])); ASSERT_TRUE(buffer->allclose(buffers2[buffer.key])); buffer->add_(2); } for (auto& buffer : buffers1) { ASSERT_FALSE(buffer->allclose(buffers2[buffer.key])); } } TEST_F(ModuleTest, ClonePreservesExternalReferences) { struct TestModule : public Cloneable { TestModule() { reset(); } void reset() override { weight = register_parameter("weight", torch::ones({4, 4})); } torch::Tensor weight; }; auto module = std::make_shared(); { torch::NoGradGuard no_grad; module->weight += 1; } ASSERT_TRUE(pointer_equal(module->weight, module->parameters()["weight"])); ASSERT_TRUE(module->weight.allclose(module->parameters()["weight"])); auto module2 = std::dynamic_pointer_cast( std::shared_ptr(module->clone())); ASSERT_FALSE(pointer_equal(module2->weight, module->weight)); ASSERT_TRUE(pointer_equal(module2->weight, module2->parameters()["weight"])); ASSERT_TRUE(module2->weight.allclose(module2->parameters()["weight"])); ASSERT_TRUE(module2->weight.allclose(module->weight)); ASSERT_FALSE(pointer_equal(module2->weight, module->parameters()["weight"])); } TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) { struct TestModule : public Cloneable { TestModule() { reset(); } void reset() override { weight = register_parameter("weight", torch::ones({4, 4})); } torch::Tensor weight; int value = 0; }; struct NestedModule : public Cloneable { NestedModule() { reset(); } void reset() override { module = register_module("module", std::make_shared()); } std::shared_ptr module; }; auto a = std::make_shared(); { torch::NoGradGuard no_grad; a->module->weight += 1; a->module->value = 123; } auto b = std::dynamic_pointer_cast(a->clone()); ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight)); ASSERT_TRUE( pointer_equal(b->module->weight, b->module->parameters()["weight"])); ASSERT_TRUE(b->module->parameters()["weight"].allclose(a->module->weight)); ASSERT_TRUE(b->module->weight.allclose(a->module->weight)); ASSERT_EQ(b->module->value, a->module->value); } TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) { struct TestModule : public Cloneable { TestModule() { reset(); } void reset() override { l1 = register_module("l1", Linear(10, 3)); l2 = register_module("l2", Linear(3, 5)); l3 = register_module("l3", Linear(5, 100)); buffer = register_buffer("buf", torch::ones({2, 2})); } Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; torch::Tensor buffer; }; TestModule m; torch::Device device(torch::kCUDA, 0); m.to(device); auto clone = m.clone(); for (const auto& parameter : clone->parameters()) { ASSERT_EQ(parameter->device().type(), device.type()); ASSERT_EQ(parameter->device().index(), device.index()); } for (const auto& buffer : clone->buffers()) { ASSERT_EQ(buffer->device().type(), device.type()); ASSERT_EQ(buffer->device().index(), device.index()); } } TEST_F(ModuleTest, CloningToAParticularDevicePlacesAllParametersThere_CUDA) { struct TestModule : public Cloneable { TestModule() { reset(); } void reset() override { l1 = register_module("l1", Linear(10, 3)); l2 = register_module("l2", Linear(3, 5)); l3 = register_module("l3", Linear(5, 100)); buffer = register_buffer("buf", torch::ones({2, 2})); } Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; torch::Tensor buffer; }; TestModule m; torch::Device device(torch::kCUDA, 1); // everything is on CPU here auto clone = m.clone(device); for (const auto& parameter : clone->parameters()) { ASSERT_EQ(parameter->device().type(), device.type()); ASSERT_EQ(parameter->device().index(), device.index()); } for (const auto& buffer : clone->buffers()) { ASSERT_EQ(buffer->device().type(), device.type()); ASSERT_EQ(buffer->device().index(), device.index()); } } struct ParameterTestModule : Module { ParameterTestModule() { a = register_parameter("a", torch::zeros({2, 2})); b = register_parameter("b", torch::ones({2, 2})); c = register_parameter("c", torch::ones({2, 2}) * 2); } torch::Tensor a, b, c; }; TEST_F(ModuleTest, HasCorrectNumberOfParameters) { ParameterTestModule module; ASSERT_EQ(module.parameters().size(), 3); } TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) { ParameterTestModule module; auto parameters = module.parameters(); ASSERT_TRUE(parameters.contains("a")); ASSERT_TRUE(parameters.contains("b")); ASSERT_TRUE(parameters.contains("c")); } struct BufferTestModule : Module { BufferTestModule() { a = register_buffer("a", torch::zeros({2, 2})); b = register_buffer("b", torch::ones({2, 2})); c = register_buffer("c", torch::ones({2, 2}) * 2); } torch::Tensor a, b, c; }; TEST_F(ModuleTest, HasCorrectNumberOfBuffers) { BufferTestModule module; ASSERT_EQ(module.buffers().size(), 3); } TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) { BufferTestModule module; auto buffers = module.buffers(); ASSERT_TRUE(buffers.contains("a")); ASSERT_TRUE(buffers.contains("b")); ASSERT_TRUE(buffers.contains("c")); } struct AImpl : torch::nn::Module { AImpl() : x_(123) {} AImpl(int x) : x_(x) {} int x_; }; TORCH_MODULE(A); TEST_F( ModuleTest, DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) { A a; ASSERT_TRUE(a); ASSERT_FALSE(a.is_empty()); ASSERT_EQ(a->x_, 123); } TEST_F( ModuleTest, ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) { A a(5); ASSERT_TRUE(a); ASSERT_FALSE(a.is_empty()); ASSERT_EQ(a->x_, 5); } TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) { A a = nullptr; ASSERT_FALSE(a); ASSERT_TRUE(a.is_empty()); ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder"); }