[c10d][PGNCCL] Make watchdog thread a class (#155831)

By extracting both monitor thread and watchdog thread into a separate class this will help us learn what dependencies we have for each thread and it will kind of simplify the consolidation work for each thread (consolidating from thread per PG instance to per PG class)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155831
Approved by: https://github.com/d4l3k, https://github.com/kwen2501
This commit is contained in:
fduwjj 2025-06-12 13:56:35 -07:00 committed by PyTorch MergeBot
parent c5d00e150a
commit ce44877961
3 changed files with 452 additions and 381 deletions

View File

@ -164,7 +164,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
// so we have this hack to manually set the desync debug flag after PG
// creation.
void forceSetDesyncDebugFlag() {
desyncDebug_ = true;
watchdog_->setDesyncDebug(true);
}
private:

View File

@ -511,7 +511,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
bool isP2P,
const char* profilingTitle,
const std::optional<std::vector<at::Tensor>>& inputs,
bool desyncDebug,
bool enableTiming,
bool cudaEventCacheEnabled,
DebugLevel distDebugLevel)
@ -952,15 +951,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false);
asyncErrorHandling_ = static_cast<ErrorHandlingMode>(
getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/));
desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
(dist_debug_level_ >= DebugLevel::Detail);
rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true);
propagatePgError_ = getCvarBool(TORCH_NCCL_PROPAGATE_ERROR, false);
enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false);
heartbeat_ = 1ULL;
cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, true));
waitTimeoutDumpInMilSec_ =
getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /*15 Sec*/);
traceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 2000);
enableCollectiveHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
// store_ usually is wrapped with PrefixStore and the prefix is different
@ -970,9 +962,11 @@ ProcessGroupNCCL::ProcessGroupNCCL(
PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
globalStore_ =
prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
auto desyncDebug = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
(dist_debug_level_ >= DebugLevel::Detail);
#ifdef ENABLE_NCCL_ERROR_CHECKING
enableTiming_.store(
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug);
#endif // ENABLE_NCCL_ERROR_CHECKING
if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) {
TORCH_WARN_ONCE(
@ -984,7 +978,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< logPrefix()
<< "TORCH_NCCL_BLOCKING_WAIT is enabled, NO watchdog thread is created.";
} else {
if (desyncDebug_ && asyncErrorHandling_ == NoHandling) {
if (desyncDebug && asyncErrorHandling_ == NoHandling) {
LOG(INFO)
<< logPrefix()
<< "TORCH_NCCL_DESYNC_DEBUG and TORCH_NCCL_ASYNC_ERROR_HANDLING "
@ -994,17 +988,17 @@ ProcessGroupNCCL::ProcessGroupNCCL(
}
}
// Initialize the heartbeat monitor instance. This has to be done before
// the watchdog thread is launched to avoid the error.
// Initialize the heartbeat monitor/watchdog instance. This has to be done
// before the corresponding thread is launched to avoid the error.
heartbeatMonitor_ = std::make_unique<HeartbeatMonitor>(this);
watchdog_ = std::make_unique<Watchdog>(this);
#ifdef ENABLE_NCCL_ERROR_CHECKING
// in blockingWait mode, we don't need to enable the watchdog thread to check
// the timeout or nccl error because the main thread would throw an exception
// and it is the user's responsibility to handle the exception.
if (!blockingWait_) {
ncclCommWatchdogThread_ =
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
watchdog_->start();
}
#endif // ENABLE_NCCL_ERROR_CHECKING
@ -1024,8 +1018,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: "
<< "NCCL version: " << ncclVersion
<< ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
<< ", TORCH_NCCL_PROPAGATE_ERROR: " << propagatePgError_
<< ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
<< ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
<< ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
@ -1050,11 +1042,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
// This call is idempotent.
attachAllocatorHooks();
}
// Enable Desync Debugger per user setting
if (desyncDebug_) {
desyncDebugger_.init(rank, size, globalRank(), getUid(), store_);
}
}
void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
@ -1271,7 +1258,7 @@ void ProcessGroupNCCL::waitForPendingWorks() {
// completedWorkList_ before it finishes.
// 3. We have three threads and two locks.
// a. main thread (this function) grabs two locks atomically
// b. watchdog thread (watchdogHandler function) always grabs
// b. watchdog thread (runLoop function) always grabs
// workMetaListMutex_
// first and then grabs completedWorkListMutex_.
// c. hook thread (runHookLoop function) only grabs
@ -1425,7 +1412,7 @@ void ProcessGroupNCCL::abort() {
// communicators and signal the threads to exit. Joining on the threads could
// potentially block and hence avoid it in this method.
terminateProcessGroup_.store(true);
workMetaListCV_.notify_one();
watchdog_->notify();
// lauch abort asynchrounously and wait for it to complete or timeout
LOG(INFO) << logPrefix()
@ -1480,10 +1467,8 @@ void ProcessGroupNCCL::shutdown() {
// anymore because I am going to destroy them now
LOG(INFO) << logPrefix() << "Operations flushed, joining watchdog thread.";
terminateProcessGroup_.store(true);
workMetaListCV_.notify_one();
if (ncclCommWatchdogThread_.joinable()) {
ncclCommWatchdogThread_.join();
}
watchdog_->notify();
watchdog_->join();
if (onCompletionHookThread_.joinable()) {
onCompletionHookThread_.join();
}
@ -1542,15 +1527,12 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
// Make sure we've told threads to stop; doesn't hurt if we'd done so before.
// Tell watchdog and onCompletionHook:
terminateProcessGroup_.store(true);
workMetaListCV_.notify_one();
watchdog_->notify();
// Tell heartbeat thread:
heartbeatMonitor_->stop();
// Wait for all threads to finish before returning
if (ncclCommWatchdogThread_.joinable()) {
ncclCommWatchdogThread_.join();
LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined.";
}
watchdog_->join();
heartbeatMonitor_->join();
if (onCompletionHookThread_.joinable()) {
onCompletionHookThread_.join();
@ -1639,6 +1621,10 @@ void ProcessGroupNCCL::HeartbeatMonitor::setLastWorkListUpdateTime(
lastWorkListUpdateTime_ = time;
}
int ProcessGroupNCCL::HeartbeatMonitor::getDumpTimeout() const {
return waitTimeoutDumpInMilSec_;
}
ProcessGroupNCCL::HeartbeatMonitor::HeartbeatMonitor(ProcessGroupNCCL* pg) {
pg_ = pg;
heartbeatTimeoutInSec_ =
@ -1969,27 +1955,69 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() {
}
}
void ProcessGroupNCCL::ncclCommWatchdog() {
ProcessGroupNCCL::Watchdog::Watchdog(ProcessGroupNCCL* pg) {
pg_ = pg;
heartbeat_ = 1ULL;
rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true);
propagatePgError_ = getCvarBool(TORCH_NCCL_PROPAGATE_ERROR, false);
desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
(pg_->dist_debug_level_ >= DebugLevel::Detail);
// print out ENV settings for the watchdog thread.
LOG(INFO) << pg_->logPrefix() << "PGNCCL Watchdog environments: "
<< "TORCH_NCCL_RETHROW_CUDA_ERRORS: " << rethrowCUDAErrors_
<< ", TORCH_NCCL_PROPAGATE_ERROR: " << propagatePgError_
<< ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_;
// Enable Desync Debugger per user setting
if (desyncDebug_) {
desyncDebugger_.init(
pg_->getRank(),
pg_->getSize(),
pg_->globalRank(),
pg_->getUid(),
pg_->store_);
}
}
void ProcessGroupNCCL::Watchdog::notify() {
workMetaListCV_.notify_one();
}
void ProcessGroupNCCL::Watchdog::start() {
TORCH_CHECK(
!ncclCommWatchdogThread_.joinable(), "Watchdog thread already started");
ncclCommWatchdogThread_ = std::thread(&ProcessGroupNCCL::Watchdog::run, this);
}
void ProcessGroupNCCL::Watchdog::join() {
if (ncclCommWatchdogThread_.joinable()) {
ncclCommWatchdogThread_.join();
LOG(INFO) << pg_->logPrefix() << "ProcessGroupNCCL watchdog thread joined.";
}
}
void ProcessGroupNCCL::Watchdog::run() {
c10::setThreadName("pt_nccl_watchdg");
try {
VLOG(2) << logPrefix() << "Process group watchdog thread started!";
heartbeatMonitor_->start();
watchdogHandler();
VLOG(2) << logPrefix()
VLOG(2) << pg_->logPrefix() << "Process group watchdog thread started!";
pg_->heartbeatMonitor_->start();
runLoop();
VLOG(2) << pg_->logPrefix()
<< "Process group watchdog thread terminated normally";
} catch (std::exception& e) {
if (std::string(e.what()).find("driver shutting down") !=
std::string::npos) {
VLOG(2)
<< logPrefix()
<< pg_->logPrefix()
<< "main process destroyed cuda before watchdog loop exited, terminating watchdog."
<< " (Watchdog caught exception: " << e.what();
} else {
// Append error message reported from watchdogHandler
// Append error message reported from runLoop
const auto exitMsg = c10::str(
logPrefix(),
pg_->logPrefix(),
"Process group watchdog thread terminated with exception: ",
e.what());
LOG(ERROR) << exitMsg;
@ -2004,7 +2032,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
}
} catch (...) {
const auto exitMsg = c10::str(
logPrefix(),
pg_->logPrefix(),
"Process group watchdog thread terminated with exception: unknown");
LOG(ERROR) << exitMsg;
watchDogException_ =
@ -2013,6 +2041,308 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
}
}
int ProcessGroupNCCL::Watchdog::getSignalSrcRank(
c10::intrusive_ptr<Store>& store,
const std::string& signal) {
// This function is 'non blocking'. We first 'check' if the key exists in the
// store, then read/get the value only if the key exists.
int srcRank = -1;
bool signalExists = false;
try {
signalExists = store->check({signal});
} catch (const std::exception& e) {
LOG(WARNING) << pg_->logPrefix() << "Failed to check the signal " << signal
<< " on TCPStore, " << e.what();
}
if (!signalExists) {
return srcRank;
}
// key exists, now read and parse the value (source rank)
std::vector<uint8_t> vec;
try {
vec = store->get(std::string(signal));
} catch (const std::exception& e) {
LOG(ERROR) << pg_->logPrefix() << "Failed to get source rank of the signal "
<< signal << " from TCPStore." << e.what();
}
TORCH_CHECK_WITH(
DistBackendError,
vec.size() == sizeof(int),
"Invalid size for the timeout rank ID");
std::memcpy(&srcRank, vec.data(), vec.size());
return srcRank;
}
void ProcessGroupNCCL::Watchdog::checkAndSetRemoteError() {
// if the error is already set, no need to check again
if (pg_->getError() != ErrorType::SUCCESS) {
return;
}
// key/signal to read from the tcpstore is a string and pg specific:
// format is: remote_error:pg_uid
int remoteErrorRank = getSignalSrcRank(
pg_->store_, std::string(kStoreErrorSignalKey) + ':' + pg_->pg_uid_);
if (remoteErrorRank != -1) {
std::lock_guard<std::mutex> lock(pg_->errorMutex_);
pg_->error_ = ErrorType::REMOTE_ERROR;
LOG(ERROR) << c10::str(
pg_->logPrefix(),
" remote error detected from rank: ",
remoteErrorRank);
}
}
void ProcessGroupNCCL::Watchdog::runLoop() {
bool done = false;
pg_->heartbeatMonitor_->setLastWorkListUpdateTime(
std::chrono::steady_clock::now());
auto lastStatusUpdateTime = std::chrono::steady_clock::now();
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
while (!done || !pg_->terminateProcessGroup_.load()) {
std::unique_lock<std::mutex> lock(pg_->workMetaListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis
// milliseconds as long as the atomic is True.
workMetaListCV_.wait_for(
lock,
std::chrono::milliseconds(kWatchdogThreadSleepMillis),
[&]() -> bool { return pg_->terminateProcessGroup_.load(); });
// Bump up heart beat by one.
heartbeat_++;
// Some versions of GLOG support less-spammy version of LOG_EVERY_MS
// in which case we don't want to spam the logs.
#ifdef LOG_EVERY_MS
// Log the progress of this PG periodically
C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str(
logPrefix(),
"NCCL Work update periodically: ",
"last enqueued NCCL work: ",
pg_->pgStatus_->lastEnqueuedSeq,
", last completed NCCL work: ",
pg_->pgStatus_->lastCompletedSeq,
".");
#endif // LOG_EVERY_MS
auto logger = ::c10d::C10dLogger::getLogger();
if (logger &&
computeDeltaMS(
lastStatusUpdateTime, std::chrono::steady_clock::now()) >=
kWorkStatusUpdatePeriodMs) {
::c10d::C10dLoggingData data;
// logging integers
data.integers["pg_id"] = static_cast<int64_t>(pg_->local_id_);
data.integers["rank"] = pg_->rank_;
data.integers["global_rank"] = pg_->globalRank();
data.integers["last_enqueued_work"] = pg_->pgStatus_->lastEnqueuedSeq;
data.integers["last_started_work"] = pg_->pgStatus_->lastStartedSeq;
data.integers["last_completed_work"] = pg_->pgStatus_->lastCompletedSeq;
data.integers["last_enqueued_numel_in"] =
static_cast<int64_t>(pg_->pgStatus_->lastEnqueuedNumelIn);
data.integers["last_enqueued_numel_out"] =
static_cast<int64_t>(pg_->pgStatus_->lastEnqueuedNumelOut);
data.integers["last_completed_numel_in"] =
static_cast<int64_t>(pg_->pgStatus_->lastCompletedNumelIn);
data.integers["last_completed_numel_out"] =
static_cast<int64_t>(pg_->pgStatus_->lastCompletedNumelOut);
data.integers["last_started_numel_in"] =
static_cast<int64_t>(pg_->pgStatus_->lastStartedNumelIn);
data.integers["last_started_numel_out"] =
static_cast<int64_t>(pg_->pgStatus_->lastStartedNumelOut);
// logging strings
data.strings["last_enqueued_work_name"] =
pg_->pgStatus_->lastEnqueuedWorkName;
data.strings["last_started_work_name"] =
pg_->pgStatus_->lastStartedWorkName;
data.strings["last_completed_work_name"] =
pg_->pgStatus_->lastCompletedWorkName;
data.strings["pg_name"] = pg_->pg_uid_;
data.strings["pg_desc"] = pg_->pg_desc_;
logger->log(data);
lastStatusUpdateTime = std::chrono::steady_clock::now();
}
if (propagatePgError_) {
// Check and set remote error if it has not been set before
checkAndSetRemoteError();
}
for (auto it = pg_->workMetaList_.begin(); it != pg_->workMetaList_.end();
/* no increment */) {
auto& work = *it;
// When terminateProcessGroup_ is true, communicators have already been
// aborted, So cannot check exception based on them. But watchdog needs to
// finish the check for the works that have already been enqueued to
// workMetaList_
// check NCCL errors first
if (!pg_->terminateProcessGroup_.load()) {
work.checkAndSetException();
}
if (work.exception()) {
// set the error to the first error found
std::lock_guard<std::mutex> lock(pg_->errorMutex_);
if (pg_->error_ == ErrorType::SUCCESS) {
pg_->error_ = ErrorType::COMM_ERROR;
}
}
// Then check if work has timed out
// Skip if work has encountered an error
bool timedout = !work.exception() && work.checkTimeout();
// Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is
// turned on; otherwise, run() is no-op)
if (timedout) {
std::lock_guard<std::mutex> lock(pg_->errorMutex_);
if (pg_->error_ == ErrorType::SUCCESS) {
pg_->error_ = ErrorType::TIMEOUT;
}
desyncDebugger_.run();
}
// If work hits an exception (either an error or timeout)
if (work.exception()) {
LOG(ERROR) << c10::str(
pg_->logPrefix(),
" failure detected by watchdog at work sequence id: ",
work.seq_,
" PG status: last enqueued work: ",
pg_->pgStatus_->lastEnqueuedSeq,
", last completed work: ",
pg_->pgStatus_->lastCompletedSeq);
// Print the traceback of the collective at call time
work.printTraceback();
// broadcast remote error signal to all other ranks in this specific PG.
// key/signal to write in the tcpstore is a string and pg specific:
// format is: remote_error:pg_uid
if (propagatePgError_) {
pg_->broadcastSignal(
pg_->store_,
std::string(kStoreErrorSignalKey) + ':' + pg_->pg_uid_,
pg_->rank_);
}
// try to notify other ranks via global TCPStore to dump the flight
// recorder when a collective timeout or exception happens. Flight
// recorder behavior is independent of desync Debug.
pg_->broadcastDumpSignal();
// Give time for dumping before throwing exception for all ranks.
// It is hard to presume or control what the pattern of watchdog might
// look like, so it is better to let all ranks universally sleep for a
// short period of time, in this case, 60 seconds, which is also the
// maximum time we leave for FR dump.
std::this_thread::sleep_for(std::chrono::milliseconds(
pg_->heartbeatMonitor_->getDumpTimeout() * 4));
if (SHOULD_CLEAN_UP(pg_->asyncErrorHandling_)) {
// Abort work and corresponding communicators
work.abort();
// PG level abort, which would abort all other communicators on this
// rank
pg_->abortComms();
}
// Throw exception
work.handleException(pg_->asyncErrorHandling_);
}
// Work status logging for desync debug
desyncDebugger_.logWorkStart(work);
// a work could be started but not completed, so we should not update
// lastStartedSeq and lastStartedOpName if the work state is checked
// multiple times after the start
if (pg_->pgStatus_->lastStartedSeq < static_cast<int64_t>(work.seq_) &&
work.isStarted()) {
pg_->pgStatus_->lastStartedSeq = static_cast<int64_t>(work.seq_);
pg_->pgStatus_->lastStartedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastStartedNumelIn = work.numelIn_;
pg_->pgStatus_->lastStartedNumelOut = work.numelOut_;
}
// allow watchdog to do an event query on a side thread
at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index());
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal};
// Clean up completed work
if (work.isCompleted()) {
// In case user didn't call `work.wait()` with async collectives,
// watchdog would unstage the stashed tensors when detecting completion
// of the collective, to prevent ProcessGroupNCCL from holding reference
// to those tensors forever.
// work.stashed_for_allocator_safety_->unstash();
// Update: it seems directly unstashing from watchdog thread would cause
// some rare problems. We thus move the unstashing to main thread,
// triggered by a next user call, see `workEnqueue`. But `work` is going
// to be destructed, so we transfer the work's shelf to a shelves
// structure owned by the PG.
if (!work.stashed_for_allocator_safety_->empty()) {
std::lock_guard<std::mutex> lock(pg_->shelvesMutex_);
// We are just pushing back a shared_ptr here, so the cost should be
// minimal
pg_->shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_);
}
// Work status logging for desync debug
desyncDebugger_.logWorkEnd(work);
if (work.futureWorkResult_ && work.finishedGPUExecutionInternal() &&
!work.futureWorkResult_->completed()) {
work.futureWorkResult_->markCompleted(
at::IValue(static_cast<uint8_t>(WorkResult::SUCCESS)));
}
{
// Reset the timeout and first work if the work is completed.
std::lock_guard<std::mutex> timeoutLock(pg_->mtxTimeoutExtension_);
if (work.ownedEphermeralTimeout_.count() > 0) {
pg_->ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
pg_->ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
}
}
pg_->pgStatus_->lastCompletedSeq = static_cast<int64_t>(work.seq_);
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
if (pg_->onCompletionHook_) {
// Move Work object to completedWorkList_ to be consumed by the hook
// thread
{
const std::lock_guard<std::mutex> lock(
pg_->completedWorkListMutex_);
pg_->completedWorkList_.splice(
pg_->completedWorkList_.end(), pg_->workMetaList_, it++);
}
pg_->completedWorkListCV_.notify_one();
} else {
it = pg_->workMetaList_.erase(it);
pg_->heartbeatMonitor_->setLastWorkListUpdateTime(
std::chrono::steady_clock::now());
}
} else {
// Increment the iterator if the current WorkNCCL object is not
// completed.
++it;
}
// Increment heartbeat after each work processed,
// in case processing is slowed down (but not hung) by cuda api contention
heartbeat_++;
}
done = pg_->workMetaList_.empty();
}
}
uint64_t ProcessGroupNCCL::Watchdog::getHeartbt() const {
return heartbeat_.load();
}
void ProcessGroupNCCL::Watchdog::setDesyncDebug(bool desyncDebug) {
desyncDebug_ = desyncDebug;
}
// Initialize and enable DesyncDebugger
void ProcessGroupNCCL::DesyncDebugger::init(
int rank,
@ -2182,39 +2512,6 @@ void ProcessGroupNCCL::broadcastSignal(
}
}
int ProcessGroupNCCL::getSignalSrcRank(
c10::intrusive_ptr<Store>& store,
const std::string& signal) {
// This function is 'non blocking'. We first 'check' if the key exists in the
// store, then read/get the value only if the key exists.
int srcRank = -1;
bool signalExists = false;
try {
signalExists = store->check({signal});
} catch (const std::exception& e) {
LOG(WARNING) << logPrefix() << "Failed to check the signal " << signal
<< " on TCPStore, " << e.what();
}
if (!signalExists) {
return srcRank;
}
// key exists, now read and parse the value (source rank)
std::vector<uint8_t> vec;
try {
vec = store->get(std::string(signal));
} catch (const std::exception& e) {
LOG(ERROR) << logPrefix() << "Failed to get source rank of the signal "
<< signal << " from TCPStore." << e.what();
}
TORCH_CHECK_WITH(
DistBackendError,
vec.size() == sizeof(int),
"Invalid size for the timeout rank ID");
std::memcpy(&srcRank, vec.data(), vec.size());
return srcRank;
}
void ProcessGroupNCCL::broadcastDumpSignal() {
// broadcast dump signal to all other global ranks.
broadcastSignal(globalStore_, std::string(kStoreDumpKey), globalRank());
@ -2226,23 +2523,6 @@ void ProcessGroupNCCL::broadcastDumpSignal() {
}
}
void ProcessGroupNCCL::checkAndSetRemoteError() {
// if the error is already set, no need to check again
if (getError() != ErrorType::SUCCESS) {
return;
}
// key/signal to read from the tcpstore is a string and pg specific:
// format is: remote_error:pg_uid
int remoteErrorRank = getSignalSrcRank(
store_, std::string(kStoreErrorSignalKey) + ':' + pg_uid_);
if (remoteErrorRank != -1) {
std::lock_guard<std::mutex> lock(errorMutex_);
error_ = ErrorType::REMOTE_ERROR;
LOG(ERROR) << c10::str(
logPrefix(), " remote error detected from rank: ", remoteErrorRank);
}
}
// NCCL recommends to evenly distribute ncclUniqueIds across the ranks
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#init-rank-config
// Lets consider an example where:
@ -2269,243 +2549,6 @@ static int getRootIndex(const int rank, const int nRanks, const int nIds) {
}
}
void ProcessGroupNCCL::watchdogHandler() {
bool done = false;
heartbeatMonitor_->setLastWorkListUpdateTime(
std::chrono::steady_clock::now());
auto lastStatusUpdateTime = std::chrono::steady_clock::now();
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
while (!done || !terminateProcessGroup_.load()) {
std::unique_lock<std::mutex> lock(workMetaListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis
// milliseconds as long as the atomic is True.
workMetaListCV_.wait_for(
lock,
std::chrono::milliseconds(kWatchdogThreadSleepMillis),
[&]() -> bool { return terminateProcessGroup_.load(); });
// Bump up heart beat by one.
heartbeat_++;
// Some versions of GLOG support less-spammy version of LOG_EVERY_MS
// in which case we don't want to spam the logs.
#ifdef LOG_EVERY_MS
// Log the progress of this PG periodically
C10_LOG_EVERY_MS(INFO, kWorkStatusUpdatePeriodMs) << c10::str(
logPrefix(),
"NCCL Work update periodically: ",
"last enqueued NCCL work: ",
pgStatus_->lastEnqueuedSeq,
", last completed NCCL work: ",
pgStatus_->lastCompletedSeq,
".");
#endif // LOG_EVERY_MS
auto logger = ::c10d::C10dLogger::getLogger();
if (logger &&
computeDeltaMS(
lastStatusUpdateTime, std::chrono::steady_clock::now()) >=
kWorkStatusUpdatePeriodMs) {
::c10d::C10dLoggingData data;
// logging integers
data.integers["pg_id"] = static_cast<int64_t>(local_id_);
data.integers["rank"] = rank_;
data.integers["global_rank"] = globalRank();
data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq;
data.integers["last_started_work"] = pgStatus_->lastStartedSeq;
data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq;
data.integers["last_enqueued_numel_in"] =
static_cast<int64_t>(pgStatus_->lastEnqueuedNumelIn);
data.integers["last_enqueued_numel_out"] =
static_cast<int64_t>(pgStatus_->lastEnqueuedNumelOut);
data.integers["last_completed_numel_in"] =
static_cast<int64_t>(pgStatus_->lastCompletedNumelIn);
data.integers["last_completed_numel_out"] =
static_cast<int64_t>(pgStatus_->lastCompletedNumelOut);
data.integers["last_started_numel_in"] =
static_cast<int64_t>(pgStatus_->lastStartedNumelIn);
data.integers["last_started_numel_out"] =
static_cast<int64_t>(pgStatus_->lastStartedNumelOut);
// logging strings
data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName;
data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName;
data.strings["last_completed_work_name"] =
pgStatus_->lastCompletedWorkName;
data.strings["pg_name"] = pg_uid_;
data.strings["pg_desc"] = pg_desc_;
logger->log(data);
lastStatusUpdateTime = std::chrono::steady_clock::now();
}
if (propagatePgError_) {
// Check and set remote error if it has not been set before
checkAndSetRemoteError();
}
for (auto it = workMetaList_.begin(); it != workMetaList_.end();
/* no increment */) {
auto& work = *it;
// When terminateProcessGroup_ is true, communicators have already been
// aborted, So cannot check exception based on them. But watchdog needs to
// finish the check for the works that have already been enqueued to
// workMetaList_
// check NCCL errors first
if (!terminateProcessGroup_.load()) {
work.checkAndSetException();
}
if (work.exception()) {
// set the error to the first error found
std::lock_guard<std::mutex> lock(errorMutex_);
if (error_ == ErrorType::SUCCESS) {
error_ = ErrorType::COMM_ERROR;
}
}
// Then check if work has timed out
// Skip if work has encountered an error
bool timedout = !work.exception() && work.checkTimeout();
// Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is
// turned on; otherwise, run() is no-op)
if (timedout) {
std::lock_guard<std::mutex> lock(errorMutex_);
if (error_ == ErrorType::SUCCESS) {
error_ = ErrorType::TIMEOUT;
}
desyncDebugger_.run();
}
// If work hits an exception (either an error or timeout)
if (work.exception()) {
LOG(ERROR) << c10::str(
logPrefix(),
" failure detected by watchdog at work sequence id: ",
work.seq_,
" PG status: last enqueued work: ",
pgStatus_->lastEnqueuedSeq,
", last completed work: ",
pgStatus_->lastCompletedSeq);
// Print the traceback of the collective at call time
work.printTraceback();
// broadcast remote error signal to all other ranks in this specific PG.
// key/signal to write in the tcpstore is a string and pg specific:
// format is: remote_error:pg_uid
if (propagatePgError_) {
broadcastSignal(
store_, std::string(kStoreErrorSignalKey) + ':' + pg_uid_, rank_);
}
// try to notify other ranks via global TCPStore to dump the flight
// recorder when a collective timeout or exception happens. Flight
// recorder behavior is independent of desync Debug.
broadcastDumpSignal();
// Give time for dumping before throwing exception for all ranks.
// It is hard to presume or control what the pattern of watchdog might
// look like, so it is better to let all ranks universally sleep for a
// short period of time, in this case, 60 seconds, which is also the
// maximum time we leave for FR dump.
std::this_thread::sleep_for(
std::chrono::milliseconds(waitTimeoutDumpInMilSec_ * 4));
if (SHOULD_CLEAN_UP(asyncErrorHandling_)) {
// Abort work and corresponding communicators
work.abort();
// PG level abort, which would abort all other communicators on this
// rank
abortComms();
}
// Throw exception
work.handleException(asyncErrorHandling_);
}
// Work status logging for desync debug
desyncDebugger_.logWorkStart(work);
// a work could be started but not completed, so we should not update
// lastStartedSeq and lastStartedOpName if the work state is checked
// multiple times after the start
if (pgStatus_->lastStartedSeq < static_cast<int64_t>(work.seq_) &&
work.isStarted()) {
pgStatus_->lastStartedSeq = static_cast<int64_t>(work.seq_);
pgStatus_->lastStartedWorkName = opTypeToString(work.opType_);
pgStatus_->lastStartedNumelIn = work.numelIn_;
pgStatus_->lastStartedNumelOut = work.numelOut_;
}
// allow watchdog to do an event query on a side thread
at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index());
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal};
// Clean up completed work
if (work.isCompleted()) {
// In case user didn't call `work.wait()` with async collectives,
// watchdog would unstage the stashed tensors when detecting completion
// of the collective, to prevent ProcessGroupNCCL from holding reference
// to those tensors forever.
// work.stashed_for_allocator_safety_->unstash();
// Update: it seems directly unstashing from watchdog thread would cause
// some rare problems. We thus move the unstashing to main thread,
// triggered by a next user call, see `workEnqueue`. But `work` is going
// to be destructed, so we transfer the work's shelf to a shelves
// structure owned by the PG.
if (!work.stashed_for_allocator_safety_->empty()) {
std::lock_guard<std::mutex> lock(shelvesMutex_);
// We are just pushing back a shared_ptr here, so the cost should be
// minimal
shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_);
}
// Work status logging for desync debug
desyncDebugger_.logWorkEnd(work);
if (work.futureWorkResult_ && work.finishedGPUExecutionInternal() &&
!work.futureWorkResult_->completed()) {
work.futureWorkResult_->markCompleted(
at::IValue(static_cast<uint8_t>(WorkResult::SUCCESS)));
}
{
// Reset the timeout and first work if the work is completed.
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
if (work.ownedEphermeralTimeout_.count() > 0) {
ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
}
}
pgStatus_->lastCompletedSeq = static_cast<int64_t>(work.seq_);
pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
pgStatus_->lastCompletedNumelIn = work.numelIn_;
pgStatus_->lastCompletedNumelOut = work.numelOut_;
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
if (onCompletionHook_) {
// Move Work object to completedWorkList_ to be consumed by the hook
// thread
{
const std::lock_guard<std::mutex> lock(completedWorkListMutex_);
completedWorkList_.splice(
completedWorkList_.end(), workMetaList_, it++);
}
completedWorkListCV_.notify_one();
} else {
it = workMetaList_.erase(it);
heartbeatMonitor_->setLastWorkListUpdateTime(
std::chrono::steady_clock::now());
}
} else {
// Increment the iterator if the current WorkNCCL object is not
// completed.
++it;
}
// Increment heartbeat after each work processed,
// in case processing is slowed down (but not hung) by cuda api contention
heartbeat_++;
}
done = workMetaList_.empty();
}
}
void ProcessGroupNCCL::runHookLoop() {
c10::setThreadName("pt_nccl_runhook");
@ -3208,7 +3251,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
profilingTitle,
profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
: std::nullopt,
desyncDebug_,
enableTiming_.load(),
cudaEventCacheEnabled_.load(),
dist_debug_level_);
@ -3329,7 +3371,7 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream)
static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;
uint64_t ProcessGroupNCCL::getWatchdogHeartbt() const {
return heartbeat_.load();
return watchdog_->getHeartbt();
}
void ProcessGroupNCCL::startCoalescing() {

View File

@ -319,7 +319,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
bool isP2P = false,
const char* profilingTitle = nullptr,
const std::optional<std::vector<at::Tensor>>& inputs = std::nullopt,
bool desyncDebug = false,
bool enableTiming = false,
bool cudaEventCacheEnabled = false,
DebugLevel distDebugLevel = DebugLevel::Off);
@ -621,6 +620,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void setLastWorkListUpdateTime(
std::chrono::time_point<std::chrono::steady_clock> time);
int getDumpTimeout() const;
// Util function to get the timeout error message
std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg);
@ -676,6 +677,80 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
};
// Class that runs as a side thread to check whether the NCCL collective
// is timed out or errors on the cached NCCL communicators.
class Watchdog {
public:
Watchdog(ProcessGroupNCCL* pg);
virtual ~Watchdog() = default;
// Start the watchdog thread.
void start();
// Join the watchdog thread.
void join();
// Function that runs as part of a separate thread and checks for errors on
// NCCL communicators. We need a separate thread to check for NCCL errors
// since we can't rely on the user calling certain methods like wait(),
// isCompleted() etc. to detect and remediate errors. In addition to this,
// we need a mechanism to safely abort and remove NCCL communicators from
// our cache. This can be done cleanly by having a thread for the
// ProcessGroupNCCL class. Attempting to modify the communicator cache from
// the WorkNCCL class might run into issues with object lifetime since the
// ProcessGroupNCCL object might get destroyed before the WorkNCCL object.
void run();
// Watchdog's inside loop.
// Takes care of cleaning up completed work, and aborting upon failure or
// timeout.
void runLoop();
// Notify the loop inside watchdog.
void notify();
void checkAndSetRemoteError();
// A helper function to get the src rank of a signal from the Store. This is
// nonblocking function returning -1 if the signal is not available yet.
int getSignalSrcRank(
c10::intrusive_ptr<Store>& store,
const std::string& signal);
uint64_t getHeartbt() const;
void setDesyncDebug(bool desyncDebug);
private:
std::thread ncclCommWatchdogThread_;
// We need to keep a reference to the PG instance so that we can access
// the member functions of the PG instance. We store a raw pointer on
// purpose because the watchdog thread now still lives within the
// lifetime of the PG instance.
ProcessGroupNCCL* pg_;
// Whether the NCCL watchdog should rethrow CUDA errors.
bool rethrowCUDAErrors_ = false;
std::exception_ptr watchDogException_ = nullptr;
// Condition Variable for watchdog thread sleep
std::condition_variable workMetaListCV_;
// Heartbeat of watchdog thread.
std::atomic_uint64_t heartbeat_{};
// Whether or not to propagate detected errors to all ranks in the same PG
// through TCPStore.
bool propagatePgError_;
// Whether or not to enable timeout root cause analysis.
bool desyncDebug_;
DesyncDebugger desyncDebugger_;
};
// If you wish to create multiple process groups, each with a potentially
// different rank and size, you can do so by passing a new store instance
// to each one. If you have only a single store object, you can
@ -947,6 +1022,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Instance of the heartbeat monitor thread.
std::unique_ptr<HeartbeatMonitor> heartbeatMonitor_;
// Instance of the watchdog thread.
std::unique_ptr<Watchdog> watchdog_;
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(
ncclUniqueId* ncclID,
@ -1082,17 +1160,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
static std::exception_ptr checkForNCCLErrorsInternal(
std::shared_ptr<NCCLComm>& ncclComm);
// Function that runs as part of a separate thread and checks for errors on
// NCCL communicators. We need a separate thread to check for NCCL errors
// since we can't rely on the user calling certain methods like wait(),
// isCompleted() etc. to detect and remediate errors. In addition to this, we
// need a mechanism to safely abort and remove NCCL communicators from our
// cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
// class. Attempting to modify the communicator cache from the WorkNCCL class
// might run into issues with object lifetime since the ProcessGroupNCCL
// object might get destroyed before the WorkNCCL object.
void ncclCommWatchdog();
// Return the CUDA device most likely associated with this backend.
// If we aren't bound to a specific device, there is no strict
// guarantee that this heuristic is the correct assignment of ranks
@ -1106,11 +1173,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// communicators from the cache and clears used device indices.
void destroyNCCLComms(const std::string& devNCCLCommMapKey);
// Watchdog's inside loop.
// Takes care of cleaning up completed work, and aborting upon failure or
// timeout.
void watchdogHandler();
void runHookLoop();
// Generates a prefix that is unique to this process group and rank, for
@ -1146,12 +1208,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
const std::string& signal,
int srcRank);
// A helper function to get the src rank of a signal from the Store. This is
// nonblocking function returning -1 if the signal is not available yet.
int getSignalSrcRank(
c10::intrusive_ptr<Store>& store,
const std::string& signal);
protected:
// Function that directly trigger std::abort so that the whole process
// gets terminated.
@ -1166,8 +1222,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
::c10d::C10dLoggingData& debugLog,
bool throwException = false);
void checkAndSetRemoteError();
// A helper function to guess the device id of the current rank, based on
// bounded device or used device. Do not use this function if you already know
// the device id to operate on.
@ -1245,21 +1299,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Mutex to guard maps like devNCCLCommMap_.
std::mutex mutex_;
// Heartbeat of watchdog thread.
std::atomic_uint64_t heartbeat_{};
// timeout for the dump to finish.
int waitTimeoutDumpInMilSec_;
// Size of ring buffer where we store NCCL Traces for debugging.
int traceBufferSize_;
// We gate the cudaEventCache so that we can roll it out gradually.
std::atomic<bool> cudaEventCacheEnabled_{};
// Watchdog thread which looks for errors on the cached NCCL communicators.
std::thread ncclCommWatchdogThread_;
std::thread onCompletionHookThread_;
// Whether or not we should terminate the watchdog and workCleanup threads.
@ -1286,9 +1331,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
bool writeDebugInfo_ = false;
// Condition Variable for watchdog thread sleep
std::condition_variable workMetaListCV_;
// Vector to store WorkNCCL pointers
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
@ -1349,14 +1391,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::mutex errorMutex_;
// Whether or not to enable timeout root cause analysis.
bool desyncDebug_;
DesyncDebugger desyncDebugger_;
// Whether or not to propagate detected errors to all ranks in the same PG
// through TCPStore.
bool propagatePgError_;
// Whether or not to sleep after an exception is thrown in the watchdog.
bool sleepAfterException_{};
@ -1375,9 +1409,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set
bool avoidRecordStreams_ = false;
// Whether the NCCL watchdog should rethrow CUDA errors.
bool rethrowCUDAErrors_ = false;
// The number of active ncclGroupStart() calls. This counter will be increased
// by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd()
// is called.
@ -1395,8 +1426,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// the ProcessGroup
uint64_t op_id_{0};
std::exception_ptr watchDogException_ = nullptr;
// The number of ProcessGroupNCCL created on the current rank.
size_t local_id_;