diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 6af7e9230d7..b038db2eaab 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -164,7 +164,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { // so we have this hack to manually set the desync debug flag after PG // creation. void forceSetDesyncDebugFlag() { - desyncDebug_ = true; + watchdog_->setDesyncDebug(true); } private: diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 4e73129ea42..5a8ac191853 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -511,7 +511,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( bool isP2P, const char* profilingTitle, const std::optional>& inputs, - bool desyncDebug, bool enableTiming, bool cudaEventCacheEnabled, DebugLevel distDebugLevel) @@ -952,15 +951,8 @@ ProcessGroupNCCL::ProcessGroupNCCL( blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false); asyncErrorHandling_ = static_cast( getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/)); - desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || - (dist_debug_level_ >= DebugLevel::Detail); - rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true); - propagatePgError_ = getCvarBool(TORCH_NCCL_PROPAGATE_ERROR, false); enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false); - heartbeat_ = 1ULL; cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, true)); - waitTimeoutDumpInMilSec_ = - getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /*15 Sec*/); traceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 2000); enableCollectiveHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); // store_ usually is wrapped with PrefixStore and the prefix is different @@ -970,9 +962,11 @@ ProcessGroupNCCL::ProcessGroupNCCL( PrefixStore* prefixStore = dynamic_cast(store_.get()); globalStore_ = prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_; + auto desyncDebug = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || + (dist_debug_level_ >= DebugLevel::Detail); #ifdef ENABLE_NCCL_ERROR_CHECKING enableTiming_.store( - getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); + getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug); #endif // ENABLE_NCCL_ERROR_CHECKING if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) { TORCH_WARN_ONCE( @@ -984,7 +978,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT is enabled, NO watchdog thread is created."; } else { - if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { + if (desyncDebug && asyncErrorHandling_ == NoHandling) { LOG(INFO) << logPrefix() << "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING " @@ -994,17 +988,17 @@ ProcessGroupNCCL::ProcessGroupNCCL( } } - // Initialize the heartbeat monitor instance. This has to be done before - // the watchdog thread is launched to avoid the error. + // Initialize the heartbeat monitor/watchdog instance. This has to be done + // before the corresponding thread is launched to avoid the error. heartbeatMonitor_ = std::make_unique(this); + watchdog_ = std::make_unique(this); #ifdef ENABLE_NCCL_ERROR_CHECKING // in blockingWait mode, we don't need to enable the watchdog thread to check // the timeout or nccl error because the main thread would throw an exception // and it is the user's responsibility to handle the exception. if (!blockingWait_) { - ncclCommWatchdogThread_ = - std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); + watchdog_->start(); } #endif // ENABLE_NCCL_ERROR_CHECKING @@ -1024,8 +1018,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: " << "NCCL version: " << ncclVersion << ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ - << ", TORCH_NCCL_PROPAGATE_ERROR: " << propagatePgError_ - << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_ << ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load() << ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_ << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug @@ -1050,11 +1042,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( // This call is idempotent. attachAllocatorHooks(); } - - // Enable Desync Debugger per user setting - if (desyncDebug_) { - desyncDebugger_.init(rank, size, globalRank(), getUid(), store_); - } } void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { @@ -1271,7 +1258,7 @@ void ProcessGroupNCCL::waitForPendingWorks() { // completedWorkList_ before it finishes. // 3. We have three threads and two locks. // a. main thread (this function) grabs two locks atomically - // b. watchdog thread (watchdogHandler function) always grabs + // b. watchdog thread (runLoop function) always grabs // workMetaListMutex_ // first and then grabs completedWorkListMutex_. // c. hook thread (runHookLoop function) only grabs @@ -1425,7 +1412,7 @@ void ProcessGroupNCCL::abort() { // communicators and signal the threads to exit. Joining on the threads could // potentially block and hence avoid it in this method. terminateProcessGroup_.store(true); - workMetaListCV_.notify_one(); + watchdog_->notify(); // lauch abort asynchrounously and wait for it to complete or timeout LOG(INFO) << logPrefix() @@ -1480,10 +1467,8 @@ void ProcessGroupNCCL::shutdown() { // anymore because I am going to destroy them now LOG(INFO) << logPrefix() << "Operations flushed, joining watchdog thread."; terminateProcessGroup_.store(true); - workMetaListCV_.notify_one(); - if (ncclCommWatchdogThread_.joinable()) { - ncclCommWatchdogThread_.join(); - } + watchdog_->notify(); + watchdog_->join(); if (onCompletionHookThread_.joinable()) { onCompletionHookThread_.join(); } @@ -1542,15 +1527,12 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { // Make sure we've told threads to stop; doesn't hurt if we'd done so before. // Tell watchdog and onCompletionHook: terminateProcessGroup_.store(true); - workMetaListCV_.notify_one(); + watchdog_->notify(); // Tell heartbeat thread: heartbeatMonitor_->stop(); // Wait for all threads to finish before returning - if (ncclCommWatchdogThread_.joinable()) { - ncclCommWatchdogThread_.join(); - LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; - } + watchdog_->join(); heartbeatMonitor_->join(); if (onCompletionHookThread_.joinable()) { onCompletionHookThread_.join(); @@ -1639,6 +1621,10 @@ void ProcessGroupNCCL::HeartbeatMonitor::setLastWorkListUpdateTime( lastWorkListUpdateTime_ = time; } +int ProcessGroupNCCL::HeartbeatMonitor::getDumpTimeout() const { + return waitTimeoutDumpInMilSec_; +} + ProcessGroupNCCL::HeartbeatMonitor::HeartbeatMonitor(ProcessGroupNCCL* pg) { pg_ = pg; heartbeatTimeoutInSec_ = @@ -1969,27 +1955,69 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { } } -void ProcessGroupNCCL::ncclCommWatchdog() { +ProcessGroupNCCL::Watchdog::Watchdog(ProcessGroupNCCL* pg) { + pg_ = pg; + heartbeat_ = 1ULL; + rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true); + propagatePgError_ = getCvarBool(TORCH_NCCL_PROPAGATE_ERROR, false); + desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) || + (pg_->dist_debug_level_ >= DebugLevel::Detail); + + // print out ENV settings for the watchdog thread. + LOG(INFO) << pg_->logPrefix() << "PGNCCL Watchdog environments: " + << "TORCH_NCCL_RETHROW_CUDA_ERRORS: " << rethrowCUDAErrors_ + << ", TORCH_NCCL_PROPAGATE_ERROR: " << propagatePgError_ + << ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_; + + // Enable Desync Debugger per user setting + if (desyncDebug_) { + desyncDebugger_.init( + pg_->getRank(), + pg_->getSize(), + pg_->globalRank(), + pg_->getUid(), + pg_->store_); + } +} + +void ProcessGroupNCCL::Watchdog::notify() { + workMetaListCV_.notify_one(); +} + +void ProcessGroupNCCL::Watchdog::start() { + TORCH_CHECK( + !ncclCommWatchdogThread_.joinable(), "Watchdog thread already started"); + ncclCommWatchdogThread_ = std::thread(&ProcessGroupNCCL::Watchdog::run, this); +} + +void ProcessGroupNCCL::Watchdog::join() { + if (ncclCommWatchdogThread_.joinable()) { + ncclCommWatchdogThread_.join(); + LOG(INFO) << pg_->logPrefix() << "ProcessGroupNCCL watchdog thread joined."; + } +} + +void ProcessGroupNCCL::Watchdog::run() { c10::setThreadName("pt_nccl_watchdg"); try { - VLOG(2) << logPrefix() << "Process group watchdog thread started!"; - heartbeatMonitor_->start(); - watchdogHandler(); - VLOG(2) << logPrefix() + VLOG(2) << pg_->logPrefix() << "Process group watchdog thread started!"; + pg_->heartbeatMonitor_->start(); + runLoop(); + VLOG(2) << pg_->logPrefix() << "Process group watchdog thread terminated normally"; } catch (std::exception& e) { if (std::string(e.what()).find("driver shutting down") != std::string::npos) { VLOG(2) - << logPrefix() + << pg_->logPrefix() << "main process destroyed cuda before watchdog loop exited, terminating watchdog." << " (Watchdog caught exception: " << e.what(); } else { - // Append error message reported from watchdogHandler + // Append error message reported from runLoop const auto exitMsg = c10::str( - logPrefix(), + pg_->logPrefix(), "Process group watchdog thread terminated with exception: ", e.what()); LOG(ERROR) << exitMsg; @@ -2004,7 +2032,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() { } } catch (...) { const auto exitMsg = c10::str( - logPrefix(), + pg_->logPrefix(), "Process group watchdog thread terminated with exception: unknown"); LOG(ERROR) << exitMsg; watchDogException_ = @@ -2013,6 +2041,308 @@ void ProcessGroupNCCL::ncclCommWatchdog() { } } +int ProcessGroupNCCL::Watchdog::getSignalSrcRank( + c10::intrusive_ptr& store, + const std::string& signal) { + // This function is 'non blocking'. We first 'check' if the key exists in the + // store, then read/get the value only if the key exists. + int srcRank = -1; + bool signalExists = false; + try { + signalExists = store->check({signal}); + } catch (const std::exception& e) { + LOG(WARNING) << pg_->logPrefix() << "Failed to check the signal " << signal + << " on TCPStore, " << e.what(); + } + if (!signalExists) { + return srcRank; + } + + // key exists, now read and parse the value (source rank) + std::vector vec; + try { + vec = store->get(std::string(signal)); + } catch (const std::exception& e) { + LOG(ERROR) << pg_->logPrefix() << "Failed to get source rank of the signal " + << signal << " from TCPStore." << e.what(); + } + TORCH_CHECK_WITH( + DistBackendError, + vec.size() == sizeof(int), + "Invalid size for the timeout rank ID"); + std::memcpy(&srcRank, vec.data(), vec.size()); + return srcRank; +} + +void ProcessGroupNCCL::Watchdog::checkAndSetRemoteError() { + // if the error is already set, no need to check again + if (pg_->getError() != ErrorType::SUCCESS) { + return; + } + // key/signal to read from the tcpstore is a string and pg specific: + // format is: remote_error:pg_uid + int remoteErrorRank = getSignalSrcRank( + pg_->store_, std::string(kStoreErrorSignalKey) + ':' + pg_->pg_uid_); + if (remoteErrorRank != -1) { + std::lock_guard lock(pg_->errorMutex_); + pg_->error_ = ErrorType::REMOTE_ERROR; + LOG(ERROR) << c10::str( + pg_->logPrefix(), + " remote error detected from rank: ", + remoteErrorRank); + } +} + +void ProcessGroupNCCL::Watchdog::runLoop() { + bool done = false; + pg_->heartbeatMonitor_->setLastWorkListUpdateTime( + std::chrono::steady_clock::now()); + auto lastStatusUpdateTime = std::chrono::steady_clock::now(); + std::list completedWorkList; + + while (!done || !pg_->terminateProcessGroup_.load()) { + std::unique_lock lock(pg_->workMetaListMutex_); + // We busy-poll the work vector every kWatchdogThreadSleepMillis + // milliseconds as long as the atomic is True. + workMetaListCV_.wait_for( + lock, + std::chrono::milliseconds(kWatchdogThreadSleepMillis), + [&]() -> bool { return pg_->terminateProcessGroup_.load(); }); + // Bump up heart beat by one. + heartbeat_++; + +// Some versions of GLOG support less-spammy version of LOG_EVERY_MS +// in which case we don't want to spam the logs. +#ifdef LOG_EVERY_MS + // Log the progress of this PG periodically + C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str( + logPrefix(), + "NCCL Work update periodically: ", + "last enqueued NCCL work: ", + pg_->pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pg_->pgStatus_->lastCompletedSeq, + "."); +#endif // LOG_EVERY_MS + auto logger = ::c10d::C10dLogger::getLogger(); + if (logger && + computeDeltaMS( + lastStatusUpdateTime, std::chrono::steady_clock::now()) >= + kWorkStatusUpdatePeriodMs) { + ::c10d::C10dLoggingData data; + // logging integers + data.integers["pg_id"] = static_cast(pg_->local_id_); + data.integers["rank"] = pg_->rank_; + data.integers["global_rank"] = pg_->globalRank(); + data.integers["last_enqueued_work"] = pg_->pgStatus_->lastEnqueuedSeq; + data.integers["last_started_work"] = pg_->pgStatus_->lastStartedSeq; + data.integers["last_completed_work"] = pg_->pgStatus_->lastCompletedSeq; + data.integers["last_enqueued_numel_in"] = + static_cast(pg_->pgStatus_->lastEnqueuedNumelIn); + data.integers["last_enqueued_numel_out"] = + static_cast(pg_->pgStatus_->lastEnqueuedNumelOut); + data.integers["last_completed_numel_in"] = + static_cast(pg_->pgStatus_->lastCompletedNumelIn); + data.integers["last_completed_numel_out"] = + static_cast(pg_->pgStatus_->lastCompletedNumelOut); + data.integers["last_started_numel_in"] = + static_cast(pg_->pgStatus_->lastStartedNumelIn); + data.integers["last_started_numel_out"] = + static_cast(pg_->pgStatus_->lastStartedNumelOut); + // logging strings + data.strings["last_enqueued_work_name"] = + pg_->pgStatus_->lastEnqueuedWorkName; + data.strings["last_started_work_name"] = + pg_->pgStatus_->lastStartedWorkName; + data.strings["last_completed_work_name"] = + pg_->pgStatus_->lastCompletedWorkName; + data.strings["pg_name"] = pg_->pg_uid_; + data.strings["pg_desc"] = pg_->pg_desc_; + logger->log(data); + lastStatusUpdateTime = std::chrono::steady_clock::now(); + } + + if (propagatePgError_) { + // Check and set remote error if it has not been set before + checkAndSetRemoteError(); + } + + for (auto it = pg_->workMetaList_.begin(); it != pg_->workMetaList_.end(); + /* no increment */) { + auto& work = *it; + // When terminateProcessGroup_ is true, communicators have already been + // aborted, So cannot check exception based on them. But watchdog needs to + // finish the check for the works that have already been enqueued to + // workMetaList_ + + // check NCCL errors first + if (!pg_->terminateProcessGroup_.load()) { + work.checkAndSetException(); + } + + if (work.exception()) { + // set the error to the first error found + std::lock_guard lock(pg_->errorMutex_); + if (pg_->error_ == ErrorType::SUCCESS) { + pg_->error_ = ErrorType::COMM_ERROR; + } + } + + // Then check if work has timed out + // Skip if work has encountered an error + bool timedout = !work.exception() && work.checkTimeout(); + + // Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is + // turned on; otherwise, run() is no-op) + if (timedout) { + std::lock_guard lock(pg_->errorMutex_); + if (pg_->error_ == ErrorType::SUCCESS) { + pg_->error_ = ErrorType::TIMEOUT; + } + desyncDebugger_.run(); + } + + // If work hits an exception (either an error or timeout) + if (work.exception()) { + LOG(ERROR) << c10::str( + pg_->logPrefix(), + " failure detected by watchdog at work sequence id: ", + work.seq_, + " PG status: last enqueued work: ", + pg_->pgStatus_->lastEnqueuedSeq, + ", last completed work: ", + pg_->pgStatus_->lastCompletedSeq); + + // Print the traceback of the collective at call time + work.printTraceback(); + + // broadcast remote error signal to all other ranks in this specific PG. + // key/signal to write in the tcpstore is a string and pg specific: + // format is: remote_error:pg_uid + if (propagatePgError_) { + pg_->broadcastSignal( + pg_->store_, + std::string(kStoreErrorSignalKey) + ':' + pg_->pg_uid_, + pg_->rank_); + } + + // try to notify other ranks via global TCPStore to dump the flight + // recorder when a collective timeout or exception happens. Flight + // recorder behavior is independent of desync Debug. + pg_->broadcastDumpSignal(); + // Give time for dumping before throwing exception for all ranks. + // It is hard to presume or control what the pattern of watchdog might + // look like, so it is better to let all ranks universally sleep for a + // short period of time, in this case, 60 seconds, which is also the + // maximum time we leave for FR dump. + std::this_thread::sleep_for(std::chrono::milliseconds( + pg_->heartbeatMonitor_->getDumpTimeout() * 4)); + + if (SHOULD_CLEAN_UP(pg_->asyncErrorHandling_)) { + // Abort work and corresponding communicators + work.abort(); + // PG level abort, which would abort all other communicators on this + // rank + pg_->abortComms(); + } + // Throw exception + work.handleException(pg_->asyncErrorHandling_); + } + + // Work status logging for desync debug + desyncDebugger_.logWorkStart(work); + + // a work could be started but not completed, so we should not update + // lastStartedSeq and lastStartedOpName if the work state is checked + // multiple times after the start + if (pg_->pgStatus_->lastStartedSeq < static_cast(work.seq_) && + work.isStarted()) { + pg_->pgStatus_->lastStartedSeq = static_cast(work.seq_); + pg_->pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); + pg_->pgStatus_->lastStartedNumelIn = work.numelIn_; + pg_->pgStatus_->lastStartedNumelOut = work.numelOut_; + } + + // allow watchdog to do an event query on a side thread + at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); + at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; + + // Clean up completed work + if (work.isCompleted()) { + // In case user didn't call `work.wait()` with async collectives, + // watchdog would unstage the stashed tensors when detecting completion + // of the collective, to prevent ProcessGroupNCCL from holding reference + // to those tensors forever. + // work.stashed_for_allocator_safety_->unstash(); + // Update: it seems directly unstashing from watchdog thread would cause + // some rare problems. We thus move the unstashing to main thread, + // triggered by a next user call, see `workEnqueue`. But `work` is going + // to be destructed, so we transfer the work's shelf to a shelves + // structure owned by the PG. + if (!work.stashed_for_allocator_safety_->empty()) { + std::lock_guard lock(pg_->shelvesMutex_); + // We are just pushing back a shared_ptr here, so the cost should be + // minimal + pg_->shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_); + } + + // Work status logging for desync debug + desyncDebugger_.logWorkEnd(work); + + if (work.futureWorkResult_ && work.finishedGPUExecutionInternal() && + !work.futureWorkResult_->completed()) { + work.futureWorkResult_->markCompleted( + at::IValue(static_cast(WorkResult::SUCCESS))); + } + { + // Reset the timeout and first work if the work is completed. + std::lock_guard timeoutLock(pg_->mtxTimeoutExtension_); + if (work.ownedEphermeralTimeout_.count() > 0) { + pg_->ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; + pg_->ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; + } + } + pg_->pgStatus_->lastCompletedSeq = static_cast(work.seq_); + pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); + pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_; + pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_; + FlightRecorderCUDA::get()->retire_id(work.trace_id_, true); + if (pg_->onCompletionHook_) { + // Move Work object to completedWorkList_ to be consumed by the hook + // thread + { + const std::lock_guard lock( + pg_->completedWorkListMutex_); + pg_->completedWorkList_.splice( + pg_->completedWorkList_.end(), pg_->workMetaList_, it++); + } + pg_->completedWorkListCV_.notify_one(); + } else { + it = pg_->workMetaList_.erase(it); + pg_->heartbeatMonitor_->setLastWorkListUpdateTime( + std::chrono::steady_clock::now()); + } + } else { + // Increment the iterator if the current WorkNCCL object is not + // completed. + ++it; + } + // Increment heartbeat after each work processed, + // in case processing is slowed down (but not hung) by cuda api contention + heartbeat_++; + } + done = pg_->workMetaList_.empty(); + } +} + +uint64_t ProcessGroupNCCL::Watchdog::getHeartbt() const { + return heartbeat_.load(); +} + +void ProcessGroupNCCL::Watchdog::setDesyncDebug(bool desyncDebug) { + desyncDebug_ = desyncDebug; +} + // Initialize and enable DesyncDebugger void ProcessGroupNCCL::DesyncDebugger::init( int rank, @@ -2182,39 +2512,6 @@ void ProcessGroupNCCL::broadcastSignal( } } -int ProcessGroupNCCL::getSignalSrcRank( - c10::intrusive_ptr& store, - const std::string& signal) { - // This function is 'non blocking'. We first 'check' if the key exists in the - // store, then read/get the value only if the key exists. - int srcRank = -1; - bool signalExists = false; - try { - signalExists = store->check({signal}); - } catch (const std::exception& e) { - LOG(WARNING) << logPrefix() << "Failed to check the signal " << signal - << " on TCPStore, " << e.what(); - } - if (!signalExists) { - return srcRank; - } - - // key exists, now read and parse the value (source rank) - std::vector vec; - try { - vec = store->get(std::string(signal)); - } catch (const std::exception& e) { - LOG(ERROR) << logPrefix() << "Failed to get source rank of the signal " - << signal << " from TCPStore." << e.what(); - } - TORCH_CHECK_WITH( - DistBackendError, - vec.size() == sizeof(int), - "Invalid size for the timeout rank ID"); - std::memcpy(&srcRank, vec.data(), vec.size()); - return srcRank; -} - void ProcessGroupNCCL::broadcastDumpSignal() { // broadcast dump signal to all other global ranks. broadcastSignal(globalStore_, std::string(kStoreDumpKey), globalRank()); @@ -2226,23 +2523,6 @@ void ProcessGroupNCCL::broadcastDumpSignal() { } } -void ProcessGroupNCCL::checkAndSetRemoteError() { - // if the error is already set, no need to check again - if (getError() != ErrorType::SUCCESS) { - return; - } - // key/signal to read from the tcpstore is a string and pg specific: - // format is: remote_error:pg_uid - int remoteErrorRank = getSignalSrcRank( - store_, std::string(kStoreErrorSignalKey) + ':' + pg_uid_); - if (remoteErrorRank != -1) { - std::lock_guard lock(errorMutex_); - error_ = ErrorType::REMOTE_ERROR; - LOG(ERROR) << c10::str( - logPrefix(), " remote error detected from rank: ", remoteErrorRank); - } -} - // NCCL recommends to evenly distribute ncclUniqueIds across the ranks // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#init-rank-config // Let’s consider an example where: @@ -2269,243 +2549,6 @@ static int getRootIndex(const int rank, const int nRanks, const int nIds) { } } -void ProcessGroupNCCL::watchdogHandler() { - bool done = false; - heartbeatMonitor_->setLastWorkListUpdateTime( - std::chrono::steady_clock::now()); - auto lastStatusUpdateTime = std::chrono::steady_clock::now(); - std::list completedWorkList; - - while (!done || !terminateProcessGroup_.load()) { - std::unique_lock lock(workMetaListMutex_); - // We busy-poll the work vector every kWatchdogThreadSleepMillis - // milliseconds as long as the atomic is True. - workMetaListCV_.wait_for( - lock, - std::chrono::milliseconds(kWatchdogThreadSleepMillis), - [&]() -> bool { return terminateProcessGroup_.load(); }); - // Bump up heart beat by one. - heartbeat_++; - -// Some versions of GLOG support less-spammy version of LOG_EVERY_MS -// in which case we don't want to spam the logs. -#ifdef LOG_EVERY_MS - // Log the progress of this PG periodically - C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str( - logPrefix(), - "NCCL Work update periodically: ", - "last enqueued NCCL work: ", - pgStatus_->lastEnqueuedSeq, - ", last completed NCCL work: ", - pgStatus_->lastCompletedSeq, - "."); -#endif // LOG_EVERY_MS - auto logger = ::c10d::C10dLogger::getLogger(); - if (logger && - computeDeltaMS( - lastStatusUpdateTime, std::chrono::steady_clock::now()) >= - kWorkStatusUpdatePeriodMs) { - ::c10d::C10dLoggingData data; - // logging integers - data.integers["pg_id"] = static_cast(local_id_); - data.integers["rank"] = rank_; - data.integers["global_rank"] = globalRank(); - data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq; - data.integers["last_started_work"] = pgStatus_->lastStartedSeq; - data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq; - data.integers["last_enqueued_numel_in"] = - static_cast(pgStatus_->lastEnqueuedNumelIn); - data.integers["last_enqueued_numel_out"] = - static_cast(pgStatus_->lastEnqueuedNumelOut); - data.integers["last_completed_numel_in"] = - static_cast(pgStatus_->lastCompletedNumelIn); - data.integers["last_completed_numel_out"] = - static_cast(pgStatus_->lastCompletedNumelOut); - data.integers["last_started_numel_in"] = - static_cast(pgStatus_->lastStartedNumelIn); - data.integers["last_started_numel_out"] = - static_cast(pgStatus_->lastStartedNumelOut); - // logging strings - data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName; - data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName; - data.strings["last_completed_work_name"] = - pgStatus_->lastCompletedWorkName; - data.strings["pg_name"] = pg_uid_; - data.strings["pg_desc"] = pg_desc_; - logger->log(data); - lastStatusUpdateTime = std::chrono::steady_clock::now(); - } - - if (propagatePgError_) { - // Check and set remote error if it has not been set before - checkAndSetRemoteError(); - } - - for (auto it = workMetaList_.begin(); it != workMetaList_.end(); - /* no increment */) { - auto& work = *it; - // When terminateProcessGroup_ is true, communicators have already been - // aborted, So cannot check exception based on them. But watchdog needs to - // finish the check for the works that have already been enqueued to - // workMetaList_ - - // check NCCL errors first - if (!terminateProcessGroup_.load()) { - work.checkAndSetException(); - } - - if (work.exception()) { - // set the error to the first error found - std::lock_guard lock(errorMutex_); - if (error_ == ErrorType::SUCCESS) { - error_ = ErrorType::COMM_ERROR; - } - } - - // Then check if work has timed out - // Skip if work has encountered an error - bool timedout = !work.exception() && work.checkTimeout(); - - // Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is - // turned on; otherwise, run() is no-op) - if (timedout) { - std::lock_guard lock(errorMutex_); - if (error_ == ErrorType::SUCCESS) { - error_ = ErrorType::TIMEOUT; - } - desyncDebugger_.run(); - } - - // If work hits an exception (either an error or timeout) - if (work.exception()) { - LOG(ERROR) << c10::str( - logPrefix(), - " failure detected by watchdog at work sequence id: ", - work.seq_, - " PG status: last enqueued work: ", - pgStatus_->lastEnqueuedSeq, - ", last completed work: ", - pgStatus_->lastCompletedSeq); - - // Print the traceback of the collective at call time - work.printTraceback(); - - // broadcast remote error signal to all other ranks in this specific PG. - // key/signal to write in the tcpstore is a string and pg specific: - // format is: remote_error:pg_uid - if (propagatePgError_) { - broadcastSignal( - store_, std::string(kStoreErrorSignalKey) + ':' + pg_uid_, rank_); - } - - // try to notify other ranks via global TCPStore to dump the flight - // recorder when a collective timeout or exception happens. Flight - // recorder behavior is independent of desync Debug. - broadcastDumpSignal(); - // Give time for dumping before throwing exception for all ranks. - // It is hard to presume or control what the pattern of watchdog might - // look like, so it is better to let all ranks universally sleep for a - // short period of time, in this case, 60 seconds, which is also the - // maximum time we leave for FR dump. - std::this_thread::sleep_for( - std::chrono::milliseconds(waitTimeoutDumpInMilSec_ * 4)); - - if (SHOULD_CLEAN_UP(asyncErrorHandling_)) { - // Abort work and corresponding communicators - work.abort(); - // PG level abort, which would abort all other communicators on this - // rank - abortComms(); - } - // Throw exception - work.handleException(asyncErrorHandling_); - } - - // Work status logging for desync debug - desyncDebugger_.logWorkStart(work); - - // a work could be started but not completed, so we should not update - // lastStartedSeq and lastStartedOpName if the work state is checked - // multiple times after the start - if (pgStatus_->lastStartedSeq < static_cast(work.seq_) && - work.isStarted()) { - pgStatus_->lastStartedSeq = static_cast(work.seq_); - pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); - pgStatus_->lastStartedNumelIn = work.numelIn_; - pgStatus_->lastStartedNumelOut = work.numelOut_; - } - - // allow watchdog to do an event query on a side thread - at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); - at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; - - // Clean up completed work - if (work.isCompleted()) { - // In case user didn't call `work.wait()` with async collectives, - // watchdog would unstage the stashed tensors when detecting completion - // of the collective, to prevent ProcessGroupNCCL from holding reference - // to those tensors forever. - // work.stashed_for_allocator_safety_->unstash(); - // Update: it seems directly unstashing from watchdog thread would cause - // some rare problems. We thus move the unstashing to main thread, - // triggered by a next user call, see `workEnqueue`. But `work` is going - // to be destructed, so we transfer the work's shelf to a shelves - // structure owned by the PG. - if (!work.stashed_for_allocator_safety_->empty()) { - std::lock_guard lock(shelvesMutex_); - // We are just pushing back a shared_ptr here, so the cost should be - // minimal - shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_); - } - - // Work status logging for desync debug - desyncDebugger_.logWorkEnd(work); - - if (work.futureWorkResult_ && work.finishedGPUExecutionInternal() && - !work.futureWorkResult_->completed()) { - work.futureWorkResult_->markCompleted( - at::IValue(static_cast(WorkResult::SUCCESS))); - } - { - // Reset the timeout and first work if the work is completed. - std::lock_guard timeoutLock(mtxTimeoutExtension_); - if (work.ownedEphermeralTimeout_.count() > 0) { - ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; - ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; - } - } - pgStatus_->lastCompletedSeq = static_cast(work.seq_); - pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); - pgStatus_->lastCompletedNumelIn = work.numelIn_; - pgStatus_->lastCompletedNumelOut = work.numelOut_; - FlightRecorderCUDA::get()->retire_id(work.trace_id_, true); - if (onCompletionHook_) { - // Move Work object to completedWorkList_ to be consumed by the hook - // thread - { - const std::lock_guard lock(completedWorkListMutex_); - completedWorkList_.splice( - completedWorkList_.end(), workMetaList_, it++); - } - completedWorkListCV_.notify_one(); - } else { - it = workMetaList_.erase(it); - heartbeatMonitor_->setLastWorkListUpdateTime( - std::chrono::steady_clock::now()); - } - } else { - // Increment the iterator if the current WorkNCCL object is not - // completed. - ++it; - } - // Increment heartbeat after each work processed, - // in case processing is slowed down (but not hung) by cuda api contention - heartbeat_++; - } - done = workMetaList_.empty(); - } -} - void ProcessGroupNCCL::runHookLoop() { c10::setThreadName("pt_nccl_runhook"); @@ -3208,7 +3251,6 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) : std::nullopt, - desyncDebug_, enableTiming_.load(), cudaEventCacheEnabled_.load(), dist_debug_level_); @@ -3329,7 +3371,7 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; uint64_t ProcessGroupNCCL::getWatchdogHeartbt() const { - return heartbeat_.load(); + return watchdog_->getHeartbt(); } void ProcessGroupNCCL::startCoalescing() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index c44613065c3..2719d68e674 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -319,7 +319,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { bool isP2P = false, const char* profilingTitle = nullptr, const std::optional>& inputs = std::nullopt, - bool desyncDebug = false, bool enableTiming = false, bool cudaEventCacheEnabled = false, DebugLevel distDebugLevel = DebugLevel::Off); @@ -621,6 +620,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { void setLastWorkListUpdateTime( std::chrono::time_point time); + int getDumpTimeout() const; + // Util function to get the timeout error message std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg); @@ -676,6 +677,80 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::chrono::time_point lastWorkListUpdateTime_; }; + // Class that runs as a side thread to check whether the NCCL collective + // is timed out or errors on the cached NCCL communicators. + class Watchdog { + public: + Watchdog(ProcessGroupNCCL* pg); + virtual ~Watchdog() = default; + + // Start the watchdog thread. + void start(); + + // Join the watchdog thread. + void join(); + + // Function that runs as part of a separate thread and checks for errors on + // NCCL communicators. We need a separate thread to check for NCCL errors + // since we can't rely on the user calling certain methods like wait(), + // isCompleted() etc. to detect and remediate errors. In addition to this, + // we need a mechanism to safely abort and remove NCCL communicators from + // our cache. This can be done cleanly by having a thread for the + // ProcessGroupNCCL class. Attempting to modify the communicator cache from + // the WorkNCCL class might run into issues with object lifetime since the + // ProcessGroupNCCL object might get destroyed before the WorkNCCL object. + void run(); + + // Watchdog's inside loop. + // Takes care of cleaning up completed work, and aborting upon failure or + // timeout. + void runLoop(); + + // Notify the loop inside watchdog. + void notify(); + + void checkAndSetRemoteError(); + + // A helper function to get the src rank of a signal from the Store. This is + // nonblocking function returning -1 if the signal is not available yet. + int getSignalSrcRank( + c10::intrusive_ptr& store, + const std::string& signal); + + uint64_t getHeartbt() const; + + void setDesyncDebug(bool desyncDebug); + + private: + std::thread ncclCommWatchdogThread_; + + // We need to keep a reference to the PG instance so that we can access + // the member functions of the PG instance. We store a raw pointer on + // purpose because the watchdog thread now still lives within the + // lifetime of the PG instance. + ProcessGroupNCCL* pg_; + + // Whether the NCCL watchdog should rethrow CUDA errors. + bool rethrowCUDAErrors_ = false; + + std::exception_ptr watchDogException_ = nullptr; + + // Condition Variable for watchdog thread sleep + std::condition_variable workMetaListCV_; + + // Heartbeat of watchdog thread. + std::atomic_uint64_t heartbeat_{}; + + // Whether or not to propagate detected errors to all ranks in the same PG + // through TCPStore. + bool propagatePgError_; + + // Whether or not to enable timeout root cause analysis. + bool desyncDebug_; + + DesyncDebugger desyncDebugger_; + }; + // If you wish to create multiple process groups, each with a potentially // different rank and size, you can do so by passing a new store instance // to each one. If you have only a single store object, you can @@ -947,6 +1022,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Instance of the heartbeat monitor thread. std::unique_ptr heartbeatMonitor_; + // Instance of the watchdog thread. + std::unique_ptr watchdog_; + // Helper that broadcasts nccl unique ID to all ranks through the store void broadcastUniqueNCCLID( ncclUniqueId* ncclID, @@ -1082,17 +1160,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { static std::exception_ptr checkForNCCLErrorsInternal( std::shared_ptr& ncclComm); - // Function that runs as part of a separate thread and checks for errors on - // NCCL communicators. We need a separate thread to check for NCCL errors - // since we can't rely on the user calling certain methods like wait(), - // isCompleted() etc. to detect and remediate errors. In addition to this, we - // need a mechanism to safely abort and remove NCCL communicators from our - // cache. This can be done cleanly by having a thread for the ProcessGroupNCCL - // class. Attempting to modify the communicator cache from the WorkNCCL class - // might run into issues with object lifetime since the ProcessGroupNCCL - // object might get destroyed before the WorkNCCL object. - void ncclCommWatchdog(); - // Return the CUDA device most likely associated with this backend. // If we aren't bound to a specific device, there is no strict // guarantee that this heuristic is the correct assignment of ranks @@ -1106,11 +1173,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // communicators from the cache and clears used device indices. void destroyNCCLComms(const std::string& devNCCLCommMapKey); - // Watchdog's inside loop. - // Takes care of cleaning up completed work, and aborting upon failure or - // timeout. - void watchdogHandler(); - void runHookLoop(); // Generates a prefix that is unique to this process group and rank, for @@ -1146,12 +1208,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { const std::string& signal, int srcRank); - // A helper function to get the src rank of a signal from the Store. This is - // nonblocking function returning -1 if the signal is not available yet. - int getSignalSrcRank( - c10::intrusive_ptr& store, - const std::string& signal); - protected: // Function that directly trigger std::abort so that the whole process // gets terminated. @@ -1166,8 +1222,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { ::c10d::C10dLoggingData& debugLog, bool throwException = false); - void checkAndSetRemoteError(); - // A helper function to guess the device id of the current rank, based on // bounded device or used device. Do not use this function if you already know // the device id to operate on. @@ -1245,21 +1299,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Mutex to guard maps like devNCCLCommMap_. std::mutex mutex_; - // Heartbeat of watchdog thread. - std::atomic_uint64_t heartbeat_{}; - - // timeout for the dump to finish. - int waitTimeoutDumpInMilSec_; - // Size of ring buffer where we store NCCL Traces for debugging. int traceBufferSize_; // We gate the cudaEventCache so that we can roll it out gradually. std::atomic cudaEventCacheEnabled_{}; - // Watchdog thread which looks for errors on the cached NCCL communicators. - std::thread ncclCommWatchdogThread_; - std::thread onCompletionHookThread_; // Whether or not we should terminate the watchdog and workCleanup threads. @@ -1286,9 +1331,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { bool writeDebugInfo_ = false; - // Condition Variable for watchdog thread sleep - std::condition_variable workMetaListCV_; - // Vector to store WorkNCCL pointers std::list workMetaList_; @@ -1349,14 +1391,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::mutex errorMutex_; - // Whether or not to enable timeout root cause analysis. - bool desyncDebug_; - DesyncDebugger desyncDebugger_; - - // Whether or not to propagate detected errors to all ranks in the same PG - // through TCPStore. - bool propagatePgError_; - // Whether or not to sleep after an exception is thrown in the watchdog. bool sleepAfterException_{}; @@ -1375,9 +1409,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set bool avoidRecordStreams_ = false; - // Whether the NCCL watchdog should rethrow CUDA errors. - bool rethrowCUDAErrors_ = false; - // The number of active ncclGroupStart() calls. This counter will be increased // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd() // is called. @@ -1395,8 +1426,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // the ProcessGroup uint64_t op_id_{0}; - std::exception_ptr watchDogException_ = nullptr; - // The number of ProcessGroupNCCL created on the current rank. size_t local_id_;