pytorch/torch/lib/c10d/NCCLUtils.hpp
Teng Li a88463cd9a Working async version of AllGather, test fix and compiler warnings, and CI (#10932)
Summary:
The previous NCCL all gather doesn't work as expected. This is a fully working async version.  Tested on both C++ and Python Frontend.

Multi-node:
```
tengli@learnfair042:~/new_pytorch/pytorch/torch/lib/build/c10d/test$ TMPFILE="/private/home/tengli/temp/tengli-test" RANK=0 WORLD_SIZE=2 ./ProcessGroupNCCLTest
Multi-node world size: 2 rank: 0
Allreduce test successful
Broadcast test successful
Reduce test successful
Allgather test successful

tengli@learnfair117:~/new_pytorch/pytorch/torch/lib/build/c10d/test$ TMPFILE="/private/home/tengli/temp/tengli-test" RANK=1 WORLD_SIZE=2 ./ProcessGroupNCCLTest
Multi-node world size: 2 rank: 1
Allreduce test successful
Broadcast test successful
Reduce test successful
Allgather test successful
```

CI test:
```
test_set_get (__main__.FileStoreTest) ... ok
test_set_get (__main__.PrefixFileStoreTest) ... ok
test_set_get (__main__.PrefixTCPStoreTest) ... ok
test_allreduce_ops (__main__.ProcessGroupGlooTest) ... ok
test_broadcast_ops (__main__.ProcessGroupGlooTest) ... ok
test_allgather_ops (__main__.ProcessGroupNCCLTest) ... ok
test_allreduce_ops (__main__.ProcessGroupNCCLTest) ... ok
test_broadcast_ops (__main__.ProcessGroupNCCLTest) ... ok
test_reduce_ops (__main__.ProcessGroupNCCLTest) ... ok
test_common_errors (__main__.RendezvousFileTest) ... ok
test_nominal (__main__.RendezvousFileTest) ... ok
test_common_errors (__main__.RendezvousTCPTest) ... ok
test_nominal (__main__.RendezvousTCPTest) ... ok
test_unknown_handler (__main__.RendezvousTest) ... ok
test_set_get (__main__.TCPStoreTest) ... ok
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10932

Differential Revision: D9542067

Pulled By: teng-li

fbshipit-source-id: 25513eddcc3119fd736875d69dfb631b10f4ac86
2018-08-28 12:40:14 -07:00

66 lines
1.7 KiB
C++

#pragma once
#include <memory>
#include <nccl.h>
#define C10D_NCCL_CHECK(cmd) \
do { \
ncclResult_t error = cmd; \
if (error != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + \
std::string(ncclGetErrorString(error)); \
throw std::runtime_error(err); \
} \
} while (0)
namespace c10d {
// RAII wrapper for NCCL communicator
class NCCLComm {
public:
explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
NCCLComm() : NCCLComm(nullptr) {}
~NCCLComm() noexcept(false) {
if (ncclComm_) {
C10D_NCCL_CHECK(ncclCommDestroy(ncclComm_));
}
}
static std::shared_ptr<NCCLComm> create(
int numRanks,
int rank,
ncclUniqueId commId) {
auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank));
return comm;
}
// Must not be copyable
NCCLComm(const NCCLComm&) = delete;
NCCLComm& operator=(const NCCLComm&) = delete;
// Move constructable
NCCLComm(NCCLComm&& other) {
std::swap(ncclComm_, other.ncclComm_);
}
// Move assignable
NCCLComm& operator=(NCCLComm&& other) {
std::swap(ncclComm_, other.ncclComm_);
return *this;
}
ncclComm_t getNcclComm() {
return ncclComm_;
}
protected:
ncclComm_t ncclComm_;
};
} // namespace c10d