mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
130 lines
3.2 KiB
C++
130 lines
3.2 KiB
C++
#include "caffe2/core/net.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/mpi/mpi_common.h"
|
|
#include "glog/logging.h"
|
|
#include "gflags/gflags.h"
|
|
#include "google/protobuf/text_format.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
const char kBcastNet[] =
|
|
" name: \"bcast\""
|
|
" op {"
|
|
" output: \"X\""
|
|
" type: \"ConstantFill\""
|
|
" arg {"
|
|
" name: \"shape\""
|
|
" ints: 10"
|
|
" }"
|
|
" arg {"
|
|
" name: \"value\""
|
|
" f: 0.0"
|
|
" }"
|
|
" }"
|
|
" op {"
|
|
" input: \"X\""
|
|
" output: \"X\""
|
|
" type: \"Broadcast\""
|
|
" arg {"
|
|
" name: \"root\""
|
|
" i: 0"
|
|
" }"
|
|
" }";
|
|
|
|
TEST(MPITest, TestBroadcast) {
|
|
NetDef net_def;
|
|
CHECK(google::protobuf::TextFormat::ParseFromString(
|
|
string(kBcastNet), &net_def));
|
|
// Let's set the network's constant fill value to be the mpi rank.
|
|
auto* arg = net_def.mutable_op(0)->mutable_arg(1);
|
|
CHECK_EQ(arg->name(), "value");
|
|
int rank;
|
|
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
|
arg->set_f(rank);
|
|
int size;
|
|
MPI_Comm_size(MPI_COMM_WORLD, &size);
|
|
|
|
for (int root = 0; root < size; ++root) {
|
|
net_def.mutable_op(1)->mutable_arg(0)->set_i(root);
|
|
Workspace ws;
|
|
unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
|
EXPECT_NE(nullptr, net.get());
|
|
EXPECT_TRUE(net->Verify());
|
|
EXPECT_TRUE(net->Run());
|
|
// Let's test the value.
|
|
auto& X = ws.GetBlob("X")->Get<Tensor<float, CPUContext> >();
|
|
EXPECT_EQ(X.size(), 10);
|
|
for (int i = 0; i < X.size(); ++i) {
|
|
EXPECT_EQ(X.data()[i], root);
|
|
}
|
|
}
|
|
}
|
|
|
|
const char kAllreduceNet[] =
|
|
" name: \"allreduce\""
|
|
" op {"
|
|
" output: \"X\""
|
|
" type: \"ConstantFill\""
|
|
" arg {"
|
|
" name: \"shape\""
|
|
" ints: 10"
|
|
" }"
|
|
" arg {"
|
|
" name: \"value\""
|
|
" f: 0.0"
|
|
" }"
|
|
" }"
|
|
" op {"
|
|
" input: \"X\""
|
|
" output: \"X_reduced\""
|
|
" type: \"Allreduce\""
|
|
" }";
|
|
|
|
TEST(MPITest, TestAllreduce) {
|
|
NetDef net_def;
|
|
CHECK(google::protobuf::TextFormat::ParseFromString(
|
|
string(kAllreduceNet), &net_def));
|
|
// Let's set the network's constant fill value to be the mpi rank.
|
|
auto* arg = net_def.mutable_op(0)->mutable_arg(1);
|
|
CHECK_EQ(arg->name(), "value");
|
|
int rank;
|
|
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
|
arg->set_f(rank);
|
|
int size;
|
|
MPI_Comm_size(MPI_COMM_WORLD, &size);
|
|
|
|
Workspace ws;
|
|
unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
|
EXPECT_NE(nullptr, net.get());
|
|
EXPECT_TRUE(net->Verify());
|
|
EXPECT_TRUE(net->Run());
|
|
// Let's test the value.
|
|
auto& X = ws.GetBlob("X")->Get<Tensor<float, CPUContext> >();
|
|
EXPECT_EQ(X.size(), 10);
|
|
for (int i = 0; i < X.size(); ++i) {
|
|
EXPECT_EQ(X.data()[i], rank);
|
|
}
|
|
auto& X_reduced = ws.GetBlob("X_reduced")->Get<Tensor<float, CPUContext> >();
|
|
EXPECT_EQ(X_reduced.size(), 10);
|
|
int expected_result = size * (size - 1) / 2;
|
|
for (int i = 0; i < X_reduced.size(); ++i) {
|
|
EXPECT_EQ(X_reduced.data()[i], expected_result);
|
|
}
|
|
}
|
|
|
|
} // namespace caffe2
|
|
|
|
DEFINE_string(caffe_test_root, "gen/", "The root of the caffe test folder.");
|
|
|
|
GTEST_API_ int main(int argc, char **argv) {
|
|
int mpi_ret;
|
|
MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mpi_ret);
|
|
testing::InitGoogleTest(&argc, argv);
|
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
|
google::InitGoogleLogging(argv[0]);
|
|
int test_result = RUN_ALL_TESTS();
|
|
MPI_Finalize();
|
|
return test_result;
|
|
}
|