mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Detect and handle NCCL errors appropriately in ProcessGroupNCCL. (#25012)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25012 Resubmitting https://github.com/pytorch/pytorch/pull/22907 with build fix. This change adds the following functionality: 1) WorkNCCL isCompleted, isSuccess methods check for NCCL errors and set the appropriate exception. 2) Added a watchdog thread to ProcessGroupNCCL which checks for errors in the cached communicators and removes them from the cache. 3) Use ncclCommAbort in NCCLComm destructor since ncclCommDestroy can block forever waiting for work. 4) Added a simulate_nccl_errors.py script to simulate NCCL errors. https://github.com/pytorch/pytorch/issues/17882 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22907 Test Plan: 1) Run the simulate_nccl_errors.py to verify NCCL errors are caught. Differential Revision: D16958078 fbshipit-source-id: 662b0b8b8ee250e2b6d15bdfc9306d71c4f66219
This commit is contained in:
parent
1037652224
commit
149c646b74
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
@ -114,6 +116,7 @@ class MultiProcessTestCase(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(MultiProcessTestCase, self).setUp()
|
super(MultiProcessTestCase, self).setUp()
|
||||||
|
self.skip_return_code_checks = []
|
||||||
self.rank = self.MAIN_PROCESS_RANK
|
self.rank = self.MAIN_PROCESS_RANK
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False)
|
self.file = tempfile.NamedTemporaryFile(delete=False)
|
||||||
self.processes = [self._spawn_process(rank) for rank in range(int(self.world_size))]
|
self.processes = [self._spawn_process(rank) for rank in range(int(self.world_size))]
|
||||||
|
|
@ -143,7 +146,8 @@ class MultiProcessTestCase(TestCase):
|
||||||
for p in self.processes:
|
for p in self.processes:
|
||||||
p.join(timeout)
|
p.join(timeout)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
self._check_return_codes(elapsed_time)
|
if fn in self.skip_return_code_checks:
|
||||||
|
self._check_return_codes(elapsed_time)
|
||||||
|
|
||||||
def _check_return_codes(self, elapsed_time):
|
def _check_return_codes(self, elapsed_time):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
38
test/simulate_nccl_errors.py
Normal file
38
test/simulate_nccl_errors.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import torch.distributed as c10d
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Simple script to simulate NCCL errors. The script is '
|
||||||
|
'supposed to be run on multiple different nodes simultaneously with '
|
||||||
|
'appropriate rank and world_size. The script run an allreduce() on '
|
||||||
|
'the rank 0 node and aborts all the other nodes to simulate an error '
|
||||||
|
'in NCCL')
|
||||||
|
parser.add_argument('addr', help='address of the master node to connect to.')
|
||||||
|
parser.add_argument('port', help='port of the master node to connect to.')
|
||||||
|
parser.add_argument('rank', help='rank of this node')
|
||||||
|
parser.add_argument('world_size', help='number of nodes in process group')
|
||||||
|
args = parser.parse_args()
|
||||||
|
rank = int(args.rank)
|
||||||
|
world_size = int(args.world_size)
|
||||||
|
port = int(args.port)
|
||||||
|
|
||||||
|
store = c10d.TCPStore(args.addr, port, world_size, rank == 0)
|
||||||
|
process_group = c10d.ProcessGroupNCCL(store, rank, world_size)
|
||||||
|
logging.info('Running first allreduce')
|
||||||
|
process_group.allreduce(torch.rand(10).cuda(rank)).wait()
|
||||||
|
if rank == 0:
|
||||||
|
logging.info('Running second allreduce only on rank 0')
|
||||||
|
work = process_group.allreduce(torch.rand(10).cuda(rank))
|
||||||
|
logging.info('Waiting for allreduce to complete...')
|
||||||
|
work.wait()
|
||||||
|
logging.info('Second allreduce successful: {}'.format(work.is_success()))
|
||||||
|
else:
|
||||||
|
logging.info('Aborting all other ranks.')
|
||||||
|
os.abort()
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -22,7 +25,7 @@ from torch.nn.parallel import DistributedDataParallel
|
||||||
|
|
||||||
from common_distributed import MultiProcessTestCase, \
|
from common_distributed import MultiProcessTestCase, \
|
||||||
requires_gloo, requires_nccl, \
|
requires_gloo, requires_nccl, \
|
||||||
skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues
|
skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues, get_timeout
|
||||||
from common_utils import TestCase, load_tests, run_tests
|
from common_utils import TestCase, load_tests, run_tests
|
||||||
from common_utils import retry_on_address_already_in_use_error
|
from common_utils import retry_on_address_already_in_use_error
|
||||||
|
|
||||||
|
|
@ -2792,12 +2795,26 @@ class ComputeBucketAssignmentTest(TestCase):
|
||||||
|
|
||||||
|
|
||||||
class CommTest(MultiProcessTestCase):
|
class CommTest(MultiProcessTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(CommTest, self).setUp()
|
||||||
|
# Need to skip return code checking for these tests since the child
|
||||||
|
# processes don't exit cleanly.
|
||||||
|
self.skip_return_code_checks = [
|
||||||
|
self.test_nccl_errors_blocking_abort,
|
||||||
|
self.test_nccl_errors_blocking_sigkill,
|
||||||
|
self.test_nccl_errors_blocking_sigstop,
|
||||||
|
self.test_nccl_errors_blocking_sigterm,
|
||||||
|
]
|
||||||
|
self.op_timeout_sec = 1
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super(CommTest, self).tearDown()
|
super(CommTest, self).tearDown()
|
||||||
try:
|
try:
|
||||||
os.remove(self.file.name)
|
os.remove(self.file.name)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
os.environ["NCCL_BLOCKING_WAIT"] = "0"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self):
|
def world_size(self):
|
||||||
|
|
@ -2831,6 +2848,85 @@ class CommTest(MultiProcessTestCase):
|
||||||
|
|
||||||
self.assertEqual(tensors, target)
|
self.assertEqual(tensors, target)
|
||||||
|
|
||||||
|
def _run_all_reduce(self, pg):
|
||||||
|
pg.allreduce(torch.rand(10).cuda(self.rank))
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_not_multigpu
|
||||||
|
def test_nccl_errors_nonblocking(self):
|
||||||
|
store = c10d.FileStore(self.file.name, self.world_size)
|
||||||
|
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||||
|
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||||
|
if self.rank == 0:
|
||||||
|
# This allreduce does not block Python thread as allreduce enqueues
|
||||||
|
# the cuda operation, and then wait only blocks the current cuda
|
||||||
|
# stream.
|
||||||
|
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||||
|
work.wait()
|
||||||
|
|
||||||
|
# Now the work scheduled next should hang forever since the previous
|
||||||
|
# allreduce will never complete.
|
||||||
|
t = threading.Thread(target=self._run_all_reduce, args=(process_group,))
|
||||||
|
t.start()
|
||||||
|
t.join(int(get_timeout(self.id()) / 2))
|
||||||
|
self.assertTrue(t.is_alive())
|
||||||
|
|
||||||
|
def _test_nccl_errors_blocking(self, func):
|
||||||
|
os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
||||||
|
store = c10d.FileStore(self.file.name, self.world_size)
|
||||||
|
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, "", timeout=timedelta(seconds=self.op_timeout_sec))
|
||||||
|
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||||
|
if self.rank == 0:
|
||||||
|
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
# Operation would time out in blocking mode.
|
||||||
|
work.wait()
|
||||||
|
else:
|
||||||
|
func()
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_not_multigpu
|
||||||
|
def test_nccl_errors_blocking_clean_exit(self):
|
||||||
|
self._test_nccl_errors_blocking(lambda : sys.exit(0))
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_not_multigpu
|
||||||
|
def test_nccl_errors_blocking_abort(self):
|
||||||
|
self._test_nccl_errors_blocking(lambda : os.abort())
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_not_multigpu
|
||||||
|
def test_nccl_errors_blocking_sigkill(self):
|
||||||
|
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGKILL))
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_not_multigpu
|
||||||
|
def test_nccl_errors_blocking_sigstop(self):
|
||||||
|
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGSTOP))
|
||||||
|
if self.rank == 0:
|
||||||
|
time.sleep(2 * self.op_timeout_sec)
|
||||||
|
for i in range(1, len(self.processes)):
|
||||||
|
os.kill(self.processes[i].pid, signal.SIGCONT)
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_not_multigpu
|
||||||
|
def test_nccl_errors_blocking_sigterm(self):
|
||||||
|
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGTERM))
|
||||||
|
|
||||||
|
def _run_invalid_nccl_blocking_wait_env(self, val):
|
||||||
|
os.environ["NCCL_BLOCKING_WAIT"] = val
|
||||||
|
store = c10d.FileStore(self.file.name, self.world_size)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_not_multigpu
|
||||||
|
def test_invalid_nccl_blocking_wait_env(self):
|
||||||
|
self._run_invalid_nccl_blocking_wait_env('abc')
|
||||||
|
self._run_invalid_nccl_blocking_wait_env('-1')
|
||||||
|
self._run_invalid_nccl_blocking_wait_env('2147483647')
|
||||||
|
self._run_invalid_nccl_blocking_wait_env('4294967295')
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_if_not_multigpu
|
@skip_if_not_multigpu
|
||||||
def test_broadcast_coalesced_nccl(self):
|
def test_broadcast_coalesced_nccl(self):
|
||||||
|
|
|
||||||
|
|
@ -489,11 +489,14 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
||||||
const std::shared_ptr<::c10d::Store>&,
|
const std::shared_ptr<::c10d::Store>&,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
const std::string&>(),
|
const std::string&,
|
||||||
|
const std::chrono::milliseconds&>(),
|
||||||
py::arg("store"),
|
py::arg("store"),
|
||||||
py::arg("rank"),
|
py::arg("rank"),
|
||||||
py::arg("size"),
|
py::arg("size"),
|
||||||
py::arg("groupName") = "");
|
py::arg("groupName") = "",
|
||||||
|
py::arg("timeout") = std::chrono::milliseconds(
|
||||||
|
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef USE_C10D_MPI
|
#ifdef USE_C10D_MPI
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,14 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
|
||||||
|
(NCCL_MINOR < 4)
|
||||||
|
#error "Need NCCL version 2.4+"
|
||||||
|
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
|
||||||
|
#error "Need NCCL version 2.4+"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <nccl.h>
|
#include <nccl.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#define C10D_NCCL_CHECK(cmd) \
|
#define C10D_NCCL_CHECK(cmd) \
|
||||||
do { \
|
do { \
|
||||||
|
|
@ -20,13 +26,16 @@ namespace c10d {
|
||||||
// RAII wrapper for NCCL communicator
|
// RAII wrapper for NCCL communicator
|
||||||
class NCCLComm {
|
class NCCLComm {
|
||||||
public:
|
public:
|
||||||
explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
|
explicit NCCLComm(ncclComm_t ncclComm)
|
||||||
|
: ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess) {}
|
||||||
|
|
||||||
NCCLComm() : NCCLComm(nullptr) {}
|
NCCLComm() : NCCLComm(nullptr) {}
|
||||||
|
|
||||||
~NCCLComm() noexcept(false) {
|
~NCCLComm() noexcept(false) {
|
||||||
if (ncclComm_) {
|
if (ncclComm_ && !aborted_) {
|
||||||
C10D_NCCL_CHECK(ncclCommDestroy(ncclComm_));
|
// Use ncclCommAbort instead of ncclCommDestroy here since ncclCommDestroy
|
||||||
|
// could block forever waiting for work to complete on the communicator.
|
||||||
|
ncclCommAbort();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,19 +56,57 @@ class NCCLComm {
|
||||||
// Move constructable
|
// Move constructable
|
||||||
NCCLComm(NCCLComm&& other) {
|
NCCLComm(NCCLComm&& other) {
|
||||||
std::swap(ncclComm_, other.ncclComm_);
|
std::swap(ncclComm_, other.ncclComm_);
|
||||||
|
std::swap(aborted_, other.aborted_);
|
||||||
|
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move assignable
|
// Move assignable
|
||||||
NCCLComm& operator=(NCCLComm&& other) {
|
NCCLComm& operator=(NCCLComm&& other) {
|
||||||
std::swap(ncclComm_, other.ncclComm_);
|
std::swap(ncclComm_, other.ncclComm_);
|
||||||
|
std::swap(aborted_, other.aborted_);
|
||||||
|
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncclComm_t getNcclComm() {
|
ncclComm_t getNcclComm() {
|
||||||
|
if (aborted_) {
|
||||||
|
throw std::runtime_error("NCCL communicator was aborted.");
|
||||||
|
}
|
||||||
return ncclComm_;
|
return ncclComm_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ncclCommAbort() {
|
||||||
|
if (aborted_) {
|
||||||
|
// Should not abort twice.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_));
|
||||||
|
aborted_ = true;
|
||||||
|
ncclComm_ = nullptr;
|
||||||
|
|
||||||
|
// Set an appropriate error so that we avoid using the communicator.
|
||||||
|
if (ncclAsyncErr_ == ncclSuccess) {
|
||||||
|
ncclAsyncErr_ = ncclSystemError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isAborted() const {
|
||||||
|
return aborted_;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclResult_t checkForNcclError() {
|
||||||
|
if (ncclAsyncErr_ != ncclSuccess) {
|
||||||
|
return ncclAsyncErr_;
|
||||||
|
}
|
||||||
|
C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_));
|
||||||
|
return ncclAsyncErr_;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ncclComm_t ncclComm_;
|
ncclComm_t ncclComm_;
|
||||||
|
bool aborted_;
|
||||||
|
ncclResult_t ncclAsyncErr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace c10d
|
} // namespace c10d
|
||||||
|
|
|
||||||
|
|
@ -111,33 +111,54 @@ void syncStreams(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100;
|
||||||
|
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
|
||||||
|
const int64_t ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis =
|
||||||
|
10 * 1000;
|
||||||
|
|
||||||
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector<at::Device>& devices)
|
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const std::vector<at::Device>& devices)
|
||||||
: devices_(devices) {
|
: devices_(devices), workStartTime_(std::chrono::steady_clock::now()) {
|
||||||
// Creates the CUDA event wrappers
|
// Creates the CUDA event wrappers
|
||||||
// Note: The actual events are lazily created when first recorded to with
|
// Note: The actual events are lazily created when first recorded to with
|
||||||
// DEFAULT_FLAGS = cudaEventDisableTiming.
|
// DEFAULT_FLAGS = cudaEventDisableTiming.
|
||||||
cudaEvents_.resize(devices.size());
|
cudaEvents_.resize(devices.size());
|
||||||
|
ncclComms_.resize(devices.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
ProcessGroupNCCL::WorkNCCL::~WorkNCCL() {}
|
ProcessGroupNCCL::WorkNCCL::~WorkNCCL() {}
|
||||||
|
|
||||||
bool ProcessGroupNCCL::WorkNCCL::isCompleted() {
|
bool ProcessGroupNCCL::WorkNCCL::isCompleted() {
|
||||||
return finishedGPUExecution();
|
checkAndSetException();
|
||||||
|
return exception() || finishedGPUExecutionInternal();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ProcessGroupNCCL::WorkNCCL::isSuccess() const {
|
bool ProcessGroupNCCL::WorkNCCL::isSuccess() const {
|
||||||
return true;
|
if (exception()) {
|
||||||
|
// Already detected an exception.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return !checkForNCCLErrors(ncclComms_) && finishedGPUExecutionInternal();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::exception_ptr ProcessGroupNCCL::WorkNCCL::exception() const {
|
void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
|
||||||
throw std::runtime_error(
|
if (exception()) {
|
||||||
"exception() is not supported by NCCL process "
|
// We already have an exception.
|
||||||
"group's work, since isSuccess() will always return true, and "
|
return;
|
||||||
"isCompleted() and wait() will either succeed or throw");
|
}
|
||||||
|
|
||||||
|
auto exception_ptr = checkForNCCLErrors(ncclComms_);
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
exception_ = exception_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper that checks if the NCCL kernels are completed on the GPUs
|
// Helper that checks if the NCCL kernels are completed on the GPUs
|
||||||
bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() {
|
bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() {
|
||||||
|
checkAndSetException();
|
||||||
|
return finishedGPUExecutionInternal();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const {
|
||||||
for (size_t i = 0; i < devices_.size(); ++i) {
|
for (size_t i = 0; i < devices_.size(); ++i) {
|
||||||
// Checking the work's corresponding CUDA events' status
|
// Checking the work's corresponding CUDA events' status
|
||||||
auto ret = cudaEventQuery(cudaEvents_[i]);
|
auto ret = cudaEventQuery(cudaEvents_[i]);
|
||||||
|
|
@ -151,6 +172,16 @@ bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ProcessGroupNCCL::WorkNCCL::checkAndThrowException() {
|
||||||
|
// Set the appropriate exception if found.
|
||||||
|
checkAndSetException();
|
||||||
|
|
||||||
|
// Throw an exception, only if we have a valid exception.
|
||||||
|
if (exception()) {
|
||||||
|
std::rethrow_exception(exception());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Waiting on the work's corresponding CUDA events
|
// Waiting on the work's corresponding CUDA events
|
||||||
void ProcessGroupNCCL::WorkNCCL::synchronize() {
|
void ProcessGroupNCCL::WorkNCCL::synchronize() {
|
||||||
for (size_t i = 0; i < devices_.size(); ++i) {
|
for (size_t i = 0; i < devices_.size(); ++i) {
|
||||||
|
|
@ -163,6 +194,23 @@ void ProcessGroupNCCL::WorkNCCL::synchronize() {
|
||||||
AT_CUDA_CHECK(cudaDeviceSynchronize());
|
AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In case of blocking, wait for the operation to complete.
|
||||||
|
if (blockingWait_) {
|
||||||
|
// Wait for the operation to complete.
|
||||||
|
while (!isCompleted()) {
|
||||||
|
auto currentTimepoint = std::chrono::steady_clock::now();
|
||||||
|
if (std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
|
currentTimepoint - workStartTime_) > opTimeout_) {
|
||||||
|
throw std::runtime_error("Operation timed out!");
|
||||||
|
}
|
||||||
|
// Check for errors and throw appropriate exception.
|
||||||
|
checkAndThrowException();
|
||||||
|
std::this_thread::sleep_for(
|
||||||
|
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
|
||||||
|
}
|
||||||
|
checkAndThrowException();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as calling synchronize().
|
// Same as calling synchronize().
|
||||||
|
|
@ -180,8 +228,32 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||||
const std::shared_ptr<Store>& store,
|
const std::shared_ptr<Store>& store,
|
||||||
int rank,
|
int rank,
|
||||||
int size,
|
int size,
|
||||||
const std::string& groupName)
|
const std::string& groupName,
|
||||||
: ProcessGroup(rank, size), store_(store), groupName_(groupName) {
|
const std::chrono::milliseconds& opTimeout)
|
||||||
|
: ProcessGroup(rank, size),
|
||||||
|
store_(store),
|
||||||
|
groupName_(groupName),
|
||||||
|
terminateWatchdog_(false),
|
||||||
|
opTimeout_(opTimeout) {
|
||||||
|
char* blockingWait = getenv(NCCL_BLOCKING_WAIT);
|
||||||
|
try {
|
||||||
|
if (blockingWait != nullptr) {
|
||||||
|
auto val = std::stoi(blockingWait);
|
||||||
|
if (val == 1) {
|
||||||
|
// Make wait() and synchronize() a blocking call.
|
||||||
|
blockingWait_ = true;
|
||||||
|
} else if (val != 0) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Invalid value for environment variable: " +
|
||||||
|
std::string(NCCL_BLOCKING_WAIT));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (std::exception& e) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Invalid value for environment variable: " +
|
||||||
|
std::string(NCCL_BLOCKING_WAIT));
|
||||||
|
}
|
||||||
|
|
||||||
// Generate the Process Group ID for current PG, this needs to be identical
|
// Generate the Process Group ID for current PG, this needs to be identical
|
||||||
// for all processes
|
// for all processes
|
||||||
std::unique_lock<std::mutex> lock(pgTrackingLock_);
|
std::unique_lock<std::mutex> lock(pgTrackingLock_);
|
||||||
|
|
@ -194,11 +266,81 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||||
processGroupID_ = std::to_string(processGroupCounterMap_[groupKey]);
|
processGroupID_ = std::to_string(processGroupCounterMap_[groupKey]);
|
||||||
groupPgID_ = groupName_ + "_" + processGroupID_;
|
groupPgID_ = groupName_ + "_" + processGroupID_;
|
||||||
pgUniqueNCCLIDCnt_[groupPgID_] = -1;
|
pgUniqueNCCLIDCnt_[groupPgID_] = -1;
|
||||||
|
ncclCommWatchdogThread_ =
|
||||||
|
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
ProcessGroupNCCL::~ProcessGroupNCCL() {
|
ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||||
std::unique_lock<std::mutex> lock(pgTrackingLock_);
|
std::unique_lock<std::mutex> lock(pgTrackingLock_);
|
||||||
pgUniqueNCCLIDCnt_.erase(groupPgID_);
|
pgUniqueNCCLIDCnt_.erase(groupPgID_);
|
||||||
|
terminateWatchdog_.store(true);
|
||||||
|
watchdogCV_.notify_one();
|
||||||
|
ncclCommWatchdogThread_.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||||
|
while (!terminateWatchdog_.load()) {
|
||||||
|
{
|
||||||
|
// Loop through the cache of communicators for NCCL errors.
|
||||||
|
std::lock_guard<std::mutex> lock(devNCCLCommMapLock_);
|
||||||
|
for (auto it = devNCCLCommMap_.begin(); it != devNCCLCommMap_.end();) {
|
||||||
|
auto& ncclComms = it->second;
|
||||||
|
if (checkForNCCLErrors(ncclComms)) {
|
||||||
|
LOG(INFO) << "Received NCCL errors for communicators in the cache, "
|
||||||
|
"removing communicators from the cache and aborting the "
|
||||||
|
"communicators.";
|
||||||
|
|
||||||
|
if (blockingWait_) {
|
||||||
|
// We should not abort the communicators if we are performing a
|
||||||
|
// non-blocking wait(). The reason for this is that if we abort the
|
||||||
|
// nccl communicator, wait() might not throw exceptions and
|
||||||
|
// subsequent operations might run on garbage results.
|
||||||
|
// The current model is that when we call wait(), subsequent
|
||||||
|
// operations only run after this work is done or we hang forever
|
||||||
|
// waiting for the operation to complete.
|
||||||
|
for (const auto& ncclComm : ncclComms) {
|
||||||
|
ncclComm->ncclCommAbort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove communicators from the cache.
|
||||||
|
it = devNCCLCommMap_.erase(it);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
it++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(watchdogCVMutex_);
|
||||||
|
watchdogCV_.wait_for(
|
||||||
|
lock,
|
||||||
|
std::chrono::milliseconds(kWatchdogThreadSleepMillis),
|
||||||
|
[&]() -> bool { return terminateWatchdog_.load(); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(
|
||||||
|
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const {
|
||||||
|
return checkForNCCLErrorsInternal(ncclComms);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors(
|
||||||
|
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) {
|
||||||
|
return checkForNCCLErrorsInternal(ncclComms);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
|
||||||
|
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) {
|
||||||
|
for (const auto& ncclComm : ncclComms) {
|
||||||
|
ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
|
||||||
|
if (ncclAsyncErr != ncclSuccess) {
|
||||||
|
return std::make_exception_ptr(std::runtime_error(
|
||||||
|
"NCCL error: " + std::string(ncclGetErrorString(ncclAsyncErr))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) {
|
void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) {
|
||||||
|
|
@ -249,10 +391,14 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||||
usedDeviceIdxs_.insert(device.index());
|
usedDeviceIdxs_.insert(device.index());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (devNCCLCommMap_.find(devicesKey) != devNCCLCommMap_.end()) {
|
{
|
||||||
// Reuse the cached communicator if there is one.
|
std::lock_guard<std::mutex> lock(devNCCLCommMapLock_);
|
||||||
return devNCCLCommMap_[devicesKey];
|
if (devNCCLCommMap_.find(devicesKey) != devNCCLCommMap_.end()) {
|
||||||
|
// Reuse the cached communicator if there is one.
|
||||||
|
return devNCCLCommMap_[devicesKey];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NCCL communicator not cached, create a new entry
|
// NCCL communicator not cached, create a new entry
|
||||||
std::vector<std::shared_ptr<NCCLComm>> ncclComms;
|
std::vector<std::shared_ptr<NCCLComm>> ncclComms;
|
||||||
ncclComms.resize(devices.size());
|
ncclComms.resize(devices.size());
|
||||||
|
|
@ -289,8 +435,6 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||||
|
|
||||||
C10D_NCCL_CHECK(ncclGroupEnd());
|
C10D_NCCL_CHECK(ncclGroupEnd());
|
||||||
|
|
||||||
// Move the NCCL resource to cache
|
|
||||||
devNCCLCommMap_.emplace(devicesKey, std::move(ncclComms));
|
|
||||||
ncclStreams_.emplace(devicesKey, std::move(streamVal));
|
ncclStreams_.emplace(devicesKey, std::move(streamVal));
|
||||||
|
|
||||||
// Note: these events are created with the (default) cudaEventDisableTiming
|
// Note: these events are created with the (default) cudaEventDisableTiming
|
||||||
|
|
@ -302,6 +446,11 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||||
std::make_tuple(devicesKey),
|
std::make_tuple(devicesKey),
|
||||||
std::make_tuple(devices.size()));
|
std::make_tuple(devices.size()));
|
||||||
|
|
||||||
|
// Hold the lock before modifying the cache.
|
||||||
|
std::lock_guard<std::mutex> lock(devNCCLCommMapLock_);
|
||||||
|
|
||||||
|
// Move the NCCL resource to cache
|
||||||
|
devNCCLCommMap_.emplace(devicesKey, std::move(ncclComms));
|
||||||
return devNCCLCommMap_[devicesKey];
|
return devNCCLCommMap_[devicesKey];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -388,6 +537,11 @@ std::vector<at::Tensor> flatten_for_scatter_gather(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||||
|
std::vector<at::Device> devices) {
|
||||||
|
return std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Fn, typename PreProcess, typename PostProcess>
|
template <typename Fn, typename PreProcess, typename PostProcess>
|
||||||
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
|
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
|
||||||
std::vector<at::Tensor>& inputs,
|
std::vector<at::Tensor>& inputs,
|
||||||
|
|
@ -403,7 +557,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
|
||||||
syncStreams(devices, ncclEvents_[key], ncclStreams_[key]);
|
syncStreams(devices, ncclEvents_[key], ncclStreams_[key]);
|
||||||
|
|
||||||
// Work itself will create the CUDA events on all GPUs of tensors
|
// Work itself will create the CUDA events on all GPUs of tensors
|
||||||
auto work = std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices);
|
auto work = initWork(devices);
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard gpuGuard;
|
at::cuda::OptionalCUDAGuard gpuGuard;
|
||||||
|
|
||||||
|
|
@ -441,6 +595,9 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
|
at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
|
||||||
work->cudaEvents_[i].record(ncclStream);
|
work->cudaEvents_[i].record(ncclStream);
|
||||||
|
work->ncclComms_[i] = ncclComms[i];
|
||||||
|
work->blockingWait_ = blockingWait_;
|
||||||
|
work->opTimeout_ = opTimeout_;
|
||||||
}
|
}
|
||||||
|
|
||||||
return work;
|
return work;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <thread>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include <c10d/NCCLUtils.hpp>
|
#include <c10d/NCCLUtils.hpp>
|
||||||
|
|
@ -12,6 +13,10 @@
|
||||||
|
|
||||||
namespace c10d {
|
namespace c10d {
|
||||||
|
|
||||||
|
// Environment variable which controls whether or not wait() is blocking or
|
||||||
|
// non-blocking.
|
||||||
|
constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
|
||||||
|
|
||||||
// ProcessGroupNCCL implements NCCL bindings for c10d.
|
// ProcessGroupNCCL implements NCCL bindings for c10d.
|
||||||
//
|
//
|
||||||
// All functions of the class are expected to be called in the same order
|
// All functions of the class are expected to be called in the same order
|
||||||
|
|
@ -31,13 +36,6 @@ namespace c10d {
|
||||||
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
|
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
|
||||||
// functionality and are synonyms.
|
// functionality and are synonyms.
|
||||||
//
|
//
|
||||||
// Note that WorkNCCL::isSuccess() and WorkNCCL::isCompleted() will always
|
|
||||||
// return true since ProcessGroupNCCL is single threaded. Every single NCCL
|
|
||||||
// or CUDA failure will simply raise std::runtime_error.
|
|
||||||
//
|
|
||||||
// Therefore, WorkNCCL::exception() is not supported since isSuccess() always
|
|
||||||
// returns true.
|
|
||||||
//
|
|
||||||
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
|
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
|
||||||
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
|
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
|
||||||
// finished execution on the GPU (not just scheduled).
|
// finished execution on the GPU (not just scheduled).
|
||||||
|
|
@ -67,19 +65,16 @@ class ProcessGroupNCCL : public ProcessGroup {
|
||||||
// Non-blocking operation.
|
// Non-blocking operation.
|
||||||
bool isCompleted() override;
|
bool isCompleted() override;
|
||||||
|
|
||||||
|
bool isSuccess() const override;
|
||||||
|
|
||||||
// Same as calling synchronize() for NCCL work.
|
// Same as calling synchronize() for NCCL work.
|
||||||
void wait() override;
|
void wait() override;
|
||||||
|
|
||||||
// Will always return true
|
|
||||||
bool isSuccess() const override;
|
|
||||||
|
|
||||||
// Let current stream wait on the completing of the NCCL work
|
// Let current stream wait on the completing of the NCCL work
|
||||||
// Throws on exceptions. Non-blocking operation.
|
// Throws on exceptions. Blocking operation, which will wait for work
|
||||||
|
// completion.
|
||||||
void synchronize() override;
|
void synchronize() override;
|
||||||
|
|
||||||
// Will always throw because it should not be called (isSuccess() -> true).
|
|
||||||
std::exception_ptr exception() const override;
|
|
||||||
|
|
||||||
// Helper function that checks if the NCCL kernels have finished
|
// Helper function that checks if the NCCL kernels have finished
|
||||||
// execution on the GPUs
|
// execution on the GPUs
|
||||||
bool finishedGPUExecution();
|
bool finishedGPUExecution();
|
||||||
|
|
@ -91,9 +86,37 @@ class ProcessGroupNCCL : public ProcessGroup {
|
||||||
// The CUDA events tracking this work item on multiple CUDA devices
|
// The CUDA events tracking this work item on multiple CUDA devices
|
||||||
std::vector<at::cuda::CUDAEvent> cudaEvents_;
|
std::vector<at::cuda::CUDAEvent> cudaEvents_;
|
||||||
|
|
||||||
|
// The NCCL communicators used for this work item.
|
||||||
|
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
|
||||||
|
|
||||||
// Tensors used for barrier op
|
// Tensors used for barrier op
|
||||||
std::vector<at::Tensor> barrierTensors_;
|
std::vector<at::Tensor> barrierTensors_;
|
||||||
|
|
||||||
|
// Clone of blockingWait_ from ProcessGroupNCCL.
|
||||||
|
bool blockingWait_ = false;
|
||||||
|
|
||||||
|
// Clonge of opTimeout_ from ProcessGroupNCCL.
|
||||||
|
std::chrono::milliseconds opTimeout_;
|
||||||
|
|
||||||
|
// Time point representing when the work started.
|
||||||
|
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
|
||||||
|
|
||||||
|
// Wrapper method for the static checkForNCCLErrors which can be overridden
|
||||||
|
// for tests.
|
||||||
|
virtual std::exception_ptr checkForNCCLErrors(
|
||||||
|
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Checks for NCCL errors and sets an appropriate exception_ptr.
|
||||||
|
void checkAndSetException();
|
||||||
|
|
||||||
|
// Checks for NCCL errors and throws an appropriate exception.
|
||||||
|
void checkAndThrowException();
|
||||||
|
|
||||||
|
// Just checks whether GPU execution has completed, without modifying
|
||||||
|
// exception_ptr.
|
||||||
|
bool finishedGPUExecutionInternal() const;
|
||||||
|
|
||||||
friend class ProcessGroupNCCL;
|
friend class ProcessGroupNCCL;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -115,7 +138,9 @@ class ProcessGroupNCCL : public ProcessGroup {
|
||||||
const std::shared_ptr<Store>& store,
|
const std::shared_ptr<Store>& store,
|
||||||
int rank,
|
int rank,
|
||||||
int size,
|
int size,
|
||||||
const std::string& groupName = "");
|
const std::string& groupName = "",
|
||||||
|
const std::chrono::milliseconds& opTimeout =
|
||||||
|
std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis));
|
||||||
|
|
||||||
virtual ~ProcessGroupNCCL();
|
virtual ~ProcessGroupNCCL();
|
||||||
|
|
||||||
|
|
@ -169,6 +194,8 @@ class ProcessGroupNCCL : public ProcessGroup {
|
||||||
std::vector<at::Tensor>& tensors,
|
std::vector<at::Tensor>& tensors,
|
||||||
int tag) override;
|
int tag) override;
|
||||||
|
|
||||||
|
static const int64_t kProcessGroupNCCLOpTimeoutMillis;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Helper that broadcasts nccl unique ID to all ranks through the store
|
// Helper that broadcasts nccl unique ID to all ranks through the store
|
||||||
void broadcastUniqueNCCLID(ncclUniqueId* ncclID);
|
void broadcastUniqueNCCLID(ncclUniqueId* ncclID);
|
||||||
|
|
@ -179,6 +206,13 @@ class ProcessGroupNCCL : public ProcessGroup {
|
||||||
const std::string& devicesKey,
|
const std::string& devicesKey,
|
||||||
const std::vector<at::Device>& devices);
|
const std::vector<at::Device>& devices);
|
||||||
|
|
||||||
|
// Wrapper method which can be overridden for tests.
|
||||||
|
virtual std::exception_ptr checkForNCCLErrors(
|
||||||
|
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
|
||||||
|
|
||||||
|
virtual std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
|
||||||
|
std::vector<at::Device> devices);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Helper that encapsulates work shared across all collective communication
|
// Helper that encapsulates work shared across all collective communication
|
||||||
// primitives. The callbacks have the following signatures:
|
// primitives. The callbacks have the following signatures:
|
||||||
|
|
@ -199,7 +233,25 @@ class ProcessGroupNCCL : public ProcessGroup {
|
||||||
PreProcess pre,
|
PreProcess pre,
|
||||||
PostProcess post);
|
PostProcess post);
|
||||||
|
|
||||||
|
// Checks for NCCL errors on each of the communicators and returns an
|
||||||
|
// appropriate exception_ptr (nullptr if no errors).
|
||||||
|
static std::exception_ptr checkForNCCLErrorsInternal(
|
||||||
|
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
|
||||||
|
|
||||||
|
// Function that runs as part of a separate thread and checks for errors on
|
||||||
|
// NCCL communicators. We need a separate thread to check for NCCL errors
|
||||||
|
// since we can't rely on the user calling certain methods like wait(),
|
||||||
|
// isCompleted() etc. to detect and remediate errors. In addition to this, we
|
||||||
|
// need a mechanism to safely abort and remove NCCL communicators from our
|
||||||
|
// cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
|
||||||
|
// class. Attempting to modify the communicator cache from the WorkNCCL class
|
||||||
|
// might run into issues with object lifetime since the ProcessGroupNCCL
|
||||||
|
// object might get destroyed before the WorkNCCL object.
|
||||||
|
void ncclCommWatchdog();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
static const int64_t kWatchdogThreadSleepMillis;
|
||||||
|
|
||||||
// Store that is used to exchange each Ranks's NCCL unique ID
|
// Store that is used to exchange each Ranks's NCCL unique ID
|
||||||
std::shared_ptr<Store> store_;
|
std::shared_ptr<Store> store_;
|
||||||
|
|
||||||
|
|
@ -228,6 +280,21 @@ class ProcessGroupNCCL : public ProcessGroup {
|
||||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
|
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
|
||||||
devNCCLCommMap_;
|
devNCCLCommMap_;
|
||||||
|
|
||||||
|
// Mutex to guard devNCCLCommMap_.
|
||||||
|
std::mutex devNCCLCommMapLock_;
|
||||||
|
|
||||||
|
// Watchdog thread which looks for errors on the cached NCCL communicators.
|
||||||
|
std::thread ncclCommWatchdogThread_;
|
||||||
|
|
||||||
|
// Whether or not we should terminate the watchdog thread.
|
||||||
|
std::atomic<bool> terminateWatchdog_;
|
||||||
|
|
||||||
|
// Condition variable to control how long the watchdog thread waits.
|
||||||
|
std::condition_variable watchdogCV_;
|
||||||
|
|
||||||
|
// Mutex for watchdog.
|
||||||
|
std::mutex watchdogCVMutex_;
|
||||||
|
|
||||||
// The CUDA steams used by NCCL kernels
|
// The CUDA steams used by NCCL kernels
|
||||||
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
|
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
|
||||||
ncclStreams_;
|
ncclStreams_;
|
||||||
|
|
@ -266,6 +333,13 @@ class ProcessGroupNCCL : public ProcessGroup {
|
||||||
// is that different group can have different ranks and we need ensure that
|
// is that different group can have different ranks and we need ensure that
|
||||||
// each group has its own uniform process group ID for all its ranks.
|
// each group has its own uniform process group ID for all its ranks.
|
||||||
static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
|
static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
|
||||||
|
|
||||||
|
// Whether or not wait() and synchronize() are blocking operations that wait
|
||||||
|
// for the operation to complete.
|
||||||
|
bool blockingWait_ = false;
|
||||||
|
|
||||||
|
// Timeout for operations. This is only used when blockingWait_ is enabled.
|
||||||
|
std::chrono::milliseconds opTimeout_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace c10d
|
} // namespace c10d
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ if(USE_CUDA)
|
||||||
endif()
|
endif()
|
||||||
if(USE_C10D_NCCL)
|
if(USE_C10D_NCCL)
|
||||||
c10d_add_test(ProcessGroupNCCLTest.cpp c10d c10d_cuda_test)
|
c10d_add_test(ProcessGroupNCCLTest.cpp c10d c10d_cuda_test)
|
||||||
|
c10d_add_test(ProcessGroupNCCLErrorsTest.cpp c10d c10d_cuda_test gtest_main)
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
if(USE_C10D_GLOO)
|
if(USE_C10D_GLOO)
|
||||||
|
|
|
||||||
168
torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp
Normal file
168
torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
#include <c10d/FileStore.hpp>
|
||||||
|
#include <c10d/ProcessGroupNCCL.hpp>
|
||||||
|
#include <c10d/test/CUDATest.hpp>
|
||||||
|
#include <c10d/test/TestUtils.hpp>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
using namespace c10d::test;
|
||||||
|
|
||||||
|
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
|
||||||
|
public:
|
||||||
|
WorkNCCLSimulateErrors(
|
||||||
|
const std::vector<at::Device>& devices,
|
||||||
|
bool simulate_error)
|
||||||
|
: WorkNCCL(devices), simulate_error_(simulate_error) {}
|
||||||
|
|
||||||
|
std::exception_ptr checkForNCCLErrors(
|
||||||
|
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms)
|
||||||
|
const override {
|
||||||
|
if (simulate_error_) {
|
||||||
|
return std::make_exception_ptr(std::runtime_error("Error"));
|
||||||
|
}
|
||||||
|
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool simulate_error_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
|
||||||
|
public:
|
||||||
|
ProcessGroupNCCLSimulateErrors(
|
||||||
|
const std::shared_ptr<c10d::Store>& store,
|
||||||
|
int rank,
|
||||||
|
int size)
|
||||||
|
: ProcessGroupNCCL(store, rank, size), simulate_error_(false) {}
|
||||||
|
|
||||||
|
std::exception_ptr checkForNCCLErrors(
|
||||||
|
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
|
||||||
|
if (simulate_error_) {
|
||||||
|
return std::make_exception_ptr(std::runtime_error("Error"));
|
||||||
|
}
|
||||||
|
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
|
||||||
|
return std::chrono::milliseconds(
|
||||||
|
ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
|
||||||
|
std::vector<at::Device> devices) override {
|
||||||
|
return std::make_shared<WorkNCCLSimulateErrors>(devices, simulate_error_);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getNCCLCommCacheSize() {
|
||||||
|
return devNCCLCommMap_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void simulate_error() {
|
||||||
|
simulate_error_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset_error() {
|
||||||
|
simulate_error_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool simulate_error_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ProcessGroupNCCLErrorsTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
bool skipTest() {
|
||||||
|
// Skip test if no cuda devices found.
|
||||||
|
return cudaNumDevices() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetUp() override {
|
||||||
|
size_t numDevices = cudaNumDevices();
|
||||||
|
TemporaryFile file;
|
||||||
|
store_ = std::make_shared<::c10d::FileStore>(file.path, 1);
|
||||||
|
|
||||||
|
at::cuda::OptionalCUDAGuard deviceGuard;
|
||||||
|
tensors_.resize(numDevices);
|
||||||
|
for (auto i = 0; i < numDevices; ++i) {
|
||||||
|
deviceGuard.set_index(i);
|
||||||
|
tensors_[i] = at::ones({3, 3}, at::kCUDA);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "0", 1) == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<at::Tensor> tensors_;
|
||||||
|
std::shared_ptr<::c10d::FileStore> store_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
|
||||||
|
if (skipTest()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
|
||||||
|
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1);
|
||||||
|
|
||||||
|
auto work = pg.allreduce(tensors_);
|
||||||
|
work->wait();
|
||||||
|
EXPECT_TRUE(work->isSuccess());
|
||||||
|
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
|
||||||
|
|
||||||
|
// Now run all reduce with errors.
|
||||||
|
pg.simulate_error();
|
||||||
|
work = pg.allreduce(tensors_);
|
||||||
|
EXPECT_THROW(work->wait(), std::runtime_error);
|
||||||
|
|
||||||
|
// Verify the work item failed.
|
||||||
|
EXPECT_TRUE(work->isCompleted());
|
||||||
|
EXPECT_FALSE(work->isSuccess());
|
||||||
|
EXPECT_THROW(work->wait(), std::runtime_error);
|
||||||
|
|
||||||
|
// Should remove the nccl communicators which hit errors from the cache.
|
||||||
|
std::this_thread::sleep_for(2 * pg.getWatchdogSleepInterval());
|
||||||
|
EXPECT_EQ(0, pg.getNCCLCommCacheSize());
|
||||||
|
|
||||||
|
// Verify we can recover from errors.
|
||||||
|
pg.reset_error();
|
||||||
|
work = pg.allreduce(tensors_);
|
||||||
|
work->wait();
|
||||||
|
EXPECT_TRUE(work->isSuccess());
|
||||||
|
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
|
||||||
|
if (skipTest()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1);
|
||||||
|
|
||||||
|
auto work = pg.allreduce(tensors_);
|
||||||
|
pg.barrier()->wait();
|
||||||
|
EXPECT_TRUE(work->isSuccess());
|
||||||
|
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
|
||||||
|
|
||||||
|
// Now run all reduce with errors.
|
||||||
|
pg.simulate_error();
|
||||||
|
work = pg.allreduce(tensors_);
|
||||||
|
|
||||||
|
// Should not throw exceptions.
|
||||||
|
work->wait();
|
||||||
|
pg.barrier()->wait();
|
||||||
|
|
||||||
|
// Verify the work item failed.
|
||||||
|
EXPECT_TRUE(work->isCompleted());
|
||||||
|
EXPECT_FALSE(work->isSuccess());
|
||||||
|
|
||||||
|
// Should remove the nccl communicators which hit errors from the cache.
|
||||||
|
std::this_thread::sleep_for(2 * pg.getWatchdogSleepInterval());
|
||||||
|
EXPECT_EQ(0, pg.getNCCLCommCacheSize());
|
||||||
|
|
||||||
|
// Verify we can recover from errors.
|
||||||
|
pg.reset_error();
|
||||||
|
work = pg.allreduce(tensors_);
|
||||||
|
pg.barrier()->wait();
|
||||||
|
EXPECT_TRUE(work->isSuccess());
|
||||||
|
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user