[NCCL] Timeout Loop Thread for Async Error Handling (#41050)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41050

**This Commit:**
We introduce a workVector to track live workNCCL objects corresponding to collective operations. Further, we introduce a workCleanupLoop, which busy-polls the vector of workNCCL objects and removes them upon completion.

**This Stack:**
The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic.

Test Plan: See D22054298 for verification of correctness and performance

Reviewed By: jiayisuse

Differential Revision: D21916637

fbshipit-source-id: f8cadaab0071aaad1c4e31f9b089aa23cba0cfbe
This commit is contained in:
Omkar Salpekar 2020-09-09 12:16:19 -07:00 committed by Facebook GitHub Bot
parent 15cbd1cf4b
commit 1df24fd457
2 changed files with 65 additions and 6 deletions

View File

@ -228,6 +228,7 @@ ncclResult_t ncclAlltoallv(
} // namespace
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
const int64_t ProcessGroupNCCL::kWorkCleanupThreadSleepMillis = 1000;
constexpr int64_t kWaitForAbortCommStoreKey = 1000;
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis = 10 * 1000;
@ -399,7 +400,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
: ProcessGroup(rank, size),
store_(store),
ncclCommCounter_(0),
terminateWatchdog_(false),
terminateProcessGroup_(false),
opTimeout_(opTimeout) {
try {
parseNcclBlockingWait();
@ -424,11 +425,14 @@ ProcessGroupNCCL::ProcessGroupNCCL(
ncclCommWatchdogThread_ =
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
#endif
workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this);
}
ProcessGroupNCCL::~ProcessGroupNCCL() {
terminateWatchdog_.store(true);
terminateProcessGroup_.store(true);
watchdogCV_.notify_one();
workListCV_.notify_one();
#ifdef ENABLE_NCCL_ERROR_CHECKING
ncclCommWatchdogThread_.join();
#endif
@ -444,6 +448,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
}
}
}
workCleanupThread_.join();
}
void ProcessGroupNCCL::ncclCommWatchdog() {
@ -458,7 +463,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
}
void ProcessGroupNCCL::ncclCommWatchdogInternal() {
while (!terminateWatchdog_.load()) {
while (!terminateProcessGroup_.load()) {
std::unordered_set<std::string> abortedCommIds;
std::unordered_set<std::string> allCommIds;
@ -554,7 +559,32 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
watchdogCV_.wait_for(
lock,
std::chrono::milliseconds(kWatchdogThreadSleepMillis),
[&]() -> bool { return terminateWatchdog_.load(); });
[&]() -> bool { return terminateProcessGroup_.load(); });
}
}
void ProcessGroupNCCL::workCleanupLoop() {
while (!terminateProcessGroup_.load()) {
std::unique_lock<std::mutex> lock(workListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis
// milliseconds as long as the atomic is True.
workListCV_.wait_for(
lock,
std::chrono::milliseconds(kWorkCleanupThreadSleepMillis),
[&]() -> bool { return terminateProcessGroup_.load(); });
for (auto it = workList_.begin(); it != workList_.end();
/* no increment*/) {
auto& work = *it;
if (work->isCompleted()) {
// Remove all Completed WorkNCCL Objects from the Vector
it = workList_.erase(it);
} else {
// Increment the iterator if the current WorkNCCL object is not
// completed.
++it;
}
}
}
}
@ -797,6 +827,14 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
futureNCCLCallbackStreams_[deviceIndex]);
}
void ProcessGroupNCCL::workEnqueue(
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> work) {
if (!terminateProcessGroup_.load()) {
std::lock_guard<std::mutex> lock(workListMutex_);
workList_.emplace_back(std::move(work));
}
}
template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
std::vector<at::Tensor>& inputs,
@ -861,6 +899,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
work->store_ = store_;
}
workEnqueue(work);
return work;
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <list>
#include <mutex>
#include <thread>
#include <unordered_map>
@ -478,8 +479,11 @@ class ProcessGroupNCCL : public ProcessGroup {
// accordingly.
void parseNcclBlockingWait();
void workCleanupLoop();
protected:
static const int64_t kWatchdogThreadSleepMillis;
static const int64_t kWorkCleanupThreadSleepMillis;
// The store is used to broadcast the NCCL unique ID of rank 0.
std::shared_ptr<Store> store_;
@ -521,8 +525,8 @@ class ProcessGroupNCCL : public ProcessGroup {
// Watchdog thread which looks for errors on the cached NCCL communicators.
std::thread ncclCommWatchdogThread_;
// Whether or not we should terminate the watchdog thread.
std::atomic<bool> terminateWatchdog_;
// Whether or not we should terminate the watchdog and workCleanup threads.
std::atomic<bool> terminateProcessGroup_;
// Condition variable to control how long the watchdog thread waits.
std::condition_variable watchdogCV_;
@ -530,6 +534,21 @@ class ProcessGroupNCCL : public ProcessGroup {
// Mutex for watchdog.
std::mutex watchdogCVMutex_;
// Thread that removes NCCL Work upon timeout
std::thread workCleanupThread_;
// Mutex to Guard workList_
std::mutex workListMutex_;
// Condition Variable for timeout thread sleep
std::condition_variable workListCV_;
// Vector to Store WorkNCCL pointers
std::list<std::shared_ptr<ProcessGroupNCCL::WorkNCCL>> workList_;
// Add Work Pointer to workVector
void workEnqueue(std::shared_ptr<ProcessGroupNCCL::WorkNCCL>);
// The CUDA steams used by NCCL kernels
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
ncclStreams_;