Detect and handle NCCL errors appropriately in ProcessGroupNCCL. (#25012)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25012

Resubmitting https://github.com/pytorch/pytorch/pull/22907 with build fix.

This change adds the following functionality:
1) WorkNCCL isCompleted, isSuccess methods check for NCCL errors and set the
appropriate exception.
2) Added a watchdog thread to ProcessGroupNCCL which checks for errors in the
cached communicators and removes them from the cache.
3) Use ncclCommAbort in NCCLComm destructor since ncclCommDestroy can block
forever waiting for work.
4) Added a simulate_nccl_errors.py script to simulate NCCL errors.

https://github.com/pytorch/pytorch/issues/17882
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22907

Test Plan: 1) Run the simulate_nccl_errors.py to verify NCCL errors are caught.

Differential Revision: D16958078

fbshipit-source-id: 662b0b8b8ee250e2b6d15bdfc9306d71c4f66219
This commit is contained in:
Pritam Damania 2019-08-22 16:10:29 -07:00 committed by Facebook Github Bot
parent 1037652224
commit 149c646b74
9 changed files with 627 additions and 39 deletions

View File

@ -1,3 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import multiprocessing
import sys
import tempfile
@ -114,6 +116,7 @@ class MultiProcessTestCase(TestCase):
def setUp(self):
super(MultiProcessTestCase, self).setUp()
self.skip_return_code_checks = []
self.rank = self.MAIN_PROCESS_RANK
self.file = tempfile.NamedTemporaryFile(delete=False)
self.processes = [self._spawn_process(rank) for rank in range(int(self.world_size))]
@ -143,7 +146,8 @@ class MultiProcessTestCase(TestCase):
for p in self.processes:
p.join(timeout)
elapsed_time = time.time() - start_time
self._check_return_codes(elapsed_time)
if fn in self.skip_return_code_checks:
self._check_return_codes(elapsed_time)
def _check_return_codes(self, elapsed_time):
"""

View File

@ -0,0 +1,38 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.distributed as c10d
import torch
import argparse
import os
import logging
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Simple script to simulate NCCL errors. The script is '
'supposed to be run on multiple different nodes simultaneously with '
'appropriate rank and world_size. The script run an allreduce() on '
'the rank 0 node and aborts all the other nodes to simulate an error '
'in NCCL')
parser.add_argument('addr', help='address of the master node to connect to.')
parser.add_argument('port', help='port of the master node to connect to.')
parser.add_argument('rank', help='rank of this node')
parser.add_argument('world_size', help='number of nodes in process group')
args = parser.parse_args()
rank = int(args.rank)
world_size = int(args.world_size)
port = int(args.port)
store = c10d.TCPStore(args.addr, port, world_size, rank == 0)
process_group = c10d.ProcessGroupNCCL(store, rank, world_size)
logging.info('Running first allreduce')
process_group.allreduce(torch.rand(10).cuda(rank)).wait()
if rank == 0:
logging.info('Running second allreduce only on rank 0')
work = process_group.allreduce(torch.rand(10).cuda(rank))
logging.info('Waiting for allreduce to complete...')
work.wait()
logging.info('Second allreduce successful: {}'.format(work.is_success()))
else:
logging.info('Aborting all other ranks.')
os.abort()

View File

@ -1,7 +1,10 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import math
import os
import random
import signal
import sys
import tempfile
import threading
@ -22,7 +25,7 @@ from torch.nn.parallel import DistributedDataParallel
from common_distributed import MultiProcessTestCase, \
requires_gloo, requires_nccl, \
skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues
skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues, get_timeout
from common_utils import TestCase, load_tests, run_tests
from common_utils import retry_on_address_already_in_use_error
@ -2792,12 +2795,26 @@ class ComputeBucketAssignmentTest(TestCase):
class CommTest(MultiProcessTestCase):
def setUp(self):
super(CommTest, self).setUp()
# Need to skip return code checking for these tests since the child
# processes don't exit cleanly.
self.skip_return_code_checks = [
self.test_nccl_errors_blocking_abort,
self.test_nccl_errors_blocking_sigkill,
self.test_nccl_errors_blocking_sigstop,
self.test_nccl_errors_blocking_sigterm,
]
self.op_timeout_sec = 1
def tearDown(self):
super(CommTest, self).tearDown()
try:
os.remove(self.file.name)
except OSError:
pass
os.environ["NCCL_BLOCKING_WAIT"] = "0"
@property
def world_size(self):
@ -2831,6 +2848,85 @@ class CommTest(MultiProcessTestCase):
self.assertEqual(tensors, target)
def _run_all_reduce(self, pg):
pg.allreduce(torch.rand(10).cuda(self.rank))
@requires_nccl()
@skip_if_not_multigpu
def test_nccl_errors_nonblocking(self):
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
process_group.allreduce(torch.rand(10).cuda(self.rank))
if self.rank == 0:
# This allreduce does not block Python thread as allreduce enqueues
# the cuda operation, and then wait only blocks the current cuda
# stream.
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
work.wait()
# Now the work scheduled next should hang forever since the previous
# allreduce will never complete.
t = threading.Thread(target=self._run_all_reduce, args=(process_group,))
t.start()
t.join(int(get_timeout(self.id()) / 2))
self.assertTrue(t.is_alive())
def _test_nccl_errors_blocking(self, func):
os.environ["NCCL_BLOCKING_WAIT"] = "1"
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, "", timeout=timedelta(seconds=self.op_timeout_sec))
process_group.allreduce(torch.rand(10).cuda(self.rank))
if self.rank == 0:
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
with self.assertRaises(RuntimeError):
# Operation would time out in blocking mode.
work.wait()
else:
func()
@requires_nccl()
@skip_if_not_multigpu
def test_nccl_errors_blocking_clean_exit(self):
self._test_nccl_errors_blocking(lambda : sys.exit(0))
@requires_nccl()
@skip_if_not_multigpu
def test_nccl_errors_blocking_abort(self):
self._test_nccl_errors_blocking(lambda : os.abort())
@requires_nccl()
@skip_if_not_multigpu
def test_nccl_errors_blocking_sigkill(self):
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGKILL))
@requires_nccl()
@skip_if_not_multigpu
def test_nccl_errors_blocking_sigstop(self):
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGSTOP))
if self.rank == 0:
time.sleep(2 * self.op_timeout_sec)
for i in range(1, len(self.processes)):
os.kill(self.processes[i].pid, signal.SIGCONT)
@requires_nccl()
@skip_if_not_multigpu
def test_nccl_errors_blocking_sigterm(self):
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGTERM))
def _run_invalid_nccl_blocking_wait_env(self, val):
os.environ["NCCL_BLOCKING_WAIT"] = val
store = c10d.FileStore(self.file.name, self.world_size)
with self.assertRaises(RuntimeError):
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
@requires_nccl()
@skip_if_not_multigpu
def test_invalid_nccl_blocking_wait_env(self):
self._run_invalid_nccl_blocking_wait_env('abc')
self._run_invalid_nccl_blocking_wait_env('-1')
self._run_invalid_nccl_blocking_wait_env('2147483647')
self._run_invalid_nccl_blocking_wait_env('4294967295')
@requires_nccl()
@skip_if_not_multigpu
def test_broadcast_coalesced_nccl(self):

View File

@ -489,11 +489,14 @@ They are used in specifying strategies for reduction collectives, e.g.,
const std::shared_ptr<::c10d::Store>&,
int,
int,
const std::string&>(),
const std::string&,
const std::chrono::milliseconds&>(),
py::arg("store"),
py::arg("rank"),
py::arg("size"),
py::arg("groupName") = "");
py::arg("groupName") = "",
py::arg("timeout") = std::chrono::milliseconds(
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis));
#endif
#ifdef USE_C10D_MPI

View File

@ -1,8 +1,14 @@
#pragma once
#include <memory>
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR < 4)
#error "Need NCCL version 2.4+"
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
#error "Need NCCL version 2.4+"
#endif
#include <nccl.h>
#include <memory>
#define C10D_NCCL_CHECK(cmd) \
do { \
@ -20,13 +26,16 @@ namespace c10d {
// RAII wrapper for NCCL communicator
class NCCLComm {
public:
explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
explicit NCCLComm(ncclComm_t ncclComm)
: ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess) {}
NCCLComm() : NCCLComm(nullptr) {}
~NCCLComm() noexcept(false) {
if (ncclComm_) {
C10D_NCCL_CHECK(ncclCommDestroy(ncclComm_));
if (ncclComm_ && !aborted_) {
// Use ncclCommAbort instead of ncclCommDestroy here since ncclCommDestroy
// could block forever waiting for work to complete on the communicator.
ncclCommAbort();
}
}
@ -47,19 +56,57 @@ class NCCLComm {
// Move constructable
NCCLComm(NCCLComm&& other) {
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
}
// Move assignable
NCCLComm& operator=(NCCLComm&& other) {
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
return *this;
}
ncclComm_t getNcclComm() {
if (aborted_) {
throw std::runtime_error("NCCL communicator was aborted.");
}
return ncclComm_;
}
void ncclCommAbort() {
if (aborted_) {
// Should not abort twice.
return;
}
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_));
aborted_ = true;
ncclComm_ = nullptr;
// Set an appropriate error so that we avoid using the communicator.
if (ncclAsyncErr_ == ncclSuccess) {
ncclAsyncErr_ = ncclSystemError;
}
}
bool isAborted() const {
return aborted_;
}
ncclResult_t checkForNcclError() {
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
}
C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_));
return ncclAsyncErr_;
}
protected:
ncclComm_t ncclComm_;
bool aborted_;
ncclResult_t ncclAsyncErr_;
};
} // namespace c10d

View File

@ -111,33 +111,54 @@ void syncStreams(
} // namespace
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100;
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis =
10 * 1000;
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector<at::Device>& devices)
: devices_(devices) {
: devices_(devices), workStartTime_(std::chrono::steady_clock::now()) {
// Creates the CUDA event wrappers
// Note: The actual events are lazily created when first recorded to with
// DEFAULT_FLAGS = cudaEventDisableTiming.
cudaEvents_.resize(devices.size());
ncclComms_.resize(devices.size());
}
ProcessGroupNCCL::WorkNCCL::~WorkNCCL() {}
bool ProcessGroupNCCL::WorkNCCL::isCompleted() {
return finishedGPUExecution();
checkAndSetException();
return exception() || finishedGPUExecutionInternal();
}
bool ProcessGroupNCCL::WorkNCCL::isSuccess() const {
return true;
if (exception()) {
// Already detected an exception.
return false;
}
return !checkForNCCLErrors(ncclComms_) && finishedGPUExecutionInternal();
}
std::exception_ptr ProcessGroupNCCL::WorkNCCL::exception() const {
throw std::runtime_error(
"exception() is not supported by NCCL process "
"group's work, since isSuccess() will always return true, and "
"isCompleted() and wait() will either succeed or throw");
void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
if (exception()) {
// We already have an exception.
return;
}
auto exception_ptr = checkForNCCLErrors(ncclComms_);
std::unique_lock<std::mutex> lock(mutex_);
exception_ = exception_ptr;
}
// Helper that checks if the NCCL kernels are completed on the GPUs
bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() {
checkAndSetException();
return finishedGPUExecutionInternal();
}
bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const {
for (size_t i = 0; i < devices_.size(); ++i) {
// Checking the work's corresponding CUDA events' status
auto ret = cudaEventQuery(cudaEvents_[i]);
@ -151,6 +172,16 @@ bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() {
return true;
}
void ProcessGroupNCCL::WorkNCCL::checkAndThrowException() {
// Set the appropriate exception if found.
checkAndSetException();
// Throw an exception, only if we have a valid exception.
if (exception()) {
std::rethrow_exception(exception());
}
}
// Waiting on the work's corresponding CUDA events
void ProcessGroupNCCL::WorkNCCL::synchronize() {
for (size_t i = 0; i < devices_.size(); ++i) {
@ -163,6 +194,23 @@ void ProcessGroupNCCL::WorkNCCL::synchronize() {
AT_CUDA_CHECK(cudaDeviceSynchronize());
}
}
// In case of blocking, wait for the operation to complete.
if (blockingWait_) {
// Wait for the operation to complete.
while (!isCompleted()) {
auto currentTimepoint = std::chrono::steady_clock::now();
if (std::chrono::duration_cast<std::chrono::milliseconds>(
currentTimepoint - workStartTime_) > opTimeout_) {
throw std::runtime_error("Operation timed out!");
}
// Check for errors and throw appropriate exception.
checkAndThrowException();
std::this_thread::sleep_for(
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
}
checkAndThrowException();
}
}
// Same as calling synchronize().
@ -180,8 +228,32 @@ ProcessGroupNCCL::ProcessGroupNCCL(
const std::shared_ptr<Store>& store,
int rank,
int size,
const std::string& groupName)
: ProcessGroup(rank, size), store_(store), groupName_(groupName) {
const std::string& groupName,
const std::chrono::milliseconds& opTimeout)
: ProcessGroup(rank, size),
store_(store),
groupName_(groupName),
terminateWatchdog_(false),
opTimeout_(opTimeout) {
char* blockingWait = getenv(NCCL_BLOCKING_WAIT);
try {
if (blockingWait != nullptr) {
auto val = std::stoi(blockingWait);
if (val == 1) {
// Make wait() and synchronize() a blocking call.
blockingWait_ = true;
} else if (val != 0) {
throw std::runtime_error(
"Invalid value for environment variable: " +
std::string(NCCL_BLOCKING_WAIT));
}
}
} catch (std::exception& e) {
throw std::runtime_error(
"Invalid value for environment variable: " +
std::string(NCCL_BLOCKING_WAIT));
}
// Generate the Process Group ID for current PG, this needs to be identical
// for all processes
std::unique_lock<std::mutex> lock(pgTrackingLock_);
@ -194,11 +266,81 @@ ProcessGroupNCCL::ProcessGroupNCCL(
processGroupID_ = std::to_string(processGroupCounterMap_[groupKey]);
groupPgID_ = groupName_ + "_" + processGroupID_;
pgUniqueNCCLIDCnt_[groupPgID_] = -1;
ncclCommWatchdogThread_ =
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
}
ProcessGroupNCCL::~ProcessGroupNCCL() {
std::unique_lock<std::mutex> lock(pgTrackingLock_);
pgUniqueNCCLIDCnt_.erase(groupPgID_);
terminateWatchdog_.store(true);
watchdogCV_.notify_one();
ncclCommWatchdogThread_.join();
}
void ProcessGroupNCCL::ncclCommWatchdog() {
while (!terminateWatchdog_.load()) {
{
// Loop through the cache of communicators for NCCL errors.
std::lock_guard<std::mutex> lock(devNCCLCommMapLock_);
for (auto it = devNCCLCommMap_.begin(); it != devNCCLCommMap_.end();) {
auto& ncclComms = it->second;
if (checkForNCCLErrors(ncclComms)) {
LOG(INFO) << "Received NCCL errors for communicators in the cache, "
"removing communicators from the cache and aborting the "
"communicators.";
if (blockingWait_) {
// We should not abort the communicators if we are performing a
// non-blocking wait(). The reason for this is that if we abort the
// nccl communicator, wait() might not throw exceptions and
// subsequent operations might run on garbage results.
// The current model is that when we call wait(), subsequent
// operations only run after this work is done or we hang forever
// waiting for the operation to complete.
for (const auto& ncclComm : ncclComms) {
ncclComm->ncclCommAbort();
}
}
// Remove communicators from the cache.
it = devNCCLCommMap_.erase(it);
} else {
it++;
}
}
}
std::unique_lock<std::mutex> lock(watchdogCVMutex_);
watchdogCV_.wait_for(
lock,
std::chrono::milliseconds(kWatchdogThreadSleepMillis),
[&]() -> bool { return terminateWatchdog_.load(); });
}
}
std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const {
return checkForNCCLErrorsInternal(ncclComms);
}
std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) {
return checkForNCCLErrorsInternal(ncclComms);
}
std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) {
for (const auto& ncclComm : ncclComms) {
ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
if (ncclAsyncErr != ncclSuccess) {
return std::make_exception_ptr(std::runtime_error(
"NCCL error: " + std::string(ncclGetErrorString(ncclAsyncErr))));
}
}
return nullptr;
}
void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) {
@ -249,10 +391,14 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
usedDeviceIdxs_.insert(device.index());
}
if (devNCCLCommMap_.find(devicesKey) != devNCCLCommMap_.end()) {
// Reuse the cached communicator if there is one.
return devNCCLCommMap_[devicesKey];
{
std::lock_guard<std::mutex> lock(devNCCLCommMapLock_);
if (devNCCLCommMap_.find(devicesKey) != devNCCLCommMap_.end()) {
// Reuse the cached communicator if there is one.
return devNCCLCommMap_[devicesKey];
}
}
// NCCL communicator not cached, create a new entry
std::vector<std::shared_ptr<NCCLComm>> ncclComms;
ncclComms.resize(devices.size());
@ -289,8 +435,6 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
C10D_NCCL_CHECK(ncclGroupEnd());
// Move the NCCL resource to cache
devNCCLCommMap_.emplace(devicesKey, std::move(ncclComms));
ncclStreams_.emplace(devicesKey, std::move(streamVal));
// Note: these events are created with the (default) cudaEventDisableTiming
@ -302,6 +446,11 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
std::make_tuple(devicesKey),
std::make_tuple(devices.size()));
// Hold the lock before modifying the cache.
std::lock_guard<std::mutex> lock(devNCCLCommMapLock_);
// Move the NCCL resource to cache
devNCCLCommMap_.emplace(devicesKey, std::move(ncclComms));
return devNCCLCommMap_[devicesKey];
}
@ -388,6 +537,11 @@ std::vector<at::Tensor> flatten_for_scatter_gather(
} // namespace
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
std::vector<at::Device> devices) {
return std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices);
}
template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
std::vector<at::Tensor>& inputs,
@ -403,7 +557,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
syncStreams(devices, ncclEvents_[key], ncclStreams_[key]);
// Work itself will create the CUDA events on all GPUs of tensors
auto work = std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices);
auto work = initWork(devices);
at::cuda::OptionalCUDAGuard gpuGuard;
@ -441,6 +595,9 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
for (size_t i = 0; i < inputs.size(); ++i) {
at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
work->cudaEvents_[i].record(ncclStream);
work->ncclComms_[i] = ncclComms[i];
work->blockingWait_ = blockingWait_;
work->opTimeout_ = opTimeout_;
}
return work;

View File

@ -1,6 +1,7 @@
#pragma once
#include <mutex>
#include <thread>
#include <unordered_map>
#include <c10d/NCCLUtils.hpp>
@ -12,6 +13,10 @@
namespace c10d {
// Environment variable which controls whether or not wait() is blocking or
// non-blocking.
constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
@ -31,13 +36,6 @@ namespace c10d {
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
// functionality and are synonyms.
//
// Note that WorkNCCL::isSuccess() and WorkNCCL::isCompleted() will always
// return true since ProcessGroupNCCL is single threaded. Every single NCCL
// or CUDA failure will simply raise std::runtime_error.
//
// Therefore, WorkNCCL::exception() is not supported since isSuccess() always
// returns true.
//
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
// finished execution on the GPU (not just scheduled).
@ -67,19 +65,16 @@ class ProcessGroupNCCL : public ProcessGroup {
// Non-blocking operation.
bool isCompleted() override;
bool isSuccess() const override;
// Same as calling synchronize() for NCCL work.
void wait() override;
// Will always return true
bool isSuccess() const override;
// Let current stream wait on the completing of the NCCL work
// Throws on exceptions. Non-blocking operation.
// Throws on exceptions. Blocking operation, which will wait for work
// completion.
void synchronize() override;
// Will always throw because it should not be called (isSuccess() -> true).
std::exception_ptr exception() const override;
// Helper function that checks if the NCCL kernels have finished
// execution on the GPUs
bool finishedGPUExecution();
@ -91,9 +86,37 @@ class ProcessGroupNCCL : public ProcessGroup {
// The CUDA events tracking this work item on multiple CUDA devices
std::vector<at::cuda::CUDAEvent> cudaEvents_;
// The NCCL communicators used for this work item.
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
// Tensors used for barrier op
std::vector<at::Tensor> barrierTensors_;
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_ = false;
// Clonge of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_;
// Time point representing when the work started.
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
// Wrapper method for the static checkForNCCLErrors which can be overridden
// for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
private:
// Checks for NCCL errors and sets an appropriate exception_ptr.
void checkAndSetException();
// Checks for NCCL errors and throws an appropriate exception.
void checkAndThrowException();
// Just checks whether GPU execution has completed, without modifying
// exception_ptr.
bool finishedGPUExecutionInternal() const;
friend class ProcessGroupNCCL;
};
@ -115,7 +138,9 @@ class ProcessGroupNCCL : public ProcessGroup {
const std::shared_ptr<Store>& store,
int rank,
int size,
const std::string& groupName = "");
const std::string& groupName = "",
const std::chrono::milliseconds& opTimeout =
std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis));
virtual ~ProcessGroupNCCL();
@ -169,6 +194,8 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<at::Tensor>& tensors,
int tag) override;
static const int64_t kProcessGroupNCCLOpTimeoutMillis;
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(ncclUniqueId* ncclID);
@ -179,6 +206,13 @@ class ProcessGroupNCCL : public ProcessGroup {
const std::string& devicesKey,
const std::vector<at::Device>& devices);
// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
virtual std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices);
private:
// Helper that encapsulates work shared across all collective communication
// primitives. The callbacks have the following signatures:
@ -199,7 +233,25 @@ class ProcessGroupNCCL : public ProcessGroup {
PreProcess pre,
PostProcess post);
// Checks for NCCL errors on each of the communicators and returns an
// appropriate exception_ptr (nullptr if no errors).
static std::exception_ptr checkForNCCLErrorsInternal(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
// Function that runs as part of a separate thread and checks for errors on
// NCCL communicators. We need a separate thread to check for NCCL errors
// since we can't rely on the user calling certain methods like wait(),
// isCompleted() etc. to detect and remediate errors. In addition to this, we
// need a mechanism to safely abort and remove NCCL communicators from our
// cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
// class. Attempting to modify the communicator cache from the WorkNCCL class
// might run into issues with object lifetime since the ProcessGroupNCCL
// object might get destroyed before the WorkNCCL object.
void ncclCommWatchdog();
protected:
static const int64_t kWatchdogThreadSleepMillis;
// Store that is used to exchange each Ranks's NCCL unique ID
std::shared_ptr<Store> store_;
@ -228,6 +280,21 @@ class ProcessGroupNCCL : public ProcessGroup {
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
devNCCLCommMap_;
// Mutex to guard devNCCLCommMap_.
std::mutex devNCCLCommMapLock_;
// Watchdog thread which looks for errors on the cached NCCL communicators.
std::thread ncclCommWatchdogThread_;
// Whether or not we should terminate the watchdog thread.
std::atomic<bool> terminateWatchdog_;
// Condition variable to control how long the watchdog thread waits.
std::condition_variable watchdogCV_;
// Mutex for watchdog.
std::mutex watchdogCVMutex_;
// The CUDA steams used by NCCL kernels
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
ncclStreams_;
@ -266,6 +333,13 @@ class ProcessGroupNCCL : public ProcessGroup {
// is that different group can have different ranks and we need ensure that
// each group has its own uniform process group ID for all its ranks.
static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;
// Timeout for operations. This is only used when blockingWait_ is enabled.
std::chrono::milliseconds opTimeout_;
};
} // namespace c10d

View File

@ -23,6 +23,7 @@ if(USE_CUDA)
endif()
if(USE_C10D_NCCL)
c10d_add_test(ProcessGroupNCCLTest.cpp c10d c10d_cuda_test)
c10d_add_test(ProcessGroupNCCLErrorsTest.cpp c10d c10d_cuda_test gtest_main)
endif()
else()
if(USE_C10D_GLOO)

View File

@ -0,0 +1,168 @@
#include <c10d/FileStore.hpp>
#include <c10d/ProcessGroupNCCL.hpp>
#include <c10d/test/CUDATest.hpp>
#include <c10d/test/TestUtils.hpp>
#include <gtest/gtest.h>
using namespace c10d::test;
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLSimulateErrors(
const std::vector<at::Device>& devices,
bool simulate_error)
: WorkNCCL(devices), simulate_error_(simulate_error) {}
std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms)
const override {
if (simulate_error_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms);
}
private:
bool simulate_error_;
};
class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
public:
ProcessGroupNCCLSimulateErrors(
const std::shared_ptr<c10d::Store>& store,
int rank,
int size)
: ProcessGroupNCCL(store, rank, size), simulate_error_(false) {}
std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
if (simulate_error_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms);
}
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
return std::chrono::milliseconds(
ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
}
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices) override {
return std::make_shared<WorkNCCLSimulateErrors>(devices, simulate_error_);
}
size_t getNCCLCommCacheSize() {
return devNCCLCommMap_.size();
}
void simulate_error() {
simulate_error_ = true;
}
void reset_error() {
simulate_error_ = false;
}
private:
bool simulate_error_;
};
class ProcessGroupNCCLErrorsTest : public ::testing::Test {
protected:
bool skipTest() {
// Skip test if no cuda devices found.
return cudaNumDevices() == 0;
}
void SetUp() override {
size_t numDevices = cudaNumDevices();
TemporaryFile file;
store_ = std::make_shared<::c10d::FileStore>(file.path, 1);
at::cuda::OptionalCUDAGuard deviceGuard;
tensors_.resize(numDevices);
for (auto i = 0; i < numDevices; ++i) {
deviceGuard.set_index(i);
tensors_[i] = at::ones({3, 3}, at::kCUDA);
}
}
void TearDown() override {
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "0", 1) == 0);
}
std::vector<at::Tensor> tensors_;
std::shared_ptr<::c10d::FileStore> store_;
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
if (skipTest()) {
return;
}
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulate_error();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), std::runtime_error);
// Verify the work item failed.
EXPECT_TRUE(work->isCompleted());
EXPECT_FALSE(work->isSuccess());
EXPECT_THROW(work->wait(), std::runtime_error);
// Should remove the nccl communicators which hit errors from the cache.
std::this_thread::sleep_for(2 * pg.getWatchdogSleepInterval());
EXPECT_EQ(0, pg.getNCCLCommCacheSize());
// Verify we can recover from errors.
pg.reset_error();
work = pg.allreduce(tensors_);
work->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
if (skipTest()) {
return;
}
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1);
auto work = pg.allreduce(tensors_);
pg.barrier()->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulate_error();
work = pg.allreduce(tensors_);
// Should not throw exceptions.
work->wait();
pg.barrier()->wait();
// Verify the work item failed.
EXPECT_TRUE(work->isCompleted());
EXPECT_FALSE(work->isSuccess());
// Should remove the nccl communicators which hit errors from the cache.
std::this_thread::sleep_for(2 * pg.getWatchdogSleepInterval());
EXPECT_EQ(0, pg.getNCCLCommCacheSize());
// Verify we can recover from errors.
pg.reset_error();
work = pg.allreduce(tensors_);
pg.barrier()->wait();
EXPECT_TRUE(work->isSuccess());
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
}