mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR is a large codemod to rewrite all C++ API tests with GoogleTest (gtest) instead of Catch. You can largely trust me to have correctly code-modded the tests, so it's not required to review every of the 2000+ changed lines. However, additional things I changed were: 1. Moved the cmake parts for these tests into their own `CMakeLists.txt` under `test/cpp/api` and calling `add_subdirectory` from `torch/CMakeLists.txt` 2. Fixing DataParallel tests which weren't being compiled because `USE_CUDA` wasn't correctly being set at all. 3. Updated README ezyang ebetica Pull Request resolved: https://github.com/pytorch/pytorch/pull/11953 Differential Revision: D9998883 Pulled By: goldsborough fbshipit-source-id: affe3f320b0ca63e7e0019926a59076bb943db80
403 lines
12 KiB
C++
403 lines
12 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/nn/cursor.h>
|
|
#include <torch/nn/module.h>
|
|
#include <torch/tensor.h>
|
|
#include <torch/utils.h>
|
|
|
|
#include <test/cpp/api/support.h>
|
|
|
|
#include <iostream>
|
|
#include <iterator>
|
|
#include <map>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
using namespace torch::nn;
|
|
using namespace torch::detail;
|
|
|
|
struct TestModule : public torch::nn::Module {
|
|
TestModule(int64_t size) {
|
|
tensor1 = register_parameter("tensor1", torch::randn({size}));
|
|
tensor2 = register_parameter("tensor2", torch::randn({size}));
|
|
}
|
|
|
|
torch::Tensor tensor1;
|
|
torch::Tensor tensor2;
|
|
};
|
|
|
|
struct Container : public torch::nn::Module {
|
|
template <typename... Ms>
|
|
explicit Container(Ms&&... ms) {
|
|
add(0, ms...);
|
|
}
|
|
|
|
void add(size_t) {}
|
|
|
|
template <typename Head, typename... Tail>
|
|
void add(size_t index, Head head, Tail... tail) {
|
|
add(std::to_string(index), std::move(head));
|
|
add(index + 1, tail...);
|
|
}
|
|
|
|
template <typename M>
|
|
void add(std::string name, M&& module) {
|
|
m.push_back(register_module(name, std::make_shared<M>(std::move(module))));
|
|
}
|
|
|
|
template <typename M>
|
|
void add(std::string name, std::shared_ptr<M>&& module) {
|
|
m.push_back(register_module(name, std::move(module)));
|
|
}
|
|
|
|
Module& operator[](size_t index) {
|
|
return *m.at(index);
|
|
}
|
|
|
|
std::vector<std::shared_ptr<Module>> m;
|
|
};
|
|
|
|
struct ModuleCursorFlatTest : torch::test::SeedingFixture {
|
|
ModuleCursorFlatTest()
|
|
: model(TestModule(1), TestModule(2), TestModule(3)),
|
|
cursor(model.modules()) {}
|
|
Container model;
|
|
ModuleCursor cursor;
|
|
};
|
|
|
|
TEST_F(ModuleCursorFlatTest, IteratesInTheCorrectOrder) {
|
|
auto iterator = cursor.begin();
|
|
ASSERT_EQ(&iterator->value, &model[0]);
|
|
ASSERT_EQ(&(++iterator)->value, &model[1]);
|
|
ASSERT_EQ(&(++iterator)->value, &model[2]);
|
|
ASSERT_EQ(++iterator, cursor.end());
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, NamesAreFlat) {
|
|
auto iterator = cursor.begin();
|
|
ASSERT_EQ(iterator->key, "0");
|
|
ASSERT_EQ((++iterator)->key, "1");
|
|
ASSERT_EQ((++iterator)->key, "2");
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, Apply) {
|
|
size_t count = 0;
|
|
cursor.apply([this, &count](Module& module) {
|
|
ASSERT_EQ(&module, &model[count]);
|
|
count += 1;
|
|
});
|
|
ASSERT_EQ(count, 3);
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, ApplyItems) {
|
|
size_t count = 0;
|
|
cursor.apply_items([this, &count](const std::string& key, Module& module) {
|
|
ASSERT_EQ(&module, &model[count]);
|
|
count += 1;
|
|
});
|
|
ASSERT_EQ(count, 3);
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, Map) {
|
|
std::vector<Module*> vector(3);
|
|
cursor.map(vector.begin(), [](Module& module) { return &module; });
|
|
ASSERT_EQ(vector[0], &model[0]);
|
|
ASSERT_EQ(vector[1], &model[1]);
|
|
ASSERT_EQ(vector[2], &model[2]);
|
|
|
|
std::list<Module*> list;
|
|
cursor.map(
|
|
std::inserter(list, list.end()), [](Module& module) { return &module; });
|
|
ASSERT_EQ(list.size(), 3);
|
|
auto iterator = list.begin();
|
|
ASSERT_EQ(*iterator++, &model[0]);
|
|
ASSERT_EQ(*iterator++, &model[1]);
|
|
ASSERT_EQ(*iterator++, &model[2]);
|
|
ASSERT_EQ(iterator, list.end());
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, MapItems) {
|
|
std::map<std::string, Module*> output;
|
|
cursor.map_items(
|
|
std::inserter(output, output.end()),
|
|
[](const std::string& key, Module& module) {
|
|
return std::make_pair(key, &module);
|
|
});
|
|
ASSERT_EQ(output.size(), 3);
|
|
ASSERT_TRUE(output.count("0"));
|
|
ASSERT_TRUE(output.count("1"));
|
|
ASSERT_TRUE(output.count("2"));
|
|
ASSERT_EQ(output["0"], &model[0]);
|
|
ASSERT_EQ(output["1"], &model[1]);
|
|
ASSERT_EQ(output["2"], &model[2]);
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, Count) {
|
|
ASSERT_EQ(cursor.size(), model.m.size());
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, FindReturnsTheCorrectModulesWhenGivenAValidKey) {
|
|
ASSERT_EQ(cursor.find("0"), &model[0]);
|
|
ASSERT_EQ(cursor.find("1"), &model[1]);
|
|
ASSERT_EQ(cursor.find("2"), &model[2]);
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, FindReturnsNullptrWhenGivenAnInvalidKey) {
|
|
ASSERT_EQ(cursor.find("foo"), nullptr);
|
|
ASSERT_EQ(cursor.find("bar"), nullptr);
|
|
}
|
|
|
|
TEST_F(
|
|
ModuleCursorFlatTest,
|
|
AtWithKeyReturnsTheCorrectModulesWhenGivenAValidKey) {
|
|
ASSERT_EQ(&cursor.at("0"), &model[0]);
|
|
ASSERT_EQ(&cursor.at("1"), &model[1]);
|
|
ASSERT_EQ(&cursor.at("2"), &model[2]);
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, AtWithKeyThrowsWhenGivenAnInvalidKey) {
|
|
ASSERT_THROWS_WITH(cursor.at("foo"), "No such key: 'foo'");
|
|
ASSERT_THROWS_WITH(cursor.at("bar"), "No such key: 'bar'");
|
|
}
|
|
|
|
TEST_F(
|
|
ModuleCursorFlatTest,
|
|
SubscriptOperatorWithKeyReturnsCorrectModulesWhenGivenAValidKey) {
|
|
ASSERT_EQ(&cursor["0"], &model[0]);
|
|
ASSERT_EQ(&cursor["1"], &model[1]);
|
|
ASSERT_EQ(&cursor["2"], &model[2]);
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, SubscriptOperatorWithKeyWhenGivenAnInvalidKey) {
|
|
ASSERT_THROWS_WITH(cursor["foo"], "No such key: 'foo'");
|
|
ASSERT_THROWS_WITH(cursor["bar"], "No such key: 'bar'");
|
|
}
|
|
|
|
TEST_F(
|
|
ModuleCursorFlatTest,
|
|
AtWithIndexReturnsTheCorrectModulesWhenGivenAValidKey) {
|
|
ASSERT_EQ(&cursor.at(0).value, &model[0]);
|
|
ASSERT_EQ(&cursor.at(1).value, &model[1]);
|
|
ASSERT_EQ(&cursor.at(2).value, &model[2]);
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, AtWithIndexThrowsWhenGivenAnInvalidKey) {
|
|
ASSERT_THROWS_WITH(
|
|
cursor.at(5), "Index 5 is out of range for cursor of size 3");
|
|
ASSERT_THROWS_WITH(
|
|
cursor.at(123), "Index 123 is out of range for cursor of size 3");
|
|
}
|
|
|
|
TEST_F(
|
|
ModuleCursorFlatTest,
|
|
SubscriptOperatorWithIndexReturnsCorrectModulesWhenGivenAValidKey) {
|
|
ASSERT_EQ(&cursor[0].value, &model[0]);
|
|
ASSERT_EQ(&cursor[1].value, &model[1]);
|
|
ASSERT_EQ(&cursor[2].value, &model[2]);
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, SubscriptOperatorWithIndexWhenGivenAnInvalidKey) {
|
|
ASSERT_THROWS_WITH(cursor[5], "Index 5 is out of range for cursor of size 3");
|
|
ASSERT_THROWS_WITH(
|
|
cursor[123], "Index 123 is out of range for cursor of size 3");
|
|
}
|
|
|
|
TEST_F(ModuleCursorFlatTest, ContainReturnsTrueWhenKeyIsPresent) {
|
|
ASSERT_TRUE(cursor.contains("0"));
|
|
ASSERT_TRUE(cursor.contains("1"));
|
|
ASSERT_TRUE(cursor.contains("2"));
|
|
}
|
|
|
|
struct ModuleCursorDeepTest : torch::test::SeedingFixture {
|
|
ModuleCursorDeepTest()
|
|
: model(
|
|
Container(TestModule(1), TestModule(2)),
|
|
TestModule(3),
|
|
Container(TestModule(4), Container(TestModule(5), TestModule(6)))) {
|
|
}
|
|
Container model;
|
|
};
|
|
|
|
TEST_F(ModuleCursorDeepTest, IteratesInTheCorrectOrder) {
|
|
auto cursor = model.modules();
|
|
auto iterator = cursor.begin();
|
|
|
|
ASSERT_EQ(&iterator->value, &model[0]);
|
|
|
|
auto* seq = dynamic_cast<Container*>(&model[0]);
|
|
ASSERT_NE(seq, nullptr);
|
|
ASSERT_EQ(&(++iterator)->value, &(*seq)[0]);
|
|
ASSERT_EQ(&(++iterator)->value, &(*seq)[1]);
|
|
|
|
ASSERT_EQ(&(++iterator)->value, &model[1]);
|
|
ASSERT_EQ(&(++iterator)->value, &model[2]);
|
|
|
|
seq = dynamic_cast<Container*>(&model[2]);
|
|
ASSERT_NE(seq, nullptr);
|
|
ASSERT_EQ(&(++iterator)->value, &(*seq)[0]);
|
|
ASSERT_EQ(&(++iterator)->value, &(*seq)[1]);
|
|
|
|
seq = dynamic_cast<Container*>(&(*seq)[1]);
|
|
ASSERT_NE(seq, nullptr);
|
|
ASSERT_EQ(&(++iterator)->value, &(*seq)[0]);
|
|
ASSERT_EQ(&(++iterator)->value, &(*seq)[1]);
|
|
}
|
|
|
|
TEST_F(ModuleCursorDeepTest, ChildrenReturnsOnlyTheFirstLevelOfSubmodules) {
|
|
auto children = model.children();
|
|
ASSERT_EQ(children.size(), 3);
|
|
ASSERT_EQ(&children.at("0"), &model[0]);
|
|
ASSERT_EQ(&children.at("1"), &model[1]);
|
|
ASSERT_EQ(&children.at("2"), &model[2]);
|
|
ASSERT_FALSE(children.contains("0.0"));
|
|
size_t count = 0;
|
|
for (auto& child : children) {
|
|
ASSERT_EQ(child.key, std::to_string(count));
|
|
ASSERT_EQ(&child.value, &model[count]);
|
|
count += 1;
|
|
}
|
|
}
|
|
|
|
struct ParameterCursorFlatTest : torch::test::SeedingFixture {
|
|
ParameterCursorFlatTest()
|
|
: first(std::make_shared<TestModule>(1)),
|
|
second(std::make_shared<TestModule>(2)),
|
|
model(first, second),
|
|
cursor(model.parameters()) {}
|
|
std::shared_ptr<TestModule> first, second;
|
|
Container model;
|
|
ParameterCursor cursor;
|
|
};
|
|
|
|
TEST(ParameterCursorTest, IteratesInTheCorrectOrderOverSimpleModels) {
|
|
torch::manual_seed(0);
|
|
TestModule model(1);
|
|
auto cursor = model.parameters();
|
|
auto iterator = cursor.begin();
|
|
ASSERT_TRUE(iterator->value.equal(model.tensor1));
|
|
ASSERT_TRUE((++iterator)->value.equal(model.tensor2));
|
|
}
|
|
|
|
TEST_F(ParameterCursorFlatTest, IteratesInTheCorrectOrder) {
|
|
auto iterator = cursor.begin();
|
|
ASSERT_TRUE(iterator->value.equal(first->tensor1));
|
|
ASSERT_TRUE((++iterator)->value.equal(first->tensor2));
|
|
ASSERT_TRUE((++iterator)->value.equal(second->tensor1));
|
|
ASSERT_TRUE((++iterator)->value.equal(second->tensor2));
|
|
}
|
|
|
|
TEST_F(ParameterCursorFlatTest, ApplyItemsWorks) {
|
|
size_t count = 0;
|
|
cursor.apply_items(
|
|
[this, &count](const std::string& key, torch::Tensor& tensor) {
|
|
switch (count) {
|
|
case 0: {
|
|
ASSERT_TRUE(tensor.equal(first->tensor1));
|
|
break;
|
|
}
|
|
case 1: {
|
|
ASSERT_TRUE(tensor.equal(first->tensor2));
|
|
break;
|
|
}
|
|
case 2: {
|
|
ASSERT_TRUE(tensor.equal(second->tensor1));
|
|
break;
|
|
}
|
|
case 3: {
|
|
ASSERT_TRUE(tensor.equal(second->tensor2));
|
|
break;
|
|
}
|
|
}
|
|
count += 1;
|
|
});
|
|
ASSERT_EQ(count, 4);
|
|
}
|
|
|
|
struct ParameterCursorDeepTest : torch::test::SeedingFixture {
|
|
std::vector<std::shared_ptr<TestModule>> make_modules() {
|
|
std::vector<std::shared_ptr<TestModule>> modules;
|
|
for (size_t i = 1; i <= 6; ++i) {
|
|
modules.push_back(std::make_shared<TestModule>(i));
|
|
}
|
|
return modules;
|
|
}
|
|
|
|
ParameterCursorDeepTest()
|
|
: modules(make_modules()),
|
|
model(
|
|
Container(modules[0], modules[1]),
|
|
modules[2],
|
|
Container(modules[3], Container(modules[4], modules[5]))) {}
|
|
|
|
std::vector<std::shared_ptr<TestModule>> modules;
|
|
Container model;
|
|
};
|
|
|
|
TEST_F(ParameterCursorDeepTest, IteratesInTheCorrectOrderOverDeepModels) {
|
|
auto cursor = model.parameters();
|
|
auto iterator = cursor.begin();
|
|
ASSERT_TRUE(iterator->value.equal(modules[0]->tensor1));
|
|
ASSERT_TRUE((++iterator)->value.equal(modules[0]->tensor2));
|
|
for (size_t index = 1; index < 6; ++index) {
|
|
ASSERT_TRUE((++iterator)->value.equal(modules[index]->tensor1));
|
|
ASSERT_TRUE((++iterator)->value.equal(modules[index]->tensor2));
|
|
}
|
|
}
|
|
|
|
TEST_F(ParameterCursorDeepTest, NamesAreHierarchical) {
|
|
auto cursor = model.parameters();
|
|
auto iterator = cursor.begin();
|
|
ASSERT_EQ(iterator->key, "0.0.tensor1");
|
|
ASSERT_EQ((++iterator)->key, "0.0.tensor2");
|
|
ASSERT_EQ((++iterator)->key, "0.1.tensor1");
|
|
ASSERT_EQ((++iterator)->key, "0.1.tensor2");
|
|
ASSERT_EQ((++iterator)->key, "1.tensor1");
|
|
ASSERT_EQ((++iterator)->key, "1.tensor2");
|
|
ASSERT_EQ((++iterator)->key, "2.0.tensor1");
|
|
ASSERT_EQ((++iterator)->key, "2.0.tensor2");
|
|
ASSERT_EQ((++iterator)->key, "2.1.0.tensor1");
|
|
ASSERT_EQ((++iterator)->key, "2.1.0.tensor2");
|
|
ASSERT_EQ((++iterator)->key, "2.1.1.tensor1");
|
|
ASSERT_EQ((++iterator)->key, "2.1.1.tensor2");
|
|
ASSERT_EQ(++iterator, cursor.end());
|
|
}
|
|
|
|
struct CursorTest : torch::test::SeedingFixture {};
|
|
|
|
TEST_F(CursorTest, NonConstToConstConversion) {
|
|
auto first = std::make_shared<TestModule>(1);
|
|
auto second = std::make_shared<TestModule>(2);
|
|
Container model(first, second);
|
|
|
|
{
|
|
ConstModuleCursor const_cursor(model.modules());
|
|
{
|
|
ModuleCursor cursor = model.modules();
|
|
ConstModuleCursor const_cursor = cursor;
|
|
}
|
|
}
|
|
{
|
|
ConstParameterCursor const_cursor(model.parameters());
|
|
{
|
|
ParameterCursor cursor = model.parameters();
|
|
ConstParameterCursor const_cursor = cursor;
|
|
}
|
|
}
|
|
{
|
|
ConstBufferCursor const_cursor(model.buffers());
|
|
{
|
|
BufferCursor cursor = model.buffers();
|
|
ConstBufferCursor const_cursor = cursor;
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_F(CursorTest, CanInvokeConstMethodOnConstCursor) {
|
|
TestModule model(1);
|
|
|
|
/// This will only compile if `Cursor` has the appropriate const methods.
|
|
const auto cursor = model.parameters();
|
|
ASSERT_TRUE(cursor.contains("tensor1"));
|
|
}
|