diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index f93b78dbc45..19aa6f09d8d 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -8,6 +8,87 @@ namespace c10d { +NCCLComm::NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {} + +NCCLComm::~NCCLComm() noexcept { + // (kwen2501) Making CUDA/NCCL calls in this destructor can hit CUDA driver + // shutdown error if CUDA context has exited first. Thus, we are not + // destroying or aborting NCCL communicators here. We just detect and warn + // about the risk of memory leak. Normally, a user would have called + // `destroy_process_group` or `abort_process_group`, and such risk would be + // avoided. + LockType lock(mutex_); + if (ncclComm_ && initialized_ && !aborted_) { + TORCH_WARN_ONCE( + "WARNING: NCCL communicator hasn't been destroyed. This may cause " + "memory leaks. To avoid the risk, you can call `destroy_process_group` " + "during normal exit or `_abort_process_group` when handling failures.") + } +} + +// NOLINTNEXTLINE(*-noexcept-move-*) +NCCLComm::NCCLComm(NCCLComm&& other) { + // Using other's lock, as it reads other's states + // Can not use this.mutex_, as this object is being constructed. + LockType lock(other.mutex_); + std::swap(ncclComm_, other.ncclComm_); + std::swap(aborted_, other.aborted_); + std::swap(ncclAsyncErr_, other.ncclAsyncErr_); + std::swap(initialized_, other.initialized_); + std::swap(nonBlocking_, other.nonBlocking_); + std::swap(deviceIndex_, other.deviceIndex_); +} + +ncclUniqueId NCCLComm::getNcclId() { + return ncclId_; +} + +std::shared_ptr NCCLComm::create( + int numRanks, + int rank, + ncclUniqueId commId, + at::DeviceIndex deviceIndex) { + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); + auto comm = std::make_shared(); + C10D_NCCL_CHECK( + ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), + std::nullopt); + comm->ncclId_ = commId; + comm->rank_ = rank; + comm->deviceIndex_ = deviceIndex; + comm->initialized_ = true; + // Old style comm is always blocking. + comm->nonBlocking_ = false; + return comm; +} + +#ifdef NCCL_HAS_COMM_NONBLOCKING +std::shared_ptr NCCLComm::create( + int numRanks, + int rank, + ncclUniqueId commId, + at::DeviceIndex deviceIndex, + ncclConfig_t& config) { + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); + auto comm = std::make_shared(); + comm->nonBlocking_ = config.blocking == 0; + LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: " + << (comm->nonBlocking_ ? "nonblocking" : "blocking"); + C10D_NCCL_CHECK_NONBLOCKING( + ncclCommInitRankConfig( + &(comm->ncclComm_), numRanks, commId, rank, &config), + std::nullopt); + comm->ncclId_ = commId; + comm->rank_ = rank; + comm->deviceIndex_ = deviceIndex; + // Under blocking mode, comm is initialized immediately after NCCL init + // returns; Under nonblocking mode, we check whether comm is initialized the + // *next* time ncclComm_ is accessed. + comm->initialized_ = !comm->nonBlocking_; + return comm; +} +#endif + ncclComm_t NCCLComm::getNcclComm() { LockType lock(mutex_); if (aborted_) { @@ -56,6 +137,11 @@ void NCCLComm::waitReady(bool longInterval) { } } +std::optional NCCLComm::getNcclCommFailureReason() const { + LockType lock(mutex_); + return commFailureReason_; +} + // TODO: why do we have `!defined(FBCODE_CAFFE2)` here? #if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2) // last argument to split() API is not used to support @@ -147,6 +233,162 @@ void NCCLComm::destroy() { aborted_ = true; } +void NCCLComm::abort(std::optional commFailureReason) { + LockType lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); +#ifdef ENABLE_NCCL_ERROR_CHECKING + if (aborted_ && !initialized_) { + // Should not abort twice. + return; + } + +#ifdef NCCL_HAS_COMM_REGISTER + // Deregister all registered segments before aborting. + for (auto& it : registeredSegmentHandles_) { + void* handle = it.second; + C10D_NCCL_CHECK( + ::ncclCommDeregister(ncclComm_, handle), + c10::str( + "Failed to deregister segment handle ", + handle, + " on ncclComm_ ", + ncclComm_)); + } + registeredSegmentHandles_.clear(); +#endif + + // Set true failure reason if provided by ProcessGroupNCCL (e.g. work + // timeout) + commFailureReason_ = commFailureReason; + LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: " + << (commFailureReason ? *commFailureReason + : "No abort reason provided."); +#ifndef NCCL_HAS_COMM_NONBLOCKING + C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); +#else + C10D_NCCL_CHECK_TIMEOUT( + ::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_); +#endif + aborted_ = true; + ncclComm_ = nullptr; + + // Set an appropriate error so that we avoid using the communicator. + if (ncclAsyncErr_ == ncclSuccess) { + ncclAsyncErr_ = ncclSystemError; + } +#else + // This is a NOOP, if error checks are disabled. + return; +#endif +} + +bool NCCLComm::isInitialized() const { + LockType lock(mutex_); + return initialized_; +} + +bool NCCLComm::isAborted() const { + LockType lock(mutex_); + return aborted_; +} + +uint64_t NCCLComm::getCommSplitCounter() const { + return ncclCommSplitCounter_; +} + +ncclResult_t NCCLComm::checkForNcclError() { + LockType lock(mutex_); +#ifdef ENABLE_NCCL_ERROR_CHECKING + if (ncclAsyncErr_ != ncclSuccess) { + return ncclAsyncErr_; + } + C10D_NCCL_CHECK( + ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_); + return ncclAsyncErr_; +#else + // Always return success, if error checks are disabled. + return ncclSuccess; +#endif +} + +ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) { + LockType lock(mutex_); +#ifdef NCCL_HAS_COMM_REGISTER + // We register only segments from cache allocator + // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always + // maps to a unique handle and should not be registered before the current + // ptr is deregistered and freed. + TORCH_CHECK( + registeredSegmentHandles_.count(ptr) == 0, + "Segment with ptr ", + ptr, + " has already been registered on ncclComm_ ", + ncclComm_); + + void* handle = nullptr; + // Use getNcclComm to make sure comm is ready before calling nccl APIs + auto comm = getNcclComm(); + C10D_NCCL_CHECK( + ncclCommRegister(comm, ptr, size, &handle), + c10::str( + "Failed to register segment with ptr ", + ptr, + ", size ", + size, + " on ncclComm_ ", + comm)); + registeredSegmentHandles_[ptr] = handle; + return ncclSuccess; +#else + return ncclInvalidUsage; +#endif +} + +ncclResult_t NCCLComm::deregisterSegment(void* ptr) { + LockType lock(mutex_); +#ifdef NCCL_HAS_COMM_REGISTER + TORCH_CHECK( + registeredSegmentHandles_.count(ptr) == 1, + "Segment with ptr ", + ptr, + " is not registered on ncclComm_ ", + ncclComm_); + + void* handle = registeredSegmentHandles_[ptr]; + // Use getNcclComm to make sure comm is ready before calling nccl APIs + auto comm = getNcclComm(); + C10D_NCCL_CHECK( + ncclCommDeregister(comm, handle), + c10::str( + "Failed to deregister segment handle ", + handle, + ", with ptr ", + ptr, + " on ncclComm_ ", + comm)); + registeredSegmentHandles_.erase(ptr); + return ncclSuccess; +#else + return ncclInvalidUsage; +#endif +} + +std::string NCCLComm::repr() const { + return c10::str((void*)ncclComm_); +} + +#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) +std::unordered_map NCCLComm::ncclCommDump() { + std::unordered_map dump; + if (isAborted()) { + LOG(INFO) << "Communicator was aborted before trying to dump its state."; + return dump; + } + C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt); + return dump; +} +#endif + std::string getNcclVersion() { static c10::once_flag ncclGetVersionFlag; static std::string versionString; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 77690ba4a7b..ffb3a1f3dca 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -214,44 +214,17 @@ class NCCLComm { using LockType = std::unique_lock; public: - explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {} + explicit NCCLComm(ncclComm_t ncclComm); NCCLComm() = default; - ~NCCLComm() noexcept { - // (kwen2501) Making CUDA/NCCL calls in this destructor can hit CUDA driver - // shutdown error if CUDA context has exited first. Thus, we are not - // destroying or aborting NCCL communicators here. We just detect and warn - // about the risk of memory leak. Normally, a user would have called - // `destroy_process_group` or `abort_process_group`, and such risk would be - // avoided. - LockType lock(mutex_); - if (ncclComm_ && initialized_ && !aborted_) { - TORCH_WARN_ONCE( - "WARNING: NCCL communicator hasn't been destroyed. This may cause " - "memory leaks. To avoid the risk, you can call `destroy_process_group` " - "during normal exit or `_abort_process_group` when handling failures.") - } - } + ~NCCLComm() noexcept; static std::shared_ptr create( int numRanks, int rank, ncclUniqueId commId, - at::DeviceIndex deviceIndex) { - at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); - auto comm = std::make_shared(); - C10D_NCCL_CHECK( - ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), - std::nullopt); - comm->ncclId_ = commId; - comm->rank_ = rank; - comm->deviceIndex_ = deviceIndex; - comm->initialized_ = true; - // Old style comm is always blocking. - comm->nonBlocking_ = false; - return comm; - } + at::DeviceIndex deviceIndex); #ifdef NCCL_HAS_COMM_NONBLOCKING static std::shared_ptr create( @@ -259,25 +232,7 @@ class NCCLComm { int rank, ncclUniqueId commId, at::DeviceIndex deviceIndex, - ncclConfig_t& config) { - at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); - auto comm = std::make_shared(); - comm->nonBlocking_ = config.blocking == 0; - LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: " - << (comm->nonBlocking_ ? "nonblocking" : "blocking"); - C10D_NCCL_CHECK_NONBLOCKING( - ncclCommInitRankConfig( - &(comm->ncclComm_), numRanks, commId, rank, &config), - std::nullopt); - comm->ncclId_ = commId; - comm->rank_ = rank; - comm->deviceIndex_ = deviceIndex; - // Under blocking mode, comm is initialized immediately after NCCL init - // returns; Under nonblocking mode, we check whether comm is initialized the - // *next* time ncclComm_ is accessed. - comm->initialized_ = !comm->nonBlocking_; - return comm; - } + ncclConfig_t& config); static std::shared_ptr split( NCCLComm* source, @@ -288,20 +243,10 @@ class NCCLComm { #endif #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) - std::unordered_map ncclCommDump() { - std::unordered_map dump; - if (isAborted()) { - LOG(INFO) << "Communicator was aborted before trying to dump its state."; - return dump; - } - C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt); - return dump; - } + std::unordered_map ncclCommDump(); #endif - ncclUniqueId getNcclId() { - return ncclId_; - } + ncclUniqueId getNcclId(); // Must not be copyable NCCLComm(const NCCLComm&) = delete; @@ -312,17 +257,7 @@ class NCCLComm { // Move constructable // NOLINTNEXTLINE(*-noexcept-move-*) - NCCLComm(NCCLComm&& other) { - // Using other's lock, as it reads other's states - // Can not use this.mutex_, as this object is being constructed. - LockType lock(other.mutex_); - std::swap(ncclComm_, other.ncclComm_); - std::swap(aborted_, other.aborted_); - std::swap(ncclAsyncErr_, other.ncclAsyncErr_); - std::swap(initialized_, other.initialized_); - std::swap(nonBlocking_, other.nonBlocking_); - std::swap(deviceIndex_, other.deviceIndex_); - } + NCCLComm(NCCLComm&& other); ncclComm_t getNcclComm(); @@ -337,59 +272,9 @@ class NCCLComm { // ncclSuccess. void waitReady(bool longInterval); - std::optional getNcclCommFailureReason() const { - LockType lock(mutex_); - return commFailureReason_; - } + std::optional getNcclCommFailureReason() const; - void abort(std::optional commFailureReason = std::nullopt) { - LockType lock(mutex_); - at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); -#ifdef ENABLE_NCCL_ERROR_CHECKING - if (aborted_ && !initialized_) { - // Should not abort twice. - return; - } - -#ifdef NCCL_HAS_COMM_REGISTER - // Deregister all registered segments before aborting. - for (auto& it : registeredSegmentHandles_) { - void* handle = it.second; - C10D_NCCL_CHECK( - ::ncclCommDeregister(ncclComm_, handle), - c10::str( - "Failed to deregister segment handle ", - handle, - " on ncclComm_ ", - ncclComm_)); - } - registeredSegmentHandles_.clear(); -#endif - - // Set true failure reason if provided by ProcessGroupNCCL (e.g. work - // timeout) - commFailureReason_ = commFailureReason; - LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: " - << (commFailureReason ? *commFailureReason - : "No abort reason provided."); -#ifndef NCCL_HAS_COMM_NONBLOCKING - C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); -#else - C10D_NCCL_CHECK_TIMEOUT( - ::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_); -#endif - aborted_ = true; - ncclComm_ = nullptr; - - // Set an appropriate error so that we avoid using the communicator. - if (ncclAsyncErr_ == ncclSuccess) { - ncclAsyncErr_ = ncclSystemError; - } -#else - // This is a NOOP, if error checks are disabled. - return; -#endif - } + void abort(std::optional commFailureReason = std::nullopt); // Finalize a communicator -- asking it to flush its operations. When the // communicator is marked as nonblocking, this is a nonblocking function; @@ -399,100 +284,19 @@ class NCCLComm { // Destroy a communicator. This is a blocking function. void destroy(); - bool isInitialized() const { - LockType lock(mutex_); - return initialized_; - } + bool isInitialized() const; - bool isAborted() const { - LockType lock(mutex_); - return aborted_; - } + bool isAborted() const; - uint64_t getCommSplitCounter() const { - return ncclCommSplitCounter_; - } + uint64_t getCommSplitCounter() const; - ncclResult_t checkForNcclError() { - LockType lock(mutex_); -#ifdef ENABLE_NCCL_ERROR_CHECKING - if (ncclAsyncErr_ != ncclSuccess) { - return ncclAsyncErr_; - } - C10D_NCCL_CHECK( - ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_); - return ncclAsyncErr_; -#else - // Always return success, if error checks are disabled. - return ncclSuccess; -#endif - } + ncclResult_t checkForNcclError(); - ncclResult_t registerSegment(void* ptr, size_t size) { - LockType lock(mutex_); -#ifdef NCCL_HAS_COMM_REGISTER - // We register only segments from cache allocator - // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always - // maps to a unique handle and should not be registered before the current - // ptr is deregistered and freed. - TORCH_CHECK( - registeredSegmentHandles_.count(ptr) == 0, - "Segment with ptr ", - ptr, - " has already been registered on ncclComm_ ", - ncclComm_); + ncclResult_t registerSegment(void* ptr, size_t size); - void* handle = nullptr; - // Use getNcclComm to make sure comm is ready before calling nccl APIs - auto comm = getNcclComm(); - C10D_NCCL_CHECK( - ncclCommRegister(comm, ptr, size, &handle), - c10::str( - "Failed to register segment with ptr ", - ptr, - ", size ", - size, - " on ncclComm_ ", - comm)); - registeredSegmentHandles_[ptr] = handle; - return ncclSuccess; -#else - return ncclInvalidUsage; -#endif - } + ncclResult_t deregisterSegment(void* ptr); - ncclResult_t deregisterSegment(void* ptr) { - LockType lock(mutex_); -#ifdef NCCL_HAS_COMM_REGISTER - TORCH_CHECK( - registeredSegmentHandles_.count(ptr) == 1, - "Segment with ptr ", - ptr, - " is not registered on ncclComm_ ", - ncclComm_); - - void* handle = registeredSegmentHandles_[ptr]; - // Use getNcclComm to make sure comm is ready before calling nccl APIs - auto comm = getNcclComm(); - C10D_NCCL_CHECK( - ncclCommDeregister(comm, handle), - c10::str( - "Failed to deregister segment handle ", - handle, - ", with ptr ", - ptr, - " on ncclComm_ ", - comm)); - registeredSegmentHandles_.erase(ptr); - return ncclSuccess; -#else - return ncclInvalidUsage; -#endif - } - - std::string repr() const { - return c10::str((void*)ncclComm_); - } + std::string repr() const; friend class ProcessGroupNCCL;