mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: The previous NCCL all gather doesn't work as expected. This is a fully working async version. Tested on both C++ and Python Frontend. Multi-node: ``` tengli@learnfair042:~/new_pytorch/pytorch/torch/lib/build/c10d/test$ TMPFILE="/private/home/tengli/temp/tengli-test" RANK=0 WORLD_SIZE=2 ./ProcessGroupNCCLTest Multi-node world size: 2 rank: 0 Allreduce test successful Broadcast test successful Reduce test successful Allgather test successful tengli@learnfair117:~/new_pytorch/pytorch/torch/lib/build/c10d/test$ TMPFILE="/private/home/tengli/temp/tengli-test" RANK=1 WORLD_SIZE=2 ./ProcessGroupNCCLTest Multi-node world size: 2 rank: 1 Allreduce test successful Broadcast test successful Reduce test successful Allgather test successful ``` CI test: ``` test_set_get (__main__.FileStoreTest) ... ok test_set_get (__main__.PrefixFileStoreTest) ... ok test_set_get (__main__.PrefixTCPStoreTest) ... ok test_allreduce_ops (__main__.ProcessGroupGlooTest) ... ok test_broadcast_ops (__main__.ProcessGroupGlooTest) ... ok test_allgather_ops (__main__.ProcessGroupNCCLTest) ... ok test_allreduce_ops (__main__.ProcessGroupNCCLTest) ... ok test_broadcast_ops (__main__.ProcessGroupNCCLTest) ... ok test_reduce_ops (__main__.ProcessGroupNCCLTest) ... ok test_common_errors (__main__.RendezvousFileTest) ... ok test_nominal (__main__.RendezvousFileTest) ... ok test_common_errors (__main__.RendezvousTCPTest) ... ok test_nominal (__main__.RendezvousTCPTest) ... ok test_unknown_handler (__main__.RendezvousTest) ... ok test_set_get (__main__.TCPStoreTest) ... ok ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/10932 Differential Revision: D9542067 Pulled By: teng-li fbshipit-source-id: 25513eddcc3119fd736875d69dfb631b10f4ac86
66 lines
1.7 KiB
C++
66 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include <memory>
|
|
|
|
#include <nccl.h>
|
|
|
|
#define C10D_NCCL_CHECK(cmd) \
|
|
do { \
|
|
ncclResult_t error = cmd; \
|
|
if (error != ncclSuccess) { \
|
|
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
|
|
std::to_string(__LINE__) + ", " + \
|
|
std::string(ncclGetErrorString(error)); \
|
|
throw std::runtime_error(err); \
|
|
} \
|
|
} while (0)
|
|
|
|
namespace c10d {
|
|
|
|
// RAII wrapper for NCCL communicator
|
|
class NCCLComm {
|
|
public:
|
|
explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
|
|
|
|
NCCLComm() : NCCLComm(nullptr) {}
|
|
|
|
~NCCLComm() noexcept(false) {
|
|
if (ncclComm_) {
|
|
C10D_NCCL_CHECK(ncclCommDestroy(ncclComm_));
|
|
}
|
|
}
|
|
|
|
static std::shared_ptr<NCCLComm> create(
|
|
int numRanks,
|
|
int rank,
|
|
ncclUniqueId commId) {
|
|
auto comm = std::make_shared<NCCLComm>();
|
|
C10D_NCCL_CHECK(
|
|
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank));
|
|
return comm;
|
|
}
|
|
|
|
// Must not be copyable
|
|
NCCLComm(const NCCLComm&) = delete;
|
|
NCCLComm& operator=(const NCCLComm&) = delete;
|
|
|
|
// Move constructable
|
|
NCCLComm(NCCLComm&& other) {
|
|
std::swap(ncclComm_, other.ncclComm_);
|
|
}
|
|
// Move assignable
|
|
NCCLComm& operator=(NCCLComm&& other) {
|
|
std::swap(ncclComm_, other.ncclComm_);
|
|
return *this;
|
|
}
|
|
|
|
ncclComm_t getNcclComm() {
|
|
return ncclComm_;
|
|
}
|
|
|
|
protected:
|
|
ncclComm_t ncclComm_;
|
|
};
|
|
|
|
} // namespace c10d
|