mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Back out "Revert D31005792: [NCCL] Init dummy NCCL comms in constructor" (#65883)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65883 Original commit changeset: d8e962b8aab6 ghstack-source-id: 139836954 Test Plan: ci Reviewed By: zhaojuanmao Differential Revision: D31299350 fbshipit-source-id: 9ad5c8fa17f7038ba579cb1eda6d9271ac07a130
This commit is contained in:
parent
c1343ff706
commit
f1f3bd8c36
|
|
@ -1,9 +1,12 @@
|
|||
#include <chrono>
|
||||
#include <iostream>
|
||||
|
||||
#include <c10d/FileStore.hpp>
|
||||
#include <c10d/ProcessGroupNCCL.hpp>
|
||||
#include "CUDATest.hpp"
|
||||
#include "TestUtils.hpp"
|
||||
#include "c10d/ProcessGroup.hpp"
|
||||
#include "c10d/Types.hpp"
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
|
@ -19,7 +22,7 @@ using c10d::ProcessGroup;
|
|||
|
||||
class NCCLTestBase {
|
||||
public:
|
||||
NCCLTestBase(const std::string& path) : path_(path) {}
|
||||
NCCLTestBase(const std::string& path, const std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout) : path_(path), pgTimeout_(pgTimeout) {}
|
||||
|
||||
NCCLTestBase(NCCLTestBase&& other) {
|
||||
path_ = std::move(other.path_);
|
||||
|
|
@ -33,19 +36,22 @@ class NCCLTestBase {
|
|||
void initialize(int rank, int size) {
|
||||
auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts = c10::make_intrusive<c10d::ProcessGroupNCCL::Options>();
|
||||
opts->timeout = pgTimeout_;
|
||||
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
|
||||
new ::c10d::ProcessGroupNCCL(store, rank, size));
|
||||
new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
|
||||
}
|
||||
|
||||
protected:
|
||||
std::string path_;
|
||||
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
|
||||
std::chrono::milliseconds pgTimeout_;
|
||||
};
|
||||
|
||||
class NCCLTest : public NCCLTestBase {
|
||||
public:
|
||||
NCCLTest(const std::string& path, int worldSize)
|
||||
: NCCLTestBase(path),
|
||||
NCCLTest(const std::string& path, int worldSize, std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout)
|
||||
: NCCLTestBase(path, pgTimeout),
|
||||
numDevices_(cudaNumDevices()),
|
||||
state_(::at::globalContext().lazyInitCUDA()),
|
||||
worldSize_(worldSize) {
|
||||
|
|
@ -497,10 +503,50 @@ void testReduceScatter(const std::string& path, int rank, int size) {
|
|||
}
|
||||
}
|
||||
|
||||
void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool timeout) {
|
||||
// simulate world_size > 1 here via threads.
|
||||
const int worldSize = 4;
|
||||
std::mutex m;
|
||||
std::unordered_set<uint64_t> nums;
|
||||
auto runTest = [&](int i) {
|
||||
NCCLTest test(path, worldSize, std::chrono::milliseconds(3000));
|
||||
// Catch error relating to health check failure
|
||||
bool error_caught = false;
|
||||
try {
|
||||
test.initialize(timeout ? 0 : -1, worldSize);
|
||||
} catch (const std::exception &e) {
|
||||
std::string errMsg = e.what();
|
||||
const std::string kTimeoutErr = "Failed to initialize NCCL communicator on rank";
|
||||
const std::string kInvalidRankErr = "Invalid rank";
|
||||
std::string expectedSubstr = timeout ? kTimeoutErr : kInvalidRankErr;
|
||||
bool cond = errMsg.find(expectedSubstr) != std::string::npos;
|
||||
EXPECT_TRUE(cond);
|
||||
error_caught = true;
|
||||
}
|
||||
EXPECT_TRUE(error_caught);
|
||||
};
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(worldSize);
|
||||
for (const auto r : c10::irange(worldSize)) {
|
||||
threads.emplace_back(std::thread([=]() { runTest(r); }));
|
||||
}
|
||||
for (auto& t : threads) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
|
||||
void testProcessGroupNCCLHealthCheckFailException(const std::string& path, int /* unused */, int /* unused */) {
|
||||
testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ false);
|
||||
}
|
||||
|
||||
void testProcessGroupNCCLHealthCheckFailTimeout(const std::string& path, int /* unused */, int /* unused */) {
|
||||
testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ true);
|
||||
}
|
||||
|
||||
void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) {
|
||||
// Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we
|
||||
// simulate world_size > 1 here via threads.
|
||||
const int worldSize = 4;
|
||||
const int worldSize = 2;
|
||||
std::mutex m;
|
||||
std::unordered_set<uint64_t> nums;
|
||||
auto runTest = [&](int i) {
|
||||
|
|
@ -625,6 +671,26 @@ TEST_F(ProcessGroupNCCLTest, testSequenceNumInit) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailTimeout) {
|
||||
if (skipTest()) {
|
||||
return;
|
||||
}
|
||||
{
|
||||
TemporaryFile file;
|
||||
testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailException) {
|
||||
if (skipTest()) {
|
||||
return;
|
||||
}
|
||||
{
|
||||
TemporaryFile file;
|
||||
testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) {
|
||||
if (skipTest()) {
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -657,11 +657,13 @@ class DistributedDataParallelTest(
|
|||
# otherwise process will be taken down and we can't check for errors.
|
||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
||||
timeout = timedelta(seconds=2)
|
||||
# TODO: smaller timeout can fail since PG NCCl does health check in
|
||||
# constructor. Look into reducing this test's runtime.
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timeout)
|
||||
# provide sufficient timeout to initialize NCCL comm.
|
||||
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=15))
|
||||
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
|
||||
pg.barrier().wait()
|
||||
pg.barrier().wait(timedelta(seconds=5))
|
||||
# Simulate stuckness in rank 0.
|
||||
if self.rank == 0:
|
||||
pg_gloo.barrier().wait()
|
||||
|
|
@ -670,7 +672,7 @@ class DistributedDataParallelTest(
|
|||
if self.rank != 0:
|
||||
# Time out due to rank 0 not calling into allreduce.
|
||||
with self.assertRaises(RuntimeError):
|
||||
pg.allreduce([inp]).wait()
|
||||
pg.allreduce([inp]).wait(timedelta(seconds=5))
|
||||
|
||||
# Now when nonzero rank attempts to use communicator, original failure reason should be logged.j
|
||||
try:
|
||||
|
|
@ -2263,14 +2265,14 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
store,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
timeout=timedelta(seconds=self.op_timeout_sec),
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
if self.rank == 0:
|
||||
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
||||
# Operation would time out in blocking mode.
|
||||
work.wait()
|
||||
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
|
||||
# Run some GPU operations to make sure cuda has not gotten stuck.
|
||||
# It was observed cuda could get stuck if NCCL communicators were
|
||||
# not properly aborted before throwing RuntimeError.
|
||||
|
|
@ -2339,13 +2341,13 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
store,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
timeout=timedelta(seconds=self.op_timeout_sec),
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
process_group.barrier().wait()
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
||||
# This should timeout
|
||||
process_group.barrier().wait()
|
||||
process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec))
|
||||
|
||||
def _run_invalid_nccl_blocking_wait_env(self, val):
|
||||
os.environ["NCCL_BLOCKING_WAIT"] = val
|
||||
|
|
@ -2382,21 +2384,20 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
|
||||
# Initialize process_group.
|
||||
timeout = 1
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store, self.rank, self.world_size, timeout=timedelta(seconds=timeout)
|
||||
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
|
||||
)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
|
||||
|
||||
if self.rank == 0:
|
||||
# This should timeout in about 1 second.
|
||||
start = time.time()
|
||||
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
|
||||
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
|
||||
else:
|
||||
# Sleep to ensure timeout.
|
||||
time.sleep(2 * timeout)
|
||||
time.sleep(2)
|
||||
|
||||
self._wait_for_comm_abort(process_group)
|
||||
|
||||
|
|
@ -2546,14 +2547,14 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Timed out initializing process group"
|
||||
RuntimeError, "Health check failure"
|
||||
):
|
||||
c10d.init_process_group(
|
||||
backend="nccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=store,
|
||||
timeout=timedelta(seconds=1),
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
|
|
@ -2565,12 +2566,12 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=store,
|
||||
timeout=timedelta(seconds=1),
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Timed out initializing process group"
|
||||
RuntimeError, "Health check failure"
|
||||
):
|
||||
c10d.new_group([0, 1], timeout=timedelta(seconds=1))
|
||||
|
||||
|
|
@ -2588,12 +2589,12 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=store,
|
||||
timeout=timedelta(seconds=1),
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
|
||||
if self.rank == 1:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Timed out initializing process group"
|
||||
RuntimeError, "Health check failure"
|
||||
):
|
||||
c10d.new_group([0, 1], timeout=timedelta(seconds=1))
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class ProcessGroupNCCLJitTest(JitTestCase):
|
|||
|
||||
def _create_nccl_pg(self, name_prefix):
|
||||
tcp_store = create_tcp_store(jit_class=True)
|
||||
opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(0, True)
|
||||
opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(10000, True)
|
||||
|
||||
name = unique_process_group_name(name_prefix)
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ class ProcessGroupNCCLJitTest(JitTestCase):
|
|||
tcp_store = create_tcp_store(jit_class=True)
|
||||
|
||||
return torch.classes.dist_c10d.frontend().new_process_group_helper(
|
||||
self.world_size, self.rank, [], "nccl", tcp_store, name, 0)
|
||||
self.world_size, self.rank, [], "nccl", tcp_store, name, 10000)
|
||||
|
||||
@requires_nccl()
|
||||
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
|
|
@ -172,7 +172,7 @@ class C10dFrontendJitTest(JitTestCase):
|
|||
pg_name = unique_process_group_name("singleton_test_process_group")
|
||||
|
||||
ProcessGroupNCCL1 = frontend1.new_process_group_helper(
|
||||
self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 0)
|
||||
self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 10000)
|
||||
|
||||
ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name)
|
||||
self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name)
|
||||
|
|
@ -189,7 +189,7 @@ class C10dProcessGroupSerialization(JitTestCase):
|
|||
|
||||
name = unique_process_group_name("module_member_process_group")
|
||||
self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper(
|
||||
1, 0, [], "nccl", tcp_store, name, 0)
|
||||
1, 0, [], "nccl", tcp_store, name, 10000)
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
if self.pg is None:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <c10d/ProcessGroupNCCL.hpp>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <sstream>
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
|
@ -13,6 +14,7 @@
|
|||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
|
@ -159,6 +161,14 @@ std::vector<at::Device> getDeviceList(const std::vector<at::Tensor>& tensors) {
|
|||
return res;
|
||||
}
|
||||
|
||||
// Return CUDA device with ordinal given by input rank.
|
||||
at::Device getDeviceForRank(int rank) {
|
||||
TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
|
||||
auto numGPUs = at::cuda::getNumGPUs();
|
||||
int16_t deviceIdx = static_cast<int16_t>(rank % numGPUs);
|
||||
return at::Device(at::DeviceType::CUDA, deviceIdx);
|
||||
}
|
||||
|
||||
// [Sync Streams] Helper that lets the input ncclStreams to wait for the current
|
||||
// stream. NCCL communications run on ncclStreams, but input tensors are
|
||||
// allocated on different streams (i.e., current streams). Communications on
|
||||
|
|
@ -502,6 +512,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
|||
asyncErrorHandling_ = false;
|
||||
}
|
||||
|
||||
// Perform health check by initializing dummy communicators and destroying
|
||||
// them. This will help indicate any NCCL-related issues prior to the first
|
||||
// collective.
|
||||
// Run it in a separate thread and wait on CV to handle timeouts, since
|
||||
// majority of getNCCLComm failures are hangs.
|
||||
runHealthCheck();
|
||||
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
ncclCommWatchdogThread_ =
|
||||
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
|
||||
|
|
@ -527,6 +544,64 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
|||
<< "\nNCCL_DEBUG: " << ncclDebugLevel;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::runHealthCheck() {
|
||||
// Run health check in a separate thread and wait on CV to handle timeouts,
|
||||
// since majority of getNCCLComm failures are hangs.
|
||||
|
||||
struct HealthCheckData {
|
||||
std::mutex healthCheckMutex;
|
||||
std::condition_variable healthCheckCv;
|
||||
bool healthCheckSuccess = false;
|
||||
std::exception_ptr healthCheckException;
|
||||
};
|
||||
|
||||
HealthCheckData healthCheckData;
|
||||
auto t = std::thread([&healthCheckData, this]() {
|
||||
try {
|
||||
std::vector<at::Device> rankDevice = {getDeviceForRank(rank_)};
|
||||
const auto key = getKeyFromDevices(rankDevice);
|
||||
// OpType does not matter, only need to set to not go through send/recv
|
||||
// path.
|
||||
getNCCLComm(key, rankDevice, OpType::ALLREDUCE);
|
||||
// Now destroy the communicators and remove them from cache so we don't
|
||||
// use destroyed communicators.
|
||||
destroyNCCLComms(key);
|
||||
// Notify main thread the health check is complete.
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
|
||||
healthCheckData.healthCheckSuccess = true;
|
||||
}
|
||||
healthCheckData.healthCheckCv.notify_one();
|
||||
} catch (const std::exception& e) {
|
||||
// Populate exception ptr.
|
||||
healthCheckData.healthCheckException = std::current_exception();
|
||||
// Unblock waiting main thread which will report exception.
|
||||
healthCheckData.healthCheckCv.notify_one();
|
||||
} // Unknown exceptions will just cause the program to terminate.
|
||||
});
|
||||
// We don't need to join the thread, just need to verify health check via the
|
||||
// CV. Hence we detach the thread here.
|
||||
t.detach(); // NOLINT
|
||||
LOG(INFO) << "[Rank " << rank_ << "]"
|
||||
<< " will wait up to " << options_->timeout.count()
|
||||
<< " msec for NCCL health check to complete.";
|
||||
std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
|
||||
healthCheckData.healthCheckCv.wait_for(
|
||||
lock, options_->timeout, [&healthCheckData]() {
|
||||
return healthCheckData.healthCheckSuccess;
|
||||
});
|
||||
|
||||
if (healthCheckData.healthCheckException) {
|
||||
std::rethrow_exception(healthCheckData.healthCheckException);
|
||||
}
|
||||
// If there is no exception, the likely culprit is a timeout/hang which is how
|
||||
// most communicator init issues manifest themselves.
|
||||
TORCH_CHECK(
|
||||
healthCheckData.healthCheckSuccess,
|
||||
"ProcessGroupNCCL: Health check failure: Failed to initialize NCCL communicator on rank ",
|
||||
rank_);
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::setSequenceNumberForGroup() {
|
||||
if (rank_ == 0) {
|
||||
// Create and broadcast sequence number
|
||||
|
|
@ -874,6 +949,30 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
"Expected to find key ",
|
||||
devNCCLCommMapKey,
|
||||
" in NCCL communicator map.");
|
||||
}
|
||||
std::vector<std::shared_ptr<NCCLComm>>& ncclComms =
|
||||
devNCCLCommMap_[devNCCLCommMapKey];
|
||||
// Loop through communicators and call ncclCommAbort.
|
||||
for (const auto& comm : ncclComms) {
|
||||
// ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being
|
||||
// destroyed, so using ncclCommAbort here.
|
||||
comm->ncclCommAbort();
|
||||
}
|
||||
// Remove communicators from the cache.
|
||||
devNCCLCommMap_.erase(devNCCLCommMapKey);
|
||||
// Clear used device indices.
|
||||
usedDeviceIdxs_.clear();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
const std::string& devicesKey,
|
||||
const std::vector<at::Device>& devices,
|
||||
|
|
@ -1697,7 +1796,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::barrier(
|
|||
"This can potentially cause a hang if this rank to GPU mapping is incorrect.",
|
||||
"Specify device_ids in barrier() to force use of a particular device."
|
||||
);
|
||||
devices.emplace_back(at::DeviceType::CUDA, deviceIdx);
|
||||
devices.emplace_back(getDeviceForRank(rank_));
|
||||
} else {
|
||||
for (auto usedDeviceIdx : usedDeviceIdxs_) {
|
||||
devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx);
|
||||
|
|
|
|||
|
|
@ -430,6 +430,18 @@ class TORCH_API ProcessGroupNCCL : public ProcessGroup {
|
|||
void abortTimedOutCollectives(
|
||||
std::unordered_set<std::string>& abortedCommIds);
|
||||
|
||||
// Performs a health check by initializing dummy NCCL communicators and then
|
||||
// destroying them. This will help indicate and signal any NCCL-related issues
|
||||
// prior to the first collective. The actual initialization and subsequent
|
||||
// destruction is ran on a separate thread and the main thread is signalled
|
||||
// about timeouts/errors to report to the application.
|
||||
void runHealthCheck();
|
||||
|
||||
// Destroys initialized NCCL communicators in devNCCLComMap_ given by input
|
||||
// key. Throws if there are no communicators to destroy. Also removes
|
||||
// communicators from the cache and clears used device indices.
|
||||
void destroyNCCLComms(const std::string& devNCCLCommMapKey);
|
||||
|
||||
void workCleanupLoop();
|
||||
|
||||
protected:
|
||||
|
|
|
|||
|
|
@ -7343,7 +7343,9 @@ class DistributedTest:
|
|||
# tests expected behavior when nonzero rank hangs.
|
||||
nccl_pg = dist.new_group(
|
||||
ranks=list(i for i in range(int(self.world_size))),
|
||||
timeout=timedelta(seconds=2),
|
||||
# provide sufficient timeout so communicators
|
||||
# can be initialized in ctor.
|
||||
timeout=timedelta(seconds=15),
|
||||
backend=dist.Backend.NCCL,
|
||||
)
|
||||
gloo_pg = dist.new_group(
|
||||
|
|
@ -7354,7 +7356,7 @@ class DistributedTest:
|
|||
# Let all ranks call allreduce first to set up communicators etc.
|
||||
# Directly simulating error here will run into store issue described
|
||||
# in https://github.com/pytorch/pytorch/issues/54524.
|
||||
nccl_pg.allreduce(tensors).wait()
|
||||
nccl_pg.allreduce(tensors).wait(timedelta(seconds=5))
|
||||
# All ranks besides 0 call into allreduce. This is to simulate a
|
||||
# desync across the world, where some ranks call into
|
||||
# monitored_barrier() and others are stuck in collective comm. In
|
||||
|
|
@ -7388,6 +7390,8 @@ class DistributedTest:
|
|||
monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
|
||||
)
|
||||
|
||||
self._barrier(timeout=30)
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@require_backend({"gloo", "nccl"})
|
||||
@require_backends_available({"gloo", "nccl"})
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user