mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67624 Test Plan: Visual inspection. Sandcastle. Reviewed By: malfet Differential Revision: D31986628 fbshipit-source-id: c872bded7325997a2945dbf5d4d052628dcb3659
179 lines
5.3 KiB
C++
179 lines
5.3 KiB
C++
#include "caffe2/mpi/mpi_common.h"
|
|
|
|
#include <thread>
|
|
|
|
#include <c10/util/typeid.h>
|
|
#include "caffe2/utils/proto_utils.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
CAFFE_KNOWN_TYPE(MPICommonWorldWrapper);
|
|
|
|
static std::mutex gCaffe2MPIMutex;
|
|
|
|
std::mutex& MPIMutex() {
|
|
return gCaffe2MPIMutex;
|
|
}
|
|
|
|
static MPI_Comm gCaffe2MPIComm = MPI_COMM_WORLD;
|
|
|
|
MPI_Comm GlobalMPIComm() {
|
|
return gCaffe2MPIComm;
|
|
}
|
|
|
|
void SetGlobalMPIComm(MPI_Comm new_comm) {
|
|
if (gCaffe2MPIComm != MPI_COMM_WORLD) {
|
|
MPI_Comm_free(&gCaffe2MPIComm);
|
|
}
|
|
gCaffe2MPIComm = new_comm;
|
|
}
|
|
|
|
int MPICommSize(MPI_Comm comm) {
|
|
int comm_size;
|
|
MPI_CHECK(MPI_Comm_size(comm, &comm_size));
|
|
return comm_size;
|
|
}
|
|
|
|
int MPICommRank(MPI_Comm comm) {
|
|
int comm_rank;
|
|
MPI_CHECK(MPI_Comm_rank(comm, &comm_rank));
|
|
return comm_rank;
|
|
}
|
|
|
|
/**
|
|
* Helper function used to setup MPI intercommunicator.
|
|
*/
|
|
static MPI_Comm AssimilateComm(MPI_Comm intra, MPI_Comm inter) {
|
|
MPI_Comm peer = MPI_COMM_NULL;
|
|
MPI_Comm newInterComm = MPI_COMM_NULL;
|
|
MPI_Comm newIntraComm = MPI_COMM_NULL;
|
|
|
|
// The spawned rank will be the "high" rank in the new intra-comm
|
|
int high = (MPI_COMM_NULL == intra) ? 1 : 0;
|
|
|
|
// If this is one of the (two) ranks in the inter-comm,
|
|
// create a new intra-comm from the inter-comm
|
|
if (MPI_COMM_NULL != inter) {
|
|
MPI_CHECK(MPI_Intercomm_merge(inter, high, &peer));
|
|
} else {
|
|
peer = MPI_COMM_NULL;
|
|
}
|
|
|
|
// Create a new inter-comm between the pre-existing intra-comm
|
|
// (all of it, not only rank zero), and the remote (spawned) rank,
|
|
// using the just-created intra-comm as the peer communicator.
|
|
int tag = 12345;
|
|
if (MPI_COMM_NULL != intra) {
|
|
// This task is a member of the pre-existing intra-comm
|
|
MPI_CHECK(MPI_Intercomm_create(intra, 0, peer, 1, tag, &newInterComm));
|
|
} else {
|
|
// This is the remote (spawned) task
|
|
MPI_CHECK(
|
|
MPI_Intercomm_create(MPI_COMM_SELF, 0, peer, 0, tag, &newInterComm));
|
|
}
|
|
|
|
// Now convert this inter-comm into an intra-comm
|
|
MPI_CHECK(MPI_Intercomm_merge(newInterComm, high, &newIntraComm));
|
|
|
|
// Clean up the intermediaries
|
|
if (MPI_COMM_NULL != peer) {
|
|
MPI_CHECK(MPI_Comm_free(&peer));
|
|
}
|
|
MPI_CHECK(MPI_Comm_free(&newInterComm));
|
|
|
|
// Delete the original intra-comm
|
|
if (MPI_COMM_NULL != intra && MPI_COMM_WORLD != intra &&
|
|
GlobalMPIComm() != intra) {
|
|
MPI_CHECK(MPI_Comm_free(&intra));
|
|
}
|
|
|
|
// Return the new intra-comm
|
|
return newIntraComm;
|
|
}
|
|
|
|
void MPISetupPeers(
|
|
const int replicas,
|
|
const string& role,
|
|
const string& job_path) {
|
|
int flag;
|
|
MPI_Initialized(&flag);
|
|
if (!flag) {
|
|
int mpi_ret;
|
|
MPI_CHECK(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &mpi_ret));
|
|
if (mpi_ret != MPI_THREAD_MULTIPLE && mpi_ret != MPI_THREAD_SERIALIZED) {
|
|
LOG(FATAL) << "This test requires the underlying MPI to support the "
|
|
<< "MPI_THREAD_SERIALIZED or MPI_THREAD_MULTIPLE mode.";
|
|
return;
|
|
}
|
|
}
|
|
|
|
if (MPICommSize(MPI_COMM_WORLD) != 1) {
|
|
LOG(ERROR) << "MPI_COMM_WORLD size is not 1: did you already run "
|
|
"MPISetupPeers? Note that if you execute your program with "
|
|
"mpirun to launch multiple local processes, you should not "
|
|
"call MPISetupPeers.";
|
|
return;
|
|
}
|
|
|
|
if (role == "server") {
|
|
// Open a port to accept connections.
|
|
char port_name[MPI_MAX_PORT_NAME] = {'\0'};
|
|
MPI_CHECK(MPI_Open_port(MPI_INFO_NULL, port_name));
|
|
VLOG(1) << "MPI server: port: " << port_name;
|
|
|
|
// Writes the port name to the file.
|
|
CHECK(WriteStringToFile(std::string(port_name), job_path.c_str()));
|
|
VLOG(1) << "MPI server: wrote to file: " << job_path;
|
|
|
|
int comm_size = MPICommSize(GlobalMPIComm());
|
|
while (comm_size < replicas) {
|
|
MPI_Comm icomm;
|
|
VLOG(1) << "MPI server: waiting for client "
|
|
<< "(" << comm_size << "/" << replicas << " have connected)";
|
|
MPI_CHECK(
|
|
MPI_Comm_accept(port_name, MPI_INFO_NULL, 0, MPI_COMM_SELF, &icomm));
|
|
VLOG(1) << "MPI server: accepted client";
|
|
MPI_Comm new_intra_comm = AssimilateComm(GlobalMPIComm(), icomm);
|
|
SetGlobalMPIComm(new_intra_comm);
|
|
comm_size = MPICommSize(new_intra_comm);
|
|
}
|
|
} else {
|
|
// Opens the job path file to obtain server address.
|
|
std::string port_name;
|
|
while (!ReadStringFromFile(job_path.c_str(), &port_name) ||
|
|
port_name.length() == 0) {
|
|
/* sleep override */
|
|
std::this_thread::sleep_for(std::chrono::seconds(1));
|
|
}
|
|
|
|
// Connect to server.
|
|
MPI_Comm icomm;
|
|
VLOG(1) << "MPI client: connecting to port: " << port_name;
|
|
MPI_CHECK(MPI_Comm_connect(
|
|
const_cast<char*>(port_name.c_str()),
|
|
MPI_INFO_NULL,
|
|
0,
|
|
GlobalMPIComm(),
|
|
&icomm));
|
|
|
|
VLOG(1) << "MPI client: connected";
|
|
|
|
// Join the server's reference intracommunicator.
|
|
MPI_Comm new_intra_comm = AssimilateComm(MPI_COMM_NULL, icomm);
|
|
SetGlobalMPIComm(new_intra_comm);
|
|
|
|
// Let other clients join the intracommunicator we're now a part of.
|
|
while (MPICommSize(GlobalMPIComm()) < replicas) {
|
|
MPI_Comm comm = AssimilateComm(GlobalMPIComm(), MPI_COMM_NULL);
|
|
SetGlobalMPIComm(comm);
|
|
}
|
|
}
|
|
|
|
// After all peers have assimilated, do a barrier.
|
|
MPI_Barrier(GlobalMPIComm());
|
|
VLOG(1) << "MPI using a communicator of size: "
|
|
<< MPICommSize(GlobalMPIComm());
|
|
}
|
|
|
|
} // namespace caffe2
|