Back out "Revert D31005792: [NCCL] Init dummy NCCL comms in constructor" (#65883)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65883

Original commit changeset: d8e962b8aab6
ghstack-source-id: 139836954

Test Plan: ci

Reviewed By: zhaojuanmao

Differential Revision: D31299350

fbshipit-source-id: 9ad5c8fa17f7038ba579cb1eda6d9271ac07a130
This commit is contained in:
Rohan Varma 2021-10-08 15:58:27 -07:00 committed by Facebook GitHub Bot
parent c1343ff706
commit f1f3bd8c36
6 changed files with 213 additions and 31 deletions

View File

@ -1,9 +1,12 @@
#include <chrono>
#include <iostream> #include <iostream>
#include <c10d/FileStore.hpp> #include <c10d/FileStore.hpp>
#include <c10d/ProcessGroupNCCL.hpp> #include <c10d/ProcessGroupNCCL.hpp>
#include "CUDATest.hpp" #include "CUDATest.hpp"
#include "TestUtils.hpp" #include "TestUtils.hpp"
#include "c10d/ProcessGroup.hpp"
#include "c10d/Types.hpp"
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
@ -19,7 +22,7 @@ using c10d::ProcessGroup;
class NCCLTestBase { class NCCLTestBase {
public: public:
NCCLTestBase(const std::string& path) : path_(path) {} NCCLTestBase(const std::string& path, const std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout) : path_(path), pgTimeout_(pgTimeout) {}
NCCLTestBase(NCCLTestBase&& other) { NCCLTestBase(NCCLTestBase&& other) {
path_ = std::move(other.path_); path_ = std::move(other.path_);
@ -33,19 +36,22 @@ class NCCLTestBase {
void initialize(int rank, int size) { void initialize(int rank, int size) {
auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts = c10::make_intrusive<c10d::ProcessGroupNCCL::Options>();
opts->timeout = pgTimeout_;
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
new ::c10d::ProcessGroupNCCL(store, rank, size)); new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
} }
protected: protected:
std::string path_; std::string path_;
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_; std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
std::chrono::milliseconds pgTimeout_;
}; };
class NCCLTest : public NCCLTestBase { class NCCLTest : public NCCLTestBase {
public: public:
NCCLTest(const std::string& path, int worldSize) NCCLTest(const std::string& path, int worldSize, std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout)
: NCCLTestBase(path), : NCCLTestBase(path, pgTimeout),
numDevices_(cudaNumDevices()), numDevices_(cudaNumDevices()),
state_(::at::globalContext().lazyInitCUDA()), state_(::at::globalContext().lazyInitCUDA()),
worldSize_(worldSize) { worldSize_(worldSize) {
@ -497,10 +503,50 @@ void testReduceScatter(const std::string& path, int rank, int size) {
} }
} }
void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool timeout) {
// simulate world_size > 1 here via threads.
const int worldSize = 4;
std::mutex m;
std::unordered_set<uint64_t> nums;
auto runTest = [&](int i) {
NCCLTest test(path, worldSize, std::chrono::milliseconds(3000));
// Catch error relating to health check failure
bool error_caught = false;
try {
test.initialize(timeout ? 0 : -1, worldSize);
} catch (const std::exception &e) {
std::string errMsg = e.what();
const std::string kTimeoutErr = "Failed to initialize NCCL communicator on rank";
const std::string kInvalidRankErr = "Invalid rank";
std::string expectedSubstr = timeout ? kTimeoutErr : kInvalidRankErr;
bool cond = errMsg.find(expectedSubstr) != std::string::npos;
EXPECT_TRUE(cond);
error_caught = true;
}
EXPECT_TRUE(error_caught);
};
std::vector<std::thread> threads;
threads.reserve(worldSize);
for (const auto r : c10::irange(worldSize)) {
threads.emplace_back(std::thread([=]() { runTest(r); }));
}
for (auto& t : threads) {
t.join();
}
}
void testProcessGroupNCCLHealthCheckFailException(const std::string& path, int /* unused */, int /* unused */) {
testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ false);
}
void testProcessGroupNCCLHealthCheckFailTimeout(const std::string& path, int /* unused */, int /* unused */) {
testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ true);
}
void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) { void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) {
// Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we // Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we
// simulate world_size > 1 here via threads. // simulate world_size > 1 here via threads.
const int worldSize = 4; const int worldSize = 2;
std::mutex m; std::mutex m;
std::unordered_set<uint64_t> nums; std::unordered_set<uint64_t> nums;
auto runTest = [&](int i) { auto runTest = [&](int i) {
@ -625,6 +671,26 @@ TEST_F(ProcessGroupNCCLTest, testSequenceNumInit) {
} }
} }
TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailTimeout) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_);
}
}
TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailException) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_);
}
}
TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) { TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) {
if (skipTest()) { if (skipTest()) {
return; return;

View File

@ -657,11 +657,13 @@ class DistributedDataParallelTest(
# otherwise process will be taken down and we can't check for errors. # otherwise process will be taken down and we can't check for errors.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
os.environ["NCCL_BLOCKING_WAIT"] = "1" os.environ["NCCL_BLOCKING_WAIT"] = "1"
timeout = timedelta(seconds=2) # TODO: smaller timeout can fail since PG NCCl does health check in
# constructor. Look into reducing this test's runtime.
store = c10d.FileStore(self.file_name, self.world_size) store = c10d.FileStore(self.file_name, self.world_size)
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timeout) # provide sufficient timeout to initialize NCCL comm.
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=15))
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size) pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
pg.barrier().wait() pg.barrier().wait(timedelta(seconds=5))
# Simulate stuckness in rank 0. # Simulate stuckness in rank 0.
if self.rank == 0: if self.rank == 0:
pg_gloo.barrier().wait() pg_gloo.barrier().wait()
@ -670,7 +672,7 @@ class DistributedDataParallelTest(
if self.rank != 0: if self.rank != 0:
# Time out due to rank 0 not calling into allreduce. # Time out due to rank 0 not calling into allreduce.
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
pg.allreduce([inp]).wait() pg.allreduce([inp]).wait(timedelta(seconds=5))
# Now when nonzero rank attempts to use communicator, original failure reason should be logged.j # Now when nonzero rank attempts to use communicator, original failure reason should be logged.j
try: try:
@ -2263,14 +2265,14 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
store, store,
self.rank, self.rank,
self.world_size, self.world_size,
timeout=timedelta(seconds=self.op_timeout_sec), timeout=timedelta(seconds=10),
) )
process_group.allreduce(torch.rand(10).cuda(self.rank)) process_group.allreduce(torch.rand(10).cuda(self.rank))
if self.rank == 0: if self.rank == 0:
work = process_group.allreduce(torch.rand(10).cuda(self.rank)) work = process_group.allreduce(torch.rand(10).cuda(self.rank))
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
# Operation would time out in blocking mode. # Operation would time out in blocking mode.
work.wait() work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
# Run some GPU operations to make sure cuda has not gotten stuck. # Run some GPU operations to make sure cuda has not gotten stuck.
# It was observed cuda could get stuck if NCCL communicators were # It was observed cuda could get stuck if NCCL communicators were
# not properly aborted before throwing RuntimeError. # not properly aborted before throwing RuntimeError.
@ -2339,13 +2341,13 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
store, store,
self.rank, self.rank,
self.world_size, self.world_size,
timeout=timedelta(seconds=self.op_timeout_sec), timeout=timedelta(seconds=10),
) )
process_group.barrier().wait() process_group.barrier().wait()
if self.rank == 0: if self.rank == 0:
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
# This should timeout # This should timeout
process_group.barrier().wait() process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec))
def _run_invalid_nccl_blocking_wait_env(self, val): def _run_invalid_nccl_blocking_wait_env(self, val):
os.environ["NCCL_BLOCKING_WAIT"] = val os.environ["NCCL_BLOCKING_WAIT"] = val
@ -2382,21 +2384,20 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
store = c10d.FileStore(self.file_name, self.world_size) store = c10d.FileStore(self.file_name, self.world_size)
# Initialize process_group. # Initialize process_group.
timeout = 1
process_group = c10d.ProcessGroupNCCL( process_group = c10d.ProcessGroupNCCL(
store, self.rank, self.world_size, timeout=timedelta(seconds=timeout) store, self.rank, self.world_size, timeout=timedelta(seconds=10)
) )
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait() process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
if self.rank == 0: if self.rank == 0:
# This should timeout in about 1 second. # This should timeout in about 1 second.
start = time.time() start = time.time()
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out. # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg): with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait() process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
else: else:
# Sleep to ensure timeout. # Sleep to ensure timeout.
time.sleep(2 * timeout) time.sleep(2)
self._wait_for_comm_abort(process_group) self._wait_for_comm_abort(process_group)
@ -2546,14 +2547,14 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
store = c10d.FileStore(self.file_name, self.world_size) store = c10d.FileStore(self.file_name, self.world_size)
if self.rank == 0: if self.rank == 0:
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Timed out initializing process group" RuntimeError, "Health check failure"
): ):
c10d.init_process_group( c10d.init_process_group(
backend="nccl", backend="nccl",
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
store=store, store=store,
timeout=timedelta(seconds=1), timeout=timedelta(seconds=10),
) )
@requires_nccl() @requires_nccl()
@ -2565,12 +2566,12 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
store=store, store=store,
timeout=timedelta(seconds=1), timeout=timedelta(seconds=10),
) )
if self.rank == 0: if self.rank == 0:
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Timed out initializing process group" RuntimeError, "Health check failure"
): ):
c10d.new_group([0, 1], timeout=timedelta(seconds=1)) c10d.new_group([0, 1], timeout=timedelta(seconds=1))
@ -2588,12 +2589,12 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
store=store, store=store,
timeout=timedelta(seconds=1), timeout=timedelta(seconds=10),
) )
if self.rank == 1: if self.rank == 1:
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Timed out initializing process group" RuntimeError, "Health check failure"
): ):
c10d.new_group([0, 1], timeout=timedelta(seconds=1)) c10d.new_group([0, 1], timeout=timedelta(seconds=1))

View File

@ -39,7 +39,7 @@ class ProcessGroupNCCLJitTest(JitTestCase):
def _create_nccl_pg(self, name_prefix): def _create_nccl_pg(self, name_prefix):
tcp_store = create_tcp_store(jit_class=True) tcp_store = create_tcp_store(jit_class=True)
opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(0, True) opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(10000, True)
name = unique_process_group_name(name_prefix) name = unique_process_group_name(name_prefix)
@ -49,7 +49,7 @@ class ProcessGroupNCCLJitTest(JitTestCase):
tcp_store = create_tcp_store(jit_class=True) tcp_store = create_tcp_store(jit_class=True)
return torch.classes.dist_c10d.frontend().new_process_group_helper( return torch.classes.dist_c10d.frontend().new_process_group_helper(
self.world_size, self.rank, [], "nccl", tcp_store, name, 0) self.world_size, self.rank, [], "nccl", tcp_store, name, 10000)
@requires_nccl() @requires_nccl()
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") @sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
@ -172,7 +172,7 @@ class C10dFrontendJitTest(JitTestCase):
pg_name = unique_process_group_name("singleton_test_process_group") pg_name = unique_process_group_name("singleton_test_process_group")
ProcessGroupNCCL1 = frontend1.new_process_group_helper( ProcessGroupNCCL1 = frontend1.new_process_group_helper(
self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 0) self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 10000)
ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name) ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name)
self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name) self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name)
@ -189,7 +189,7 @@ class C10dProcessGroupSerialization(JitTestCase):
name = unique_process_group_name("module_member_process_group") name = unique_process_group_name("module_member_process_group")
self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper( self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper(
1, 0, [], "nccl", tcp_store, name, 0) 1, 0, [], "nccl", tcp_store, name, 10000)
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
if self.pg is None: if self.pg is None:

View File

@ -1,4 +1,5 @@
#include <c10d/ProcessGroupNCCL.hpp> #include <c10d/ProcessGroupNCCL.hpp>
#include <c10/util/Exception.h>
#include <sstream> #include <sstream>
#ifdef USE_C10D_NCCL #ifdef USE_C10D_NCCL
@ -13,6 +14,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGraphsC10Utils.h> #include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/core/DeviceType.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <c10/util/Logging.h> #include <c10/util/Logging.h>
@ -159,6 +161,14 @@ std::vector<at::Device> getDeviceList(const std::vector<at::Tensor>& tensors) {
return res; return res;
} }
// Return CUDA device with ordinal given by input rank.
at::Device getDeviceForRank(int rank) {
TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
auto numGPUs = at::cuda::getNumGPUs();
int16_t deviceIdx = static_cast<int16_t>(rank % numGPUs);
return at::Device(at::DeviceType::CUDA, deviceIdx);
}
// [Sync Streams] Helper that lets the input ncclStreams to wait for the current // [Sync Streams] Helper that lets the input ncclStreams to wait for the current
// stream. NCCL communications run on ncclStreams, but input tensors are // stream. NCCL communications run on ncclStreams, but input tensors are
// allocated on different streams (i.e., current streams). Communications on // allocated on different streams (i.e., current streams). Communications on
@ -502,6 +512,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
asyncErrorHandling_ = false; asyncErrorHandling_ = false;
} }
// Perform health check by initializing dummy communicators and destroying
// them. This will help indicate any NCCL-related issues prior to the first
// collective.
// Run it in a separate thread and wait on CV to handle timeouts, since
// majority of getNCCLComm failures are hangs.
runHealthCheck();
#ifdef ENABLE_NCCL_ERROR_CHECKING #ifdef ENABLE_NCCL_ERROR_CHECKING
ncclCommWatchdogThread_ = ncclCommWatchdogThread_ =
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
@ -527,6 +544,64 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< "\nNCCL_DEBUG: " << ncclDebugLevel; << "\nNCCL_DEBUG: " << ncclDebugLevel;
} }
void ProcessGroupNCCL::runHealthCheck() {
// Run health check in a separate thread and wait on CV to handle timeouts,
// since majority of getNCCLComm failures are hangs.
struct HealthCheckData {
std::mutex healthCheckMutex;
std::condition_variable healthCheckCv;
bool healthCheckSuccess = false;
std::exception_ptr healthCheckException;
};
HealthCheckData healthCheckData;
auto t = std::thread([&healthCheckData, this]() {
try {
std::vector<at::Device> rankDevice = {getDeviceForRank(rank_)};
const auto key = getKeyFromDevices(rankDevice);
// OpType does not matter, only need to set to not go through send/recv
// path.
getNCCLComm(key, rankDevice, OpType::ALLREDUCE);
// Now destroy the communicators and remove them from cache so we don't
// use destroyed communicators.
destroyNCCLComms(key);
// Notify main thread the health check is complete.
{
std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
healthCheckData.healthCheckSuccess = true;
}
healthCheckData.healthCheckCv.notify_one();
} catch (const std::exception& e) {
// Populate exception ptr.
healthCheckData.healthCheckException = std::current_exception();
// Unblock waiting main thread which will report exception.
healthCheckData.healthCheckCv.notify_one();
} // Unknown exceptions will just cause the program to terminate.
});
// We don't need to join the thread, just need to verify health check via the
// CV. Hence we detach the thread here.
t.detach(); // NOLINT
LOG(INFO) << "[Rank " << rank_ << "]"
<< " will wait up to " << options_->timeout.count()
<< " msec for NCCL health check to complete.";
std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
healthCheckData.healthCheckCv.wait_for(
lock, options_->timeout, [&healthCheckData]() {
return healthCheckData.healthCheckSuccess;
});
if (healthCheckData.healthCheckException) {
std::rethrow_exception(healthCheckData.healthCheckException);
}
// If there is no exception, the likely culprit is a timeout/hang which is how
// most communicator init issues manifest themselves.
TORCH_CHECK(
healthCheckData.healthCheckSuccess,
"ProcessGroupNCCL: Health check failure: Failed to initialize NCCL communicator on rank ",
rank_);
}
void ProcessGroupNCCL::setSequenceNumberForGroup() { void ProcessGroupNCCL::setSequenceNumberForGroup() {
if (rank_ == 0) { if (rank_ == 0) {
// Create and broadcast sequence number // Create and broadcast sequence number
@ -874,6 +949,30 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
} }
} }
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
std::lock_guard<std::mutex> lock(mutex_);
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
TORCH_INTERNAL_ASSERT(
false,
"Expected to find key ",
devNCCLCommMapKey,
" in NCCL communicator map.");
}
std::vector<std::shared_ptr<NCCLComm>>& ncclComms =
devNCCLCommMap_[devNCCLCommMapKey];
// Loop through communicators and call ncclCommAbort.
for (const auto& comm : ncclComms) {
// ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being
// destroyed, so using ncclCommAbort here.
comm->ncclCommAbort();
}
// Remove communicators from the cache.
devNCCLCommMap_.erase(devNCCLCommMapKey);
// Clear used device indices.
usedDeviceIdxs_.clear();
}
std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm( std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
const std::string& devicesKey, const std::string& devicesKey,
const std::vector<at::Device>& devices, const std::vector<at::Device>& devices,
@ -1697,7 +1796,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::barrier(
"This can potentially cause a hang if this rank to GPU mapping is incorrect.", "This can potentially cause a hang if this rank to GPU mapping is incorrect.",
"Specify device_ids in barrier() to force use of a particular device." "Specify device_ids in barrier() to force use of a particular device."
); );
devices.emplace_back(at::DeviceType::CUDA, deviceIdx); devices.emplace_back(getDeviceForRank(rank_));
} else { } else {
for (auto usedDeviceIdx : usedDeviceIdxs_) { for (auto usedDeviceIdx : usedDeviceIdxs_) {
devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx);

View File

@ -430,6 +430,18 @@ class TORCH_API ProcessGroupNCCL : public ProcessGroup {
void abortTimedOutCollectives( void abortTimedOutCollectives(
std::unordered_set<std::string>& abortedCommIds); std::unordered_set<std::string>& abortedCommIds);
// Performs a health check by initializing dummy NCCL communicators and then
// destroying them. This will help indicate and signal any NCCL-related issues
// prior to the first collective. The actual initialization and subsequent
// destruction is ran on a separate thread and the main thread is signalled
// about timeouts/errors to report to the application.
void runHealthCheck();
// Destroys initialized NCCL communicators in devNCCLComMap_ given by input
// key. Throws if there are no communicators to destroy. Also removes
// communicators from the cache and clears used device indices.
void destroyNCCLComms(const std::string& devNCCLCommMapKey);
void workCleanupLoop(); void workCleanupLoop();
protected: protected:

View File

@ -7343,7 +7343,9 @@ class DistributedTest:
# tests expected behavior when nonzero rank hangs. # tests expected behavior when nonzero rank hangs.
nccl_pg = dist.new_group( nccl_pg = dist.new_group(
ranks=list(i for i in range(int(self.world_size))), ranks=list(i for i in range(int(self.world_size))),
timeout=timedelta(seconds=2), # provide sufficient timeout so communicators
# can be initialized in ctor.
timeout=timedelta(seconds=15),
backend=dist.Backend.NCCL, backend=dist.Backend.NCCL,
) )
gloo_pg = dist.new_group( gloo_pg = dist.new_group(
@ -7354,7 +7356,7 @@ class DistributedTest:
# Let all ranks call allreduce first to set up communicators etc. # Let all ranks call allreduce first to set up communicators etc.
# Directly simulating error here will run into store issue described # Directly simulating error here will run into store issue described
# in https://github.com/pytorch/pytorch/issues/54524. # in https://github.com/pytorch/pytorch/issues/54524.
nccl_pg.allreduce(tensors).wait() nccl_pg.allreduce(tensors).wait(timedelta(seconds=5))
# All ranks besides 0 call into allreduce. This is to simulate a # All ranks besides 0 call into allreduce. This is to simulate a
# desync across the world, where some ranks call into # desync across the world, where some ranks call into
# monitored_barrier() and others are stuck in collective comm. In # monitored_barrier() and others are stuck in collective comm. In
@ -7388,6 +7390,8 @@ class DistributedTest:
monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
) )
self._barrier(timeout=30)
@with_nccl_blocking_wait @with_nccl_blocking_wait
@require_backend({"gloo", "nccl"}) @require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"}) @require_backends_available({"gloo", "nccl"})