[PGNCCL] Move NCCLComm impl to cpp (#142826)

BE as titled. No behavior change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142826
Approved by: https://github.com/wconstab, https://github.com/c-p-i-o
This commit is contained in:
Ke Wen 2024-12-10 17:29:08 -08:00 committed by PyTorch MergeBot
parent 06075d3d18
commit cb354f8b47
2 changed files with 258 additions and 212 deletions

View File

@ -8,6 +8,87 @@
namespace c10d {
NCCLComm::NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
NCCLComm::~NCCLComm() noexcept {
// (kwen2501) Making CUDA/NCCL calls in this destructor can hit CUDA driver
// shutdown error if CUDA context has exited first. Thus, we are not
// destroying or aborting NCCL communicators here. We just detect and warn
// about the risk of memory leak. Normally, a user would have called
// `destroy_process_group` or `abort_process_group`, and such risk would be
// avoided.
LockType lock(mutex_);
if (ncclComm_ && initialized_ && !aborted_) {
TORCH_WARN_ONCE(
"WARNING: NCCL communicator hasn't been destroyed. This may cause "
"memory leaks. To avoid the risk, you can call `destroy_process_group` "
"during normal exit or `_abort_process_group` when handling failures.")
}
}
// NOLINTNEXTLINE(*-noexcept-move-*)
NCCLComm::NCCLComm(NCCLComm&& other) {
// Using other's lock, as it reads other's states
// Can not use this.mutex_, as this object is being constructed.
LockType lock(other.mutex_);
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
std::swap(initialized_, other.initialized_);
std::swap(nonBlocking_, other.nonBlocking_);
std::swap(deviceIndex_, other.deviceIndex_);
}
ncclUniqueId NCCLComm::getNcclId() {
return ncclId_;
}
std::shared_ptr<NCCLComm> NCCLComm::create(
int numRanks,
int rank,
ncclUniqueId commId,
at::DeviceIndex deviceIndex) {
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank),
std::nullopt);
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->deviceIndex_ = deviceIndex;
comm->initialized_ = true;
// Old style comm is always blocking.
comm->nonBlocking_ = false;
return comm;
}
#ifdef NCCL_HAS_COMM_NONBLOCKING
std::shared_ptr<NCCLComm> NCCLComm::create(
int numRanks,
int rank,
ncclUniqueId commId,
at::DeviceIndex deviceIndex,
ncclConfig_t& config) {
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
auto comm = std::make_shared<NCCLComm>();
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
<< (comm->nonBlocking_ ? "nonblocking" : "blocking");
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
std::nullopt);
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->deviceIndex_ = deviceIndex;
// Under blocking mode, comm is initialized immediately after NCCL init
// returns; Under nonblocking mode, we check whether comm is initialized the
// *next* time ncclComm_ is accessed.
comm->initialized_ = !comm->nonBlocking_;
return comm;
}
#endif
ncclComm_t NCCLComm::getNcclComm() {
LockType lock(mutex_);
if (aborted_) {
@ -56,6 +137,11 @@ void NCCLComm::waitReady(bool longInterval) {
}
}
std::optional<std::string> NCCLComm::getNcclCommFailureReason() const {
LockType lock(mutex_);
return commFailureReason_;
}
// TODO: why do we have `!defined(FBCODE_CAFFE2)` here?
#if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2)
// last argument to split() API is not used to support
@ -147,6 +233,162 @@ void NCCLComm::destroy() {
aborted_ = true;
}
void NCCLComm::abort(std::optional<std::string> commFailureReason) {
LockType lock(mutex_);
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_ && !initialized_) {
// Should not abort twice.
return;
}
#ifdef NCCL_HAS_COMM_REGISTER
// Deregister all registered segments before aborting.
for (auto& it : registeredSegmentHandles_) {
void* handle = it.second;
C10D_NCCL_CHECK(
::ncclCommDeregister(ncclComm_, handle),
c10::str(
"Failed to deregister segment handle ",
handle,
" on ncclComm_ ",
ncclComm_));
}
registeredSegmentHandles_.clear();
#endif
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
// timeout)
commFailureReason_ = commFailureReason;
LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
<< (commFailureReason ? *commFailureReason
: "No abort reason provided.");
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
#else
C10D_NCCL_CHECK_TIMEOUT(
::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_);
#endif
aborted_ = true;
ncclComm_ = nullptr;
// Set an appropriate error so that we avoid using the communicator.
if (ncclAsyncErr_ == ncclSuccess) {
ncclAsyncErr_ = ncclSystemError;
}
#else
// This is a NOOP, if error checks are disabled.
return;
#endif
}
bool NCCLComm::isInitialized() const {
LockType lock(mutex_);
return initialized_;
}
bool NCCLComm::isAborted() const {
LockType lock(mutex_);
return aborted_;
}
uint64_t NCCLComm::getCommSplitCounter() const {
return ncclCommSplitCounter_;
}
ncclResult_t NCCLComm::checkForNcclError() {
LockType lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
}
C10D_NCCL_CHECK(
ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
return ncclAsyncErr_;
#else
// Always return success, if error checks are disabled.
return ncclSuccess;
#endif
}
ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) {
LockType lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
// We register only segments from cache allocator
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
// maps to a unique handle and should not be registered before the current
// ptr is deregistered and freed.
TORCH_CHECK(
registeredSegmentHandles_.count(ptr) == 0,
"Segment with ptr ",
ptr,
" has already been registered on ncclComm_ ",
ncclComm_);
void* handle = nullptr;
// Use getNcclComm to make sure comm is ready before calling nccl APIs
auto comm = getNcclComm();
C10D_NCCL_CHECK(
ncclCommRegister(comm, ptr, size, &handle),
c10::str(
"Failed to register segment with ptr ",
ptr,
", size ",
size,
" on ncclComm_ ",
comm));
registeredSegmentHandles_[ptr] = handle;
return ncclSuccess;
#else
return ncclInvalidUsage;
#endif
}
ncclResult_t NCCLComm::deregisterSegment(void* ptr) {
LockType lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
TORCH_CHECK(
registeredSegmentHandles_.count(ptr) == 1,
"Segment with ptr ",
ptr,
" is not registered on ncclComm_ ",
ncclComm_);
void* handle = registeredSegmentHandles_[ptr];
// Use getNcclComm to make sure comm is ready before calling nccl APIs
auto comm = getNcclComm();
C10D_NCCL_CHECK(
ncclCommDeregister(comm, handle),
c10::str(
"Failed to deregister segment handle ",
handle,
", with ptr ",
ptr,
" on ncclComm_ ",
comm));
registeredSegmentHandles_.erase(ptr);
return ncclSuccess;
#else
return ncclInvalidUsage;
#endif
}
std::string NCCLComm::repr() const {
return c10::str((void*)ncclComm_);
}
#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
std::unordered_map<std::string, std::string> NCCLComm::ncclCommDump() {
std::unordered_map<std::string, std::string> dump;
if (isAborted()) {
LOG(INFO) << "Communicator was aborted before trying to dump its state.";
return dump;
}
C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt);
return dump;
}
#endif
std::string getNcclVersion() {
static c10::once_flag ncclGetVersionFlag;
static std::string versionString;

View File

@ -214,44 +214,17 @@ class NCCLComm {
using LockType = std::unique_lock<MutexType>;
public:
explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
explicit NCCLComm(ncclComm_t ncclComm);
NCCLComm() = default;
~NCCLComm() noexcept {
// (kwen2501) Making CUDA/NCCL calls in this destructor can hit CUDA driver
// shutdown error if CUDA context has exited first. Thus, we are not
// destroying or aborting NCCL communicators here. We just detect and warn
// about the risk of memory leak. Normally, a user would have called
// `destroy_process_group` or `abort_process_group`, and such risk would be
// avoided.
LockType lock(mutex_);
if (ncclComm_ && initialized_ && !aborted_) {
TORCH_WARN_ONCE(
"WARNING: NCCL communicator hasn't been destroyed. This may cause "
"memory leaks. To avoid the risk, you can call `destroy_process_group` "
"during normal exit or `_abort_process_group` when handling failures.")
}
}
~NCCLComm() noexcept;
static std::shared_ptr<NCCLComm> create(
int numRanks,
int rank,
ncclUniqueId commId,
at::DeviceIndex deviceIndex) {
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank),
std::nullopt);
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->deviceIndex_ = deviceIndex;
comm->initialized_ = true;
// Old style comm is always blocking.
comm->nonBlocking_ = false;
return comm;
}
at::DeviceIndex deviceIndex);
#ifdef NCCL_HAS_COMM_NONBLOCKING
static std::shared_ptr<NCCLComm> create(
@ -259,25 +232,7 @@ class NCCLComm {
int rank,
ncclUniqueId commId,
at::DeviceIndex deviceIndex,
ncclConfig_t& config) {
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
auto comm = std::make_shared<NCCLComm>();
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
<< (comm->nonBlocking_ ? "nonblocking" : "blocking");
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
std::nullopt);
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->deviceIndex_ = deviceIndex;
// Under blocking mode, comm is initialized immediately after NCCL init
// returns; Under nonblocking mode, we check whether comm is initialized the
// *next* time ncclComm_ is accessed.
comm->initialized_ = !comm->nonBlocking_;
return comm;
}
ncclConfig_t& config);
static std::shared_ptr<NCCLComm> split(
NCCLComm* source,
@ -288,20 +243,10 @@ class NCCLComm {
#endif
#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
std::unordered_map<std::string, std::string> ncclCommDump() {
std::unordered_map<std::string, std::string> dump;
if (isAborted()) {
LOG(INFO) << "Communicator was aborted before trying to dump its state.";
return dump;
}
C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt);
return dump;
}
std::unordered_map<std::string, std::string> ncclCommDump();
#endif
ncclUniqueId getNcclId() {
return ncclId_;
}
ncclUniqueId getNcclId();
// Must not be copyable
NCCLComm(const NCCLComm&) = delete;
@ -312,17 +257,7 @@ class NCCLComm {
// Move constructable
// NOLINTNEXTLINE(*-noexcept-move-*)
NCCLComm(NCCLComm&& other) {
// Using other's lock, as it reads other's states
// Can not use this.mutex_, as this object is being constructed.
LockType lock(other.mutex_);
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
std::swap(initialized_, other.initialized_);
std::swap(nonBlocking_, other.nonBlocking_);
std::swap(deviceIndex_, other.deviceIndex_);
}
NCCLComm(NCCLComm&& other);
ncclComm_t getNcclComm();
@ -337,59 +272,9 @@ class NCCLComm {
// ncclSuccess.
void waitReady(bool longInterval);
std::optional<std::string> getNcclCommFailureReason() const {
LockType lock(mutex_);
return commFailureReason_;
}
std::optional<std::string> getNcclCommFailureReason() const;
void abort(std::optional<std::string> commFailureReason = std::nullopt) {
LockType lock(mutex_);
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_ && !initialized_) {
// Should not abort twice.
return;
}
#ifdef NCCL_HAS_COMM_REGISTER
// Deregister all registered segments before aborting.
for (auto& it : registeredSegmentHandles_) {
void* handle = it.second;
C10D_NCCL_CHECK(
::ncclCommDeregister(ncclComm_, handle),
c10::str(
"Failed to deregister segment handle ",
handle,
" on ncclComm_ ",
ncclComm_));
}
registeredSegmentHandles_.clear();
#endif
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
// timeout)
commFailureReason_ = commFailureReason;
LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
<< (commFailureReason ? *commFailureReason
: "No abort reason provided.");
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
#else
C10D_NCCL_CHECK_TIMEOUT(
::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_);
#endif
aborted_ = true;
ncclComm_ = nullptr;
// Set an appropriate error so that we avoid using the communicator.
if (ncclAsyncErr_ == ncclSuccess) {
ncclAsyncErr_ = ncclSystemError;
}
#else
// This is a NOOP, if error checks are disabled.
return;
#endif
}
void abort(std::optional<std::string> commFailureReason = std::nullopt);
// Finalize a communicator -- asking it to flush its operations. When the
// communicator is marked as nonblocking, this is a nonblocking function;
@ -399,100 +284,19 @@ class NCCLComm {
// Destroy a communicator. This is a blocking function.
void destroy();
bool isInitialized() const {
LockType lock(mutex_);
return initialized_;
}
bool isInitialized() const;
bool isAborted() const {
LockType lock(mutex_);
return aborted_;
}
bool isAborted() const;
uint64_t getCommSplitCounter() const {
return ncclCommSplitCounter_;
}
uint64_t getCommSplitCounter() const;
ncclResult_t checkForNcclError() {
LockType lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
}
C10D_NCCL_CHECK(
ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
return ncclAsyncErr_;
#else
// Always return success, if error checks are disabled.
return ncclSuccess;
#endif
}
ncclResult_t checkForNcclError();
ncclResult_t registerSegment(void* ptr, size_t size) {
LockType lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
// We register only segments from cache allocator
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
// maps to a unique handle and should not be registered before the current
// ptr is deregistered and freed.
TORCH_CHECK(
registeredSegmentHandles_.count(ptr) == 0,
"Segment with ptr ",
ptr,
" has already been registered on ncclComm_ ",
ncclComm_);
ncclResult_t registerSegment(void* ptr, size_t size);
void* handle = nullptr;
// Use getNcclComm to make sure comm is ready before calling nccl APIs
auto comm = getNcclComm();
C10D_NCCL_CHECK(
ncclCommRegister(comm, ptr, size, &handle),
c10::str(
"Failed to register segment with ptr ",
ptr,
", size ",
size,
" on ncclComm_ ",
comm));
registeredSegmentHandles_[ptr] = handle;
return ncclSuccess;
#else
return ncclInvalidUsage;
#endif
}
ncclResult_t deregisterSegment(void* ptr);
ncclResult_t deregisterSegment(void* ptr) {
LockType lock(mutex_);
#ifdef NCCL_HAS_COMM_REGISTER
TORCH_CHECK(
registeredSegmentHandles_.count(ptr) == 1,
"Segment with ptr ",
ptr,
" is not registered on ncclComm_ ",
ncclComm_);
void* handle = registeredSegmentHandles_[ptr];
// Use getNcclComm to make sure comm is ready before calling nccl APIs
auto comm = getNcclComm();
C10D_NCCL_CHECK(
ncclCommDeregister(comm, handle),
c10::str(
"Failed to deregister segment handle ",
handle,
", with ptr ",
ptr,
" on ncclComm_ ",
comm));
registeredSegmentHandles_.erase(ptr);
return ncclSuccess;
#else
return ncclInvalidUsage;
#endif
}
std::string repr() const {
return c10::str((void*)ncclComm_);
}
std::string repr() const;
friend class ProcessGroupNCCL;