mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[NCCL] Timeout Loop Thread for Async Error Handling (#41050)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41050 **This Commit:** We introduce a workVector to track live workNCCL objects corresponding to collective operations. Further, we introduce a workCleanupLoop, which busy-polls the vector of workNCCL objects and removes them upon completion. **This Stack:** The purpose of this stack is to fix the hanging behavior observed in when using PyTorch DDP training with NCCL. In various situations (desynchronization, high GPU utilization, etc.), NCCL collectives may hang due to waiting on an unresponsive worker. This stack detects such hanging behavior and aborts timed-out collectives by throwing a user-visible exception, all with minimal perf regression. Training can then be restarted from a previous checkpoint with something like torchelastic. Test Plan: See D22054298 for verification of correctness and performance Reviewed By: jiayisuse Differential Revision: D21916637 fbshipit-source-id: f8cadaab0071aaad1c4e31f9b089aa23cba0cfbe
This commit is contained in:
parent
15cbd1cf4b
commit
1df24fd457
|
|
@ -228,6 +228,7 @@ ncclResult_t ncclAlltoallv(
|
|||
} // namespace
|
||||
|
||||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
|
||||
const int64_t ProcessGroupNCCL::kWorkCleanupThreadSleepMillis = 1000;
|
||||
constexpr int64_t kWaitForAbortCommStoreKey = 1000;
|
||||
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
|
||||
const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis = 10 * 1000;
|
||||
|
|
@ -399,7 +400,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
|||
: ProcessGroup(rank, size),
|
||||
store_(store),
|
||||
ncclCommCounter_(0),
|
||||
terminateWatchdog_(false),
|
||||
terminateProcessGroup_(false),
|
||||
opTimeout_(opTimeout) {
|
||||
try {
|
||||
parseNcclBlockingWait();
|
||||
|
|
@ -424,11 +425,14 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
|||
ncclCommWatchdogThread_ =
|
||||
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
|
||||
#endif
|
||||
|
||||
workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this);
|
||||
}
|
||||
|
||||
ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||
terminateWatchdog_.store(true);
|
||||
terminateProcessGroup_.store(true);
|
||||
watchdogCV_.notify_one();
|
||||
workListCV_.notify_one();
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
ncclCommWatchdogThread_.join();
|
||||
#endif
|
||||
|
|
@ -444,6 +448,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
|
|||
}
|
||||
}
|
||||
}
|
||||
workCleanupThread_.join();
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||
|
|
@ -458,7 +463,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
|
|||
}
|
||||
|
||||
void ProcessGroupNCCL::ncclCommWatchdogInternal() {
|
||||
while (!terminateWatchdog_.load()) {
|
||||
while (!terminateProcessGroup_.load()) {
|
||||
std::unordered_set<std::string> abortedCommIds;
|
||||
std::unordered_set<std::string> allCommIds;
|
||||
|
||||
|
|
@ -554,7 +559,32 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
|
|||
watchdogCV_.wait_for(
|
||||
lock,
|
||||
std::chrono::milliseconds(kWatchdogThreadSleepMillis),
|
||||
[&]() -> bool { return terminateWatchdog_.load(); });
|
||||
[&]() -> bool { return terminateProcessGroup_.load(); });
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::workCleanupLoop() {
|
||||
while (!terminateProcessGroup_.load()) {
|
||||
std::unique_lock<std::mutex> lock(workListMutex_);
|
||||
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
||||
// milliseconds as long as the atomic is True.
|
||||
workListCV_.wait_for(
|
||||
lock,
|
||||
std::chrono::milliseconds(kWorkCleanupThreadSleepMillis),
|
||||
[&]() -> bool { return terminateProcessGroup_.load(); });
|
||||
|
||||
for (auto it = workList_.begin(); it != workList_.end();
|
||||
/* no increment*/) {
|
||||
auto& work = *it;
|
||||
if (work->isCompleted()) {
|
||||
// Remove all Completed WorkNCCL Objects from the Vector
|
||||
it = workList_.erase(it);
|
||||
} else {
|
||||
// Increment the iterator if the current WorkNCCL object is not
|
||||
// completed.
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -797,6 +827,14 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
|
|||
futureNCCLCallbackStreams_[deviceIndex]);
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::workEnqueue(
|
||||
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> work) {
|
||||
if (!terminateProcessGroup_.load()) {
|
||||
std::lock_guard<std::mutex> lock(workListMutex_);
|
||||
workList_.emplace_back(std::move(work));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Fn, typename PreProcess, typename PostProcess>
|
||||
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
|
||||
std::vector<at::Tensor>& inputs,
|
||||
|
|
@ -861,6 +899,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
|
|||
work->store_ = store_;
|
||||
}
|
||||
|
||||
workEnqueue(work);
|
||||
|
||||
return work;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
|
@ -478,8 +479,11 @@ class ProcessGroupNCCL : public ProcessGroup {
|
|||
// accordingly.
|
||||
void parseNcclBlockingWait();
|
||||
|
||||
void workCleanupLoop();
|
||||
|
||||
protected:
|
||||
static const int64_t kWatchdogThreadSleepMillis;
|
||||
static const int64_t kWorkCleanupThreadSleepMillis;
|
||||
|
||||
// The store is used to broadcast the NCCL unique ID of rank 0.
|
||||
std::shared_ptr<Store> store_;
|
||||
|
|
@ -521,8 +525,8 @@ class ProcessGroupNCCL : public ProcessGroup {
|
|||
// Watchdog thread which looks for errors on the cached NCCL communicators.
|
||||
std::thread ncclCommWatchdogThread_;
|
||||
|
||||
// Whether or not we should terminate the watchdog thread.
|
||||
std::atomic<bool> terminateWatchdog_;
|
||||
// Whether or not we should terminate the watchdog and workCleanup threads.
|
||||
std::atomic<bool> terminateProcessGroup_;
|
||||
|
||||
// Condition variable to control how long the watchdog thread waits.
|
||||
std::condition_variable watchdogCV_;
|
||||
|
|
@ -530,6 +534,21 @@ class ProcessGroupNCCL : public ProcessGroup {
|
|||
// Mutex for watchdog.
|
||||
std::mutex watchdogCVMutex_;
|
||||
|
||||
// Thread that removes NCCL Work upon timeout
|
||||
std::thread workCleanupThread_;
|
||||
|
||||
// Mutex to Guard workList_
|
||||
std::mutex workListMutex_;
|
||||
|
||||
// Condition Variable for timeout thread sleep
|
||||
std::condition_variable workListCV_;
|
||||
|
||||
// Vector to Store WorkNCCL pointers
|
||||
std::list<std::shared_ptr<ProcessGroupNCCL::WorkNCCL>> workList_;
|
||||
|
||||
// Add Work Pointer to workVector
|
||||
void workEnqueue(std::shared_ptr<ProcessGroupNCCL::WorkNCCL>);
|
||||
|
||||
// The CUDA steams used by NCCL kernels
|
||||
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
|
||||
ncclStreams_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user