mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR implements _allgather_base, reduce_scatter, and _reduce_scatter_base in the MPI backend (ProcessGroupMPI), enabling support for Fully Sharded Data Parallel (FSDP) in environments that use MPI for distributed communication. ### Context As noted in https://github.com/pytorch/pytorch/issues/85628, FSDP currently supports only the NCCL backend. Due to this limitation, FSDP cannot run on legacy HPC environments or clusters that rely on MPI. By implementing just these three collective operations, we can enable FSDP to work with the MPI backend. These collectives are implemented in a similar manner to existing operations such as allgather. ### Testing We validated this PR using pytorch/build/bin/ProcessGroupMPITest with OpenMPI, and all tests passed successfully. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150162 Approved by: https://github.com/H-Huang
1053 lines
34 KiB
C++
1053 lines
34 KiB
C++
#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
|
|
|
|
#ifdef USE_C10D_MPI
|
|
|
|
#include <iostream>
|
|
#include <map>
|
|
|
|
#include <c10/core/DeviceGuard.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
|
|
#if defined(OPEN_MPI) && OPEN_MPI
|
|
#include <mpi-ext.h> // Needed for CUDA-aware check
|
|
#endif
|
|
|
|
namespace c10d {
|
|
|
|
#define MPI_CHECK(cmd) \
|
|
do { \
|
|
int mpiStatus = cmd; \
|
|
if (mpiStatus != MPI_SUCCESS) { \
|
|
std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \
|
|
std::to_string(__LINE__) + \
|
|
", with error code: " + std::to_string(mpiStatus); \
|
|
TORCH_CHECK(false, err); \
|
|
} \
|
|
} while (0)
|
|
|
|
namespace {
|
|
|
|
// Op mapping
|
|
std::map<ReduceOp::RedOpType, MPI_Op> mpiOp = {
|
|
{ReduceOp::MIN, MPI_MIN},
|
|
{ReduceOp::MAX, MPI_MAX},
|
|
{ReduceOp::SUM, MPI_SUM},
|
|
{ReduceOp::PRODUCT, MPI_PROD},
|
|
};
|
|
// Type mapping
|
|
std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
|
|
{at::kByte, MPI_UNSIGNED_CHAR},
|
|
{at::kChar, MPI_CHAR},
|
|
{at::kDouble, MPI_DOUBLE},
|
|
{at::kFloat, MPI_FLOAT},
|
|
{at::kInt, MPI_INT},
|
|
{at::kLong, MPI_LONG},
|
|
{at::kShort, MPI_SHORT},
|
|
};
|
|
|
|
// Checking CUDA-aware MPI support, currently we only support CUDA aware
|
|
// MPI ops through Open MPI
|
|
bool cudaAwareMpiCheck() {
|
|
// Run time check
|
|
#if defined(MPIX_CUDA_AWARE_SUPPORT)
|
|
if (MPIX_Query_cuda_support() == 1) {
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
#else // !defined(MPIX_CUDA_AWARE_SUPPORT)
|
|
return false;
|
|
#endif // MPIX_CUDA_AWARE_SUPPORT
|
|
}
|
|
|
|
// Checking the input tensor's validity
|
|
void checkSingleTensorHelper(const at::Tensor& tensor) {
|
|
if (!tensor.is_contiguous()) {
|
|
TORCH_CHECK(false, "input tensor has to be contiguous");
|
|
}
|
|
if (tensor.is_sparse()) {
|
|
TORCH_CHECK(false, "input tensor has to be dense");
|
|
}
|
|
if (tensor.is_cuda() && !cudaAwareMpiCheck()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"CUDA tensor detected and the MPI used doesn't "
|
|
"have CUDA-aware MPI support");
|
|
}
|
|
}
|
|
|
|
void checkSingleTensor(const std::vector<at::Tensor>& tensors) {
|
|
if (tensors.size() != 1) {
|
|
TORCH_CHECK(
|
|
false, "MPI process group does not support multi-GPU collectives");
|
|
}
|
|
checkSingleTensorHelper(tensors[0]);
|
|
}
|
|
|
|
void checkSameSizeAndType(
|
|
const at::Tensor& t_in,
|
|
const std::vector<at::Tensor>& tensors) {
|
|
for (const auto& tensor : tensors) {
|
|
if ((tensor.numel() != t_in.numel()) ||
|
|
(tensor.scalar_type() != t_in.scalar_type())) {
|
|
TORCH_CHECK(false, "Tensors are not equal in size or data type");
|
|
}
|
|
checkSingleTensorHelper(tensor);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::vector<at::Tensor> ProcessGroupMPI::WorkMPI::result() {
|
|
return outputTensors_;
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupMPI::WorkMPI::getFuture() {
|
|
return future_;
|
|
}
|
|
|
|
void ProcessGroupMPI::WorkMPI::finishWorkMPIError(
|
|
const std::exception_ptr& eptr) {
|
|
future_->setError(eptr);
|
|
finish(eptr);
|
|
}
|
|
|
|
void ProcessGroupMPI::WorkMPI::finishWorkMPI() {
|
|
future_->markCompleted(at::IValue(outputTensors_));
|
|
finish();
|
|
}
|
|
|
|
ProcessGroupMPI::AsyncWork::AsyncWork(
|
|
MPI_Request request,
|
|
std::vector<at::Tensor> outputTensors,
|
|
const char* profilingTitle,
|
|
const std::optional<std::vector<at::Tensor>>& inputTensors)
|
|
: Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
|
|
outputTensors_(std::move(outputTensors)),
|
|
request_(request) {
|
|
memset(&status_, 0, sizeof(status_));
|
|
}
|
|
|
|
ProcessGroupMPI::AsyncWork::~AsyncWork() {
|
|
if (request_ != MPI_REQUEST_NULL) {
|
|
std::cerr
|
|
<< "Attempted destruction of AsyncWork before work has completed, "
|
|
<< "terminating the program." << '\n';
|
|
std::terminate();
|
|
}
|
|
}
|
|
|
|
bool ProcessGroupMPI::AsyncWork::isCompleted() {
|
|
if (request_ == MPI_REQUEST_NULL) {
|
|
return true;
|
|
}
|
|
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
int flag = 0;
|
|
MPI_CHECK(MPI_Test(&request_, &flag, &status_));
|
|
if (request_ != MPI_REQUEST_NULL) {
|
|
return false;
|
|
}
|
|
|
|
// request_ == MPI_REQUEST_NULL; the work has completed
|
|
// Populate exception if request was not successful
|
|
if (status_.MPI_ERROR != MPI_SUCCESS) {
|
|
populateException();
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ProcessGroupMPI::AsyncWork::isSuccess() const {
|
|
if (request_ != MPI_REQUEST_NULL) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Invalid call to AsyncWork::isSuccess before work has completed");
|
|
}
|
|
|
|
return status_.MPI_ERROR == MPI_SUCCESS;
|
|
}
|
|
|
|
int ProcessGroupMPI::AsyncWork::sourceRank() const {
|
|
return status_.MPI_SOURCE;
|
|
}
|
|
|
|
bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) {
|
|
if (request_ == MPI_REQUEST_NULL) {
|
|
// AsyncWork needs to manually call profiling end callbacks if they are set,
|
|
// since it does not call ProcessGroup::finish().
|
|
if (Work::recordFunctionEndCallback_) {
|
|
Work::recordFunctionEndCallback_();
|
|
Work::recordFunctionEndCallback_ = nullptr;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Wait(&request_, &status_));
|
|
auto ok = (status_.MPI_ERROR == MPI_SUCCESS);
|
|
|
|
// AsyncWork needs to manually call profiling end callbacks if they are set,
|
|
// since it does not call ProcessGroup::finish().
|
|
if (Work::recordFunctionEndCallback_) {
|
|
Work::recordFunctionEndCallback_();
|
|
Work::recordFunctionEndCallback_ = nullptr;
|
|
}
|
|
|
|
if (!ok) {
|
|
populateException();
|
|
std::rethrow_exception(exception_);
|
|
}
|
|
if (c10d::allow_inflight_collective_as_graph_input()) {
|
|
c10d::unregister_work(
|
|
c10::intrusive_ptr<
|
|
ProcessGroupMPI::AsyncWork>::unsafe_reclaim_from_nonowning(this));
|
|
}
|
|
// Always return true, because abort API is not implemented.
|
|
return true;
|
|
}
|
|
|
|
void ProcessGroupMPI::AsyncWork::abort(){
|
|
TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.")}
|
|
|
|
std::vector<at::Tensor> ProcessGroupMPI::AsyncWork::result() {
|
|
return outputTensors_;
|
|
}
|
|
|
|
void ProcessGroupMPI::AsyncWork::populateException() {
|
|
std::array<char, MPI_MAX_ERROR_STRING> buf{};
|
|
int len = buf.size();
|
|
MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len));
|
|
exception_ =
|
|
std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len)));
|
|
}
|
|
|
|
// Static global states
|
|
int ProcessGroupMPI::mpiThreadSupport_ = 0;
|
|
std::mutex ProcessGroupMPI::pgGlobalMutex_;
|
|
|
|
void ProcessGroupMPI::mpiExit() {
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Finalize());
|
|
}
|
|
|
|
void ProcessGroupMPI::initMPIOnce() {
|
|
// Initialize MPI environment. We only want to initialize once.
|
|
static bool init_mpi_flag [[maybe_unused]] = []() {
|
|
int mpi_was_initialized = 0;
|
|
MPI_CHECK(MPI_Initialized(&mpi_was_initialized));
|
|
if (mpi_was_initialized == 0) {
|
|
MPI_CHECK(MPI_Init_thread(
|
|
nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_));
|
|
if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Used MPI implementation doesn't have the "
|
|
"minimum level of threading support: "
|
|
"MPI_THREAD_SERIALIZED. This is required by "
|
|
"c10d package");
|
|
}
|
|
if (std::atexit(ProcessGroupMPI::mpiExit)) {
|
|
TORCH_CHECK(false, "Fail to register the MPI exit handler");
|
|
}
|
|
} else {
|
|
TORCH_WARN_ONCE("MPI was previously initialized.");
|
|
}
|
|
return true;
|
|
}();
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
|
|
std::vector<int> ranks) {
|
|
// Once initialization
|
|
initMPIOnce();
|
|
|
|
MPI_Comm groupComm = MPI_COMM_WORLD;
|
|
int rank = -1;
|
|
int size = -1;
|
|
|
|
{
|
|
std::lock_guard<std::mutex> globalLock(pgGlobalMutex_);
|
|
|
|
// If no ranks are specified, assume we're creating the root group
|
|
if (!ranks.empty()) {
|
|
MPI_Group worldGroup{};
|
|
MPI_Group ranksGroup{};
|
|
MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
|
|
MPI_CHECK(
|
|
MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
|
|
// `MPI_Comm_create` can be flaky in certain cases.
|
|
// See: https://github.com/pytorch/pytorch/issues/53899
|
|
constexpr int kMaxNumRetries = 3;
|
|
bool groupComm_updated = false;
|
|
MPI_Barrier(MPI_COMM_WORLD);
|
|
for (const auto i : c10::irange(kMaxNumRetries)) {
|
|
(void)i;
|
|
if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) {
|
|
groupComm_updated = true;
|
|
break;
|
|
}
|
|
}
|
|
MPI_CHECK(groupComm_updated);
|
|
MPI_CHECK(MPI_Group_free(&worldGroup));
|
|
MPI_CHECK(MPI_Group_free(&ranksGroup));
|
|
}
|
|
|
|
// Fetch rank and world size for this group (MPI_COMM_WORLD or new)
|
|
if (groupComm != MPI_COMM_NULL) {
|
|
MPI_CHECK(MPI_Comm_rank(groupComm, &rank));
|
|
MPI_CHECK(MPI_Comm_size(groupComm, &size));
|
|
|
|
if (rank < 0 || size < 0) {
|
|
TORCH_CHECK(false, "Failed to get the world_size / rank");
|
|
}
|
|
}
|
|
}
|
|
|
|
// If this process is not part of the group, we don't construct a
|
|
// process group instance. This is in line with the semantics of the
|
|
// other process group types.
|
|
if (groupComm == MPI_COMM_NULL) {
|
|
return c10::intrusive_ptr<ProcessGroupMPI>();
|
|
}
|
|
|
|
return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm);
|
|
}
|
|
|
|
ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
|
|
: Backend(rank, size), stop_(false), pgComm_(pgComm) {
|
|
if (pgComm_ == MPI_COMM_NULL) {
|
|
TORCH_CHECK(false, "pgComm_ must not be MPI_COMM_NULL");
|
|
}
|
|
|
|
// Start the worker thread accepting MPI calls
|
|
workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
|
|
|
|
init();
|
|
}
|
|
|
|
ProcessGroupMPI::~ProcessGroupMPI() {
|
|
destroy();
|
|
}
|
|
|
|
void ProcessGroupMPI::destroy() {
|
|
std::unique_lock<std::mutex> lock(pgMutex_);
|
|
queueConsumeCV_.wait(lock, [&] { return queue_.empty(); });
|
|
|
|
// Queue is empty, signal stop
|
|
stop_ = true;
|
|
|
|
// Release lock to allow threads to terminate
|
|
lock.unlock();
|
|
queueProduceCV_.notify_all();
|
|
|
|
// Join the single worker thread
|
|
workerThread_.join();
|
|
}
|
|
|
|
void ProcessGroupMPI::abort() {
|
|
destroy();
|
|
MPI_Abort(pgComm_, EXIT_FAILURE);
|
|
}
|
|
|
|
void ProcessGroupMPI::runLoop() {
|
|
std::unique_lock<std::mutex> lock(pgMutex_);
|
|
|
|
while (!stop_) {
|
|
if (queue_.empty()) {
|
|
queueProduceCV_.wait(lock);
|
|
continue;
|
|
}
|
|
|
|
auto workTuple = std::move(queue_.front());
|
|
|
|
queue_.pop_front();
|
|
|
|
auto& workEntry = std::get<0>(workTuple);
|
|
auto& work = std::get<1>(workTuple);
|
|
|
|
lock.unlock();
|
|
queueConsumeCV_.notify_one();
|
|
|
|
try {
|
|
workEntry->run(workEntry);
|
|
work->finishWorkMPI();
|
|
} catch (...) {
|
|
work->finishWorkMPIError(std::current_exception());
|
|
}
|
|
|
|
lock.lock();
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::enqueue(
|
|
std::unique_ptr<WorkEntry> entry,
|
|
const char* profilingTitle,
|
|
const std::optional<std::vector<at::Tensor>>& inputTensors) {
|
|
auto work =
|
|
c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors);
|
|
std::unique_lock<std::mutex> lock(pgMutex_);
|
|
queue_.emplace_back(std::move(entry), work);
|
|
lock.unlock();
|
|
queueProduceCV_.notify_one();
|
|
return work;
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::broadcast(
|
|
std::vector<at::Tensor>& tensors,
|
|
const BroadcastOptions& opts) {
|
|
checkSingleTensor(tensors);
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[opts, this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto data = (entry->src)[0];
|
|
c10::DeviceGuard guard(data.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Bcast(
|
|
data.data_ptr(),
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
opts.rootRank,
|
|
pgComm_));
|
|
};
|
|
auto entry =
|
|
std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:broadcast",
|
|
std::optional<std::vector<at::Tensor>>(tensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce(
|
|
std::vector<at::Tensor>& tensors,
|
|
const AllreduceOptions& opts) {
|
|
checkSingleTensor(tensors);
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[opts, this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto data = (entry->src)[0];
|
|
c10::DeviceGuard guard(data.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Allreduce(
|
|
MPI_IN_PLACE,
|
|
data.data_ptr(),
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
mpiOp.at(opts.reduceOp),
|
|
pgComm_));
|
|
};
|
|
auto entry =
|
|
std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:all_reduce",
|
|
std::optional<std::vector<at::Tensor>>(tensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce_coalesced(
|
|
std::vector<at::Tensor>& tensors,
|
|
const AllreduceCoalescedOptions& opts) {
|
|
TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::reduce(
|
|
std::vector<at::Tensor>& tensors,
|
|
const ReduceOptions& opts) {
|
|
checkSingleTensor(tensors);
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[opts, this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto data = (entry->src)[0];
|
|
auto dataPtr = (entry->src)[0].data_ptr();
|
|
void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
|
|
void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr;
|
|
|
|
c10::DeviceGuard guard(data.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Reduce(
|
|
sendbuf,
|
|
recvbuf,
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
mpiOp.at(opts.reduceOp),
|
|
opts.rootRank,
|
|
pgComm_));
|
|
};
|
|
auto entry =
|
|
std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:reduce",
|
|
std::optional<std::vector<at::Tensor>>(tensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::allgather(
|
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllgatherOptions& opts) {
|
|
checkSingleTensor(inputTensors);
|
|
if (outputTensors.size() != 1) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"MPI process group only supports a single "
|
|
"tensor op");
|
|
}
|
|
if (static_cast<size_t>(size_) != outputTensors[0].size()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"All gather: number of output tensors should equal "
|
|
"to the world size");
|
|
}
|
|
|
|
checkSameSizeAndType(inputTensors[0], outputTensors[0]);
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto data = (entry->src)[0];
|
|
std::vector<at::Tensor> outputDataVec = entry->dst;
|
|
auto flatOutputTensor = newLikeFlat(outputDataVec);
|
|
|
|
c10::DeviceGuard guard(data.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Allgather(
|
|
data.data_ptr(),
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
flatOutputTensor.data_ptr(),
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
pgComm_));
|
|
|
|
for (const auto i : c10::irange(outputDataVec.size())) {
|
|
outputDataVec[i].copy_(flatOutputTensor[static_cast<int64_t>(i)]);
|
|
}
|
|
};
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors, &outputTensors[0], std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:all_gather",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::allgather_coalesced(
|
|
std::vector<std::vector<at::Tensor>>& /* unused */,
|
|
std::vector<at::Tensor>& /* unused */,
|
|
const AllgatherOptions& /* unused */) {
|
|
TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced");
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::gather(
|
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const GatherOptions& opts) {
|
|
checkSingleTensor(inputTensors);
|
|
|
|
if (rank_ != opts.rootRank) {
|
|
if (!outputTensors.empty()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Gather: number of output tensors should be 0 "
|
|
"for non-root");
|
|
}
|
|
} else {
|
|
if (outputTensors.size() != 1) {
|
|
TORCH_CHECK(false, "Gather: multi-GPU collective is not supported");
|
|
}
|
|
if (static_cast<size_t>(size_) != outputTensors[0].size()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Gather: number of output tensors should equal "
|
|
"to the world size");
|
|
}
|
|
checkSameSizeAndType(inputTensors[0], outputTensors[0]);
|
|
}
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[opts, this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto data = (entry->src)[0];
|
|
void* recvbuf = nullptr;
|
|
at::Tensor flatOutputTensor;
|
|
|
|
std::vector<at::Tensor> dstdata = entry->dst;
|
|
if (rank_ == opts.rootRank) {
|
|
flatOutputTensor = newLikeFlat(dstdata);
|
|
recvbuf = flatOutputTensor.data_ptr();
|
|
}
|
|
|
|
c10::DeviceGuard guard(data.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Gather(
|
|
data.data_ptr(),
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
recvbuf,
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
opts.rootRank,
|
|
pgComm_));
|
|
|
|
if (rank_ == opts.rootRank) {
|
|
const std::vector<at::Tensor>& outputDataVec = entry->dst;
|
|
// copy the flattened output tensors to the outputs
|
|
for (const auto i : c10::irange(outputDataVec.size())) {
|
|
outputDataVec.at(i).copy_(
|
|
flatOutputTensor[static_cast<int64_t>(i)]);
|
|
}
|
|
}
|
|
};
|
|
|
|
if (rank_ == opts.rootRank) {
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors, &outputTensors[0], std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:gather",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors));
|
|
} else {
|
|
auto entry =
|
|
std::make_unique<WorkEntry>(&inputTensors, nullptr, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:gather",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors));
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::scatter(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
|
const ScatterOptions& opts) {
|
|
checkSingleTensor(outputTensors);
|
|
|
|
if (rank_ != opts.rootRank) {
|
|
if (!inputTensors.empty()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Scatter: number of input tensors should be 0 "
|
|
"for non-root");
|
|
}
|
|
} else {
|
|
if (inputTensors.size() != 1) {
|
|
TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported");
|
|
}
|
|
if (static_cast<size_t>(size_) != inputTensors[0].size()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Scatter: number of input tensors should equal "
|
|
"to the world size");
|
|
}
|
|
checkSameSizeAndType(outputTensors[0], inputTensors[0]);
|
|
}
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[opts, this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto data = (entry->dst)[0];
|
|
void* sendbuf = nullptr;
|
|
at::Tensor flatInputTensor;
|
|
|
|
if (rank_ == opts.rootRank) {
|
|
std::vector<at::Tensor>& inputDataVec = entry->src;
|
|
flatInputTensor = newLikeFlat(inputDataVec);
|
|
sendbuf = flatInputTensor.data_ptr();
|
|
|
|
// copy the input tensors to the flatten large send buffer
|
|
for (const auto i : c10::irange(inputDataVec.size())) {
|
|
flatInputTensor[static_cast<int64_t>(i)].copy_(inputDataVec.at(i));
|
|
}
|
|
}
|
|
|
|
c10::DeviceGuard guard(data.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Scatter(
|
|
sendbuf,
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
data.data_ptr(),
|
|
data.numel(),
|
|
mpiDatatype.at(data.scalar_type()),
|
|
opts.rootRank,
|
|
pgComm_));
|
|
};
|
|
|
|
if (rank_ == opts.rootRank) {
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors[0], &outputTensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:scatter",
|
|
!inputTensors.empty()
|
|
? std::optional<std::vector<at::Tensor>>(inputTensors[0])
|
|
: std::nullopt);
|
|
} else {
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
nullptr, &outputTensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:scatter",
|
|
!inputTensors.empty()
|
|
? std::optional<std::vector<at::Tensor>>(inputTensors[0])
|
|
: std::nullopt);
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
|
const ReduceScatterOptions& opts) {
|
|
checkSingleTensor(outputTensors);
|
|
if (inputTensors.size() != 1) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"MPI process group only supports a single "
|
|
"tensor op");
|
|
}
|
|
if (static_cast<size_t>(size_) != inputTensors[0].size()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Reduce scatter: number of input tensors should equal "
|
|
"to the world size");
|
|
}
|
|
checkSameSizeAndType(outputTensors[0], inputTensors[0]);
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[opts, this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto data = (entry->dst)[0];
|
|
auto flatInputTensor = newLikeFlat(entry->src);
|
|
for (const auto i : c10::irange(entry->src.size())) {
|
|
flatInputTensor[static_cast<int64_t>(i)].copy_(entry->src[i]);
|
|
}
|
|
int recvcount = flatInputTensor.numel() / size_;
|
|
|
|
c10::DeviceGuard guard(data.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Reduce_scatter_block(
|
|
flatInputTensor.data_ptr(),
|
|
data.data_ptr(),
|
|
recvcount,
|
|
mpiDatatype.at(data.scalar_type()),
|
|
mpiOp.at(opts.reduceOp),
|
|
pgComm_));
|
|
};
|
|
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors[0], &outputTensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:reduce_scatter",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors[0]));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
|
|
at::Tensor& outputTensor,
|
|
at::Tensor& inputTensor,
|
|
std::vector<int64_t>& outputSplitSizes,
|
|
std::vector<int64_t>& inputSplitSizes,
|
|
const AllToAllOptions& opts) {
|
|
checkSingleTensorHelper(inputTensor);
|
|
checkSingleTensorHelper(outputTensor);
|
|
|
|
if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
|
|
// We can use alltoall
|
|
TORCH_CHECK(
|
|
outputTensor.numel() == inputTensor.numel() &&
|
|
outputTensor.type() == inputTensor.type(),
|
|
"Tensors are not equal in size or data type");
|
|
TORCH_CHECK(
|
|
outputTensor.size(0) % size_ == 0,
|
|
"Tensor's dim 0 does not divide equally across group size");
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto srcdata = (entry->src)[0];
|
|
auto dstdata = (entry->dst)[0];
|
|
c10::DeviceGuard guard(srcdata.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Alltoall(
|
|
srcdata.data_ptr(),
|
|
srcdata.numel() / size_,
|
|
mpiDatatype.at(srcdata.scalar_type()),
|
|
dstdata.data_ptr(),
|
|
dstdata.numel() / size_,
|
|
mpiDatatype.at(dstdata.scalar_type()),
|
|
pgComm_));
|
|
};
|
|
std::vector<at::Tensor> inputTensors = {inputTensor};
|
|
std::vector<at::Tensor> outputTensors = {outputTensor};
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors, &outputTensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:all_to_all",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors));
|
|
} else {
|
|
// Need alltoallv
|
|
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
|
|
c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[this, inputSplitSizes, outputSplitSizes](
|
|
std::unique_ptr<WorkEntry>& entry) {
|
|
auto srcdata = (entry->src)[0];
|
|
auto dstdata = (entry->dst)[0];
|
|
std::vector<int> send_lengths(size_);
|
|
std::vector<int> recv_lengths(size_);
|
|
std::vector<int> send_offsets(size_);
|
|
std::vector<int> recv_offsets(size_);
|
|
c10d::computeLengthsAndOffsets(
|
|
inputSplitSizes, srcdata, &send_lengths, &send_offsets);
|
|
c10d::computeLengthsAndOffsets(
|
|
outputSplitSizes, dstdata, &recv_lengths, &recv_offsets);
|
|
c10::DeviceGuard guard(srcdata.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Alltoallv(
|
|
srcdata.data_ptr(),
|
|
send_lengths.data(),
|
|
send_offsets.data(),
|
|
mpiDatatype.at(srcdata.scalar_type()),
|
|
dstdata.data_ptr(),
|
|
recv_lengths.data(),
|
|
recv_offsets.data(),
|
|
mpiDatatype.at(dstdata.scalar_type()),
|
|
pgComm_));
|
|
};
|
|
std::vector<at::Tensor> inputTensors = {inputTensor};
|
|
std::vector<at::Tensor> outputTensors = {outputTensor};
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors, &outputTensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:all_to_all",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors));
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllToAllOptions& opts) {
|
|
TORCH_CHECK(
|
|
inputTensors.size() == static_cast<size_t>(size_),
|
|
"Number of input tensors are not equal to group size");
|
|
TORCH_CHECK(
|
|
outputTensors.size() == static_cast<size_t>(size_),
|
|
"Number of output tensors are not equal to group size");
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[this](std::unique_ptr<WorkEntry>& entry) {
|
|
std::vector<int> send_lengths(size_);
|
|
std::vector<int> recv_lengths(size_);
|
|
std::vector<int> send_offsets(size_);
|
|
std::vector<int> recv_offsets(size_);
|
|
auto srcdata = entry->src;
|
|
auto dstdata = entry->dst;
|
|
auto src_len = c10d::computeLengthsAndOffsets(
|
|
srcdata, &send_lengths, &send_offsets);
|
|
auto dst_len = c10d::computeLengthsAndOffsets(
|
|
dstdata, &recv_lengths, &recv_offsets);
|
|
std::vector<int64_t> send_lengthsL(
|
|
send_lengths.begin(), send_lengths.end());
|
|
std::vector<int64_t> recv_lengthsL(
|
|
recv_lengths.begin(), recv_lengths.end());
|
|
at::Tensor srcFlatData =
|
|
at::empty({static_cast<int64_t>(src_len)}, srcdata[0].options());
|
|
at::Tensor dstFlatData =
|
|
at::empty({static_cast<int64_t>(dst_len)}, dstdata[0].options());
|
|
auto srcFlatDataSplits =
|
|
srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0);
|
|
for (const auto i : c10::irange(size_)) {
|
|
srcFlatDataSplits[i].copy_(srcdata[i].view({-1}));
|
|
}
|
|
c10::DeviceGuard guard1(srcdata[0].device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Alltoallv(
|
|
srcFlatData.data_ptr(),
|
|
send_lengths.data(),
|
|
send_offsets.data(),
|
|
mpiDatatype.at(srcdata[0].scalar_type()),
|
|
dstFlatData.data_ptr(),
|
|
recv_lengths.data(),
|
|
recv_offsets.data(),
|
|
mpiDatatype.at(dstdata[0].scalar_type()),
|
|
pgComm_));
|
|
|
|
auto dstFlatDataSplits =
|
|
dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0);
|
|
for (const auto i : c10::irange(size_)) {
|
|
dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]);
|
|
}
|
|
};
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors, &outputTensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:all_to_all",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::send(
|
|
std::vector<at::Tensor>& tensors,
|
|
int dstRank,
|
|
int tag) {
|
|
checkSingleTensor(tensors);
|
|
|
|
auto& tensor = tensors[0];
|
|
MPI_Request request = MPI_REQUEST_NULL;
|
|
|
|
{
|
|
c10::DeviceGuard guard(tensor.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Isend(
|
|
tensor.data_ptr(),
|
|
tensor.numel(),
|
|
mpiDatatype.at(tensor.scalar_type()),
|
|
dstRank,
|
|
tag,
|
|
pgComm_,
|
|
&request));
|
|
}
|
|
|
|
return c10::make_intrusive<AsyncWork>(
|
|
request,
|
|
std::vector<at::Tensor>(),
|
|
"mpi:send",
|
|
std::optional<std::vector<at::Tensor>>(tensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::recv(
|
|
std::vector<at::Tensor>& tensors,
|
|
int srcRank,
|
|
int tag) {
|
|
checkSingleTensor(tensors);
|
|
|
|
auto& tensor = tensors[0];
|
|
MPI_Request request = MPI_REQUEST_NULL;
|
|
|
|
{
|
|
c10::DeviceGuard guard(tensor.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Irecv(
|
|
tensor.data_ptr(),
|
|
tensor.numel(),
|
|
mpiDatatype.at(tensor.scalar_type()),
|
|
srcRank,
|
|
tag,
|
|
pgComm_,
|
|
&request));
|
|
}
|
|
|
|
return c10::make_intrusive<AsyncWork>(
|
|
request,
|
|
tensors,
|
|
"mpi:recv",
|
|
std::optional<std::vector<at::Tensor>>(tensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::recvAnysource(
|
|
std::vector<at::Tensor>& tensors,
|
|
int tag) {
|
|
checkSingleTensor(tensors);
|
|
|
|
auto& tensor = tensors[0];
|
|
MPI_Request request = MPI_REQUEST_NULL;
|
|
|
|
{
|
|
c10::DeviceGuard guard(tensor.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Irecv(
|
|
tensor.data_ptr(),
|
|
tensor.numel(),
|
|
mpiDatatype.at(tensor.scalar_type()),
|
|
MPI_ANY_SOURCE,
|
|
tag,
|
|
pgComm_,
|
|
&request));
|
|
}
|
|
|
|
return c10::make_intrusive<AsyncWork>(
|
|
request,
|
|
tensors,
|
|
"mpi:recvAnySource",
|
|
std::optional<std::vector<at::Tensor>>(tensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) {
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[this](std::unique_ptr<WorkEntry>& entry) {
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Barrier(pgComm_));
|
|
};
|
|
auto entry =
|
|
std::make_unique<WorkEntry>(nullptr, nullptr, std::move(runFunc));
|
|
return enqueue(std::move(entry), "mpi:barrier", std::nullopt);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base(
|
|
at::Tensor& outputTensor,
|
|
at::Tensor& inputTensor,
|
|
const AllgatherOptions& opts) {
|
|
TORCH_CHECK(
|
|
outputTensor.numel() == inputTensor.numel() * size_,
|
|
"All gather: output tensor size must be equal to input tensor size times the world size");
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto dstdata = (entry->dst)[0];
|
|
auto srcdata = (entry->src)[0];
|
|
c10::DeviceGuard guard(srcdata.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Allgather(
|
|
srcdata.data_ptr(),
|
|
srcdata.numel(),
|
|
mpiDatatype.at(srcdata.scalar_type()),
|
|
dstdata.data_ptr(),
|
|
srcdata.numel(),
|
|
mpiDatatype.at(dstdata.scalar_type()),
|
|
pgComm_));
|
|
};
|
|
|
|
auto inputTensors = std::vector<at::Tensor>({inputTensor});
|
|
auto outputTensors = std::vector<at::Tensor>({outputTensor});
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors, &outputTensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:_allgather_base",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors));
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> ProcessGroupMPI::_reduce_scatter_base(
|
|
at::Tensor& outputTensor,
|
|
at::Tensor& inputTensor,
|
|
const ReduceScatterOptions& opts) {
|
|
TORCH_CHECK(
|
|
outputTensor.numel() * size_ == inputTensor.numel(),
|
|
"Reduce scatter: input tensor size must be equal to output tensor size times the world size");
|
|
|
|
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
|
[opts, this](std::unique_ptr<WorkEntry>& entry) {
|
|
auto dstdata = (entry->dst)[0];
|
|
auto srcdata = (entry->src)[0];
|
|
c10::DeviceGuard guard(srcdata.device());
|
|
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
|
|
MPI_CHECK(MPI_Reduce_scatter_block(
|
|
srcdata.data_ptr(),
|
|
dstdata.data_ptr(),
|
|
dstdata.numel(),
|
|
mpiDatatype.at(srcdata.scalar_type()),
|
|
mpiOp.at(opts.reduceOp),
|
|
pgComm_));
|
|
};
|
|
|
|
auto inputTensors = std::vector<at::Tensor>({inputTensor});
|
|
auto outputTensors = std::vector<at::Tensor>({outputTensor});
|
|
auto entry = std::make_unique<WorkEntry>(
|
|
&inputTensors, &outputTensors, std::move(runFunc));
|
|
return enqueue(
|
|
std::move(entry),
|
|
"mpi:_reduce_scatter_base",
|
|
std::optional<std::vector<at::Tensor>>(inputTensors));
|
|
}
|
|
|
|
} // namespace c10d
|
|
|
|
#endif // USE_C10D_MPI
|