mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Back out "Revert D31005792: [NCCL] Init dummy NCCL comms in constructor" (#65883)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65883 Original commit changeset: d8e962b8aab6 ghstack-source-id: 139836954 Test Plan: ci Reviewed By: zhaojuanmao Differential Revision: D31299350 fbshipit-source-id: 9ad5c8fa17f7038ba579cb1eda6d9271ac07a130
This commit is contained in:
parent
c1343ff706
commit
f1f3bd8c36
|
|
@ -1,9 +1,12 @@
|
||||||
|
#include <chrono>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include <c10d/FileStore.hpp>
|
#include <c10d/FileStore.hpp>
|
||||||
#include <c10d/ProcessGroupNCCL.hpp>
|
#include <c10d/ProcessGroupNCCL.hpp>
|
||||||
#include "CUDATest.hpp"
|
#include "CUDATest.hpp"
|
||||||
#include "TestUtils.hpp"
|
#include "TestUtils.hpp"
|
||||||
|
#include "c10d/ProcessGroup.hpp"
|
||||||
|
#include "c10d/Types.hpp"
|
||||||
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
|
@ -19,7 +22,7 @@ using c10d::ProcessGroup;
|
||||||
|
|
||||||
class NCCLTestBase {
|
class NCCLTestBase {
|
||||||
public:
|
public:
|
||||||
NCCLTestBase(const std::string& path) : path_(path) {}
|
NCCLTestBase(const std::string& path, const std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout) : path_(path), pgTimeout_(pgTimeout) {}
|
||||||
|
|
||||||
NCCLTestBase(NCCLTestBase&& other) {
|
NCCLTestBase(NCCLTestBase&& other) {
|
||||||
path_ = std::move(other.path_);
|
path_ = std::move(other.path_);
|
||||||
|
|
@ -33,19 +36,22 @@ class NCCLTestBase {
|
||||||
void initialize(int rank, int size) {
|
void initialize(int rank, int size) {
|
||||||
auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
|
auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
|
||||||
|
|
||||||
|
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts = c10::make_intrusive<c10d::ProcessGroupNCCL::Options>();
|
||||||
|
opts->timeout = pgTimeout_;
|
||||||
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
|
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
|
||||||
new ::c10d::ProcessGroupNCCL(store, rank, size));
|
new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::string path_;
|
std::string path_;
|
||||||
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
|
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
|
||||||
|
std::chrono::milliseconds pgTimeout_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class NCCLTest : public NCCLTestBase {
|
class NCCLTest : public NCCLTestBase {
|
||||||
public:
|
public:
|
||||||
NCCLTest(const std::string& path, int worldSize)
|
NCCLTest(const std::string& path, int worldSize, std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout)
|
||||||
: NCCLTestBase(path),
|
: NCCLTestBase(path, pgTimeout),
|
||||||
numDevices_(cudaNumDevices()),
|
numDevices_(cudaNumDevices()),
|
||||||
state_(::at::globalContext().lazyInitCUDA()),
|
state_(::at::globalContext().lazyInitCUDA()),
|
||||||
worldSize_(worldSize) {
|
worldSize_(worldSize) {
|
||||||
|
|
@ -497,10 +503,50 @@ void testReduceScatter(const std::string& path, int rank, int size) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool timeout) {
|
||||||
|
// simulate world_size > 1 here via threads.
|
||||||
|
const int worldSize = 4;
|
||||||
|
std::mutex m;
|
||||||
|
std::unordered_set<uint64_t> nums;
|
||||||
|
auto runTest = [&](int i) {
|
||||||
|
NCCLTest test(path, worldSize, std::chrono::milliseconds(3000));
|
||||||
|
// Catch error relating to health check failure
|
||||||
|
bool error_caught = false;
|
||||||
|
try {
|
||||||
|
test.initialize(timeout ? 0 : -1, worldSize);
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
std::string errMsg = e.what();
|
||||||
|
const std::string kTimeoutErr = "Failed to initialize NCCL communicator on rank";
|
||||||
|
const std::string kInvalidRankErr = "Invalid rank";
|
||||||
|
std::string expectedSubstr = timeout ? kTimeoutErr : kInvalidRankErr;
|
||||||
|
bool cond = errMsg.find(expectedSubstr) != std::string::npos;
|
||||||
|
EXPECT_TRUE(cond);
|
||||||
|
error_caught = true;
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(error_caught);
|
||||||
|
};
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
threads.reserve(worldSize);
|
||||||
|
for (const auto r : c10::irange(worldSize)) {
|
||||||
|
threads.emplace_back(std::thread([=]() { runTest(r); }));
|
||||||
|
}
|
||||||
|
for (auto& t : threads) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void testProcessGroupNCCLHealthCheckFailException(const std::string& path, int /* unused */, int /* unused */) {
|
||||||
|
testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testProcessGroupNCCLHealthCheckFailTimeout(const std::string& path, int /* unused */, int /* unused */) {
|
||||||
|
testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ true);
|
||||||
|
}
|
||||||
|
|
||||||
void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) {
|
void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) {
|
||||||
// Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we
|
// Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we
|
||||||
// simulate world_size > 1 here via threads.
|
// simulate world_size > 1 here via threads.
|
||||||
const int worldSize = 4;
|
const int worldSize = 2;
|
||||||
std::mutex m;
|
std::mutex m;
|
||||||
std::unordered_set<uint64_t> nums;
|
std::unordered_set<uint64_t> nums;
|
||||||
auto runTest = [&](int i) {
|
auto runTest = [&](int i) {
|
||||||
|
|
@ -625,6 +671,26 @@ TEST_F(ProcessGroupNCCLTest, testSequenceNumInit) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailTimeout) {
|
||||||
|
if (skipTest()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
TemporaryFile file;
|
||||||
|
testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailException) {
|
||||||
|
if (skipTest()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
TemporaryFile file;
|
||||||
|
testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) {
|
TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) {
|
||||||
if (skipTest()) {
|
if (skipTest()) {
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -657,11 +657,13 @@ class DistributedDataParallelTest(
|
||||||
# otherwise process will be taken down and we can't check for errors.
|
# otherwise process will be taken down and we can't check for errors.
|
||||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||||
os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
||||||
timeout = timedelta(seconds=2)
|
# TODO: smaller timeout can fail since PG NCCl does health check in
|
||||||
|
# constructor. Look into reducing this test's runtime.
|
||||||
store = c10d.FileStore(self.file_name, self.world_size)
|
store = c10d.FileStore(self.file_name, self.world_size)
|
||||||
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timeout)
|
# provide sufficient timeout to initialize NCCL comm.
|
||||||
|
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=15))
|
||||||
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
|
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
|
||||||
pg.barrier().wait()
|
pg.barrier().wait(timedelta(seconds=5))
|
||||||
# Simulate stuckness in rank 0.
|
# Simulate stuckness in rank 0.
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
pg_gloo.barrier().wait()
|
pg_gloo.barrier().wait()
|
||||||
|
|
@ -670,7 +672,7 @@ class DistributedDataParallelTest(
|
||||||
if self.rank != 0:
|
if self.rank != 0:
|
||||||
# Time out due to rank 0 not calling into allreduce.
|
# Time out due to rank 0 not calling into allreduce.
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
pg.allreduce([inp]).wait()
|
pg.allreduce([inp]).wait(timedelta(seconds=5))
|
||||||
|
|
||||||
# Now when nonzero rank attempts to use communicator, original failure reason should be logged.j
|
# Now when nonzero rank attempts to use communicator, original failure reason should be logged.j
|
||||||
try:
|
try:
|
||||||
|
|
@ -2263,14 +2265,14 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
store,
|
store,
|
||||||
self.rank,
|
self.rank,
|
||||||
self.world_size,
|
self.world_size,
|
||||||
timeout=timedelta(seconds=self.op_timeout_sec),
|
timeout=timedelta(seconds=10),
|
||||||
)
|
)
|
||||||
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||||
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
||||||
# Operation would time out in blocking mode.
|
# Operation would time out in blocking mode.
|
||||||
work.wait()
|
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
|
||||||
# Run some GPU operations to make sure cuda has not gotten stuck.
|
# Run some GPU operations to make sure cuda has not gotten stuck.
|
||||||
# It was observed cuda could get stuck if NCCL communicators were
|
# It was observed cuda could get stuck if NCCL communicators were
|
||||||
# not properly aborted before throwing RuntimeError.
|
# not properly aborted before throwing RuntimeError.
|
||||||
|
|
@ -2339,13 +2341,13 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
store,
|
store,
|
||||||
self.rank,
|
self.rank,
|
||||||
self.world_size,
|
self.world_size,
|
||||||
timeout=timedelta(seconds=self.op_timeout_sec),
|
timeout=timedelta(seconds=10),
|
||||||
)
|
)
|
||||||
process_group.barrier().wait()
|
process_group.barrier().wait()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
||||||
# This should timeout
|
# This should timeout
|
||||||
process_group.barrier().wait()
|
process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec))
|
||||||
|
|
||||||
def _run_invalid_nccl_blocking_wait_env(self, val):
|
def _run_invalid_nccl_blocking_wait_env(self, val):
|
||||||
os.environ["NCCL_BLOCKING_WAIT"] = val
|
os.environ["NCCL_BLOCKING_WAIT"] = val
|
||||||
|
|
@ -2382,21 +2384,20 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
store = c10d.FileStore(self.file_name, self.world_size)
|
store = c10d.FileStore(self.file_name, self.world_size)
|
||||||
|
|
||||||
# Initialize process_group.
|
# Initialize process_group.
|
||||||
timeout = 1
|
|
||||||
process_group = c10d.ProcessGroupNCCL(
|
process_group = c10d.ProcessGroupNCCL(
|
||||||
store, self.rank, self.world_size, timeout=timedelta(seconds=timeout)
|
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
|
||||||
)
|
)
|
||||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
|
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
# This should timeout in about 1 second.
|
# This should timeout in about 1 second.
|
||||||
start = time.time()
|
start = time.time()
|
||||||
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
|
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
|
||||||
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
||||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
|
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
|
||||||
else:
|
else:
|
||||||
# Sleep to ensure timeout.
|
# Sleep to ensure timeout.
|
||||||
time.sleep(2 * timeout)
|
time.sleep(2)
|
||||||
|
|
||||||
self._wait_for_comm_abort(process_group)
|
self._wait_for_comm_abort(process_group)
|
||||||
|
|
||||||
|
|
@ -2546,14 +2547,14 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||||
store = c10d.FileStore(self.file_name, self.world_size)
|
store = c10d.FileStore(self.file_name, self.world_size)
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Timed out initializing process group"
|
RuntimeError, "Health check failure"
|
||||||
):
|
):
|
||||||
c10d.init_process_group(
|
c10d.init_process_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
store=store,
|
store=store,
|
||||||
timeout=timedelta(seconds=1),
|
timeout=timedelta(seconds=10),
|
||||||
)
|
)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
|
|
@ -2565,12 +2566,12 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
store=store,
|
store=store,
|
||||||
timeout=timedelta(seconds=1),
|
timeout=timedelta(seconds=10),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Timed out initializing process group"
|
RuntimeError, "Health check failure"
|
||||||
):
|
):
|
||||||
c10d.new_group([0, 1], timeout=timedelta(seconds=1))
|
c10d.new_group([0, 1], timeout=timedelta(seconds=1))
|
||||||
|
|
||||||
|
|
@ -2588,12 +2589,12 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
store=store,
|
store=store,
|
||||||
timeout=timedelta(seconds=1),
|
timeout=timedelta(seconds=10),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.rank == 1:
|
if self.rank == 1:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Timed out initializing process group"
|
RuntimeError, "Health check failure"
|
||||||
):
|
):
|
||||||
c10d.new_group([0, 1], timeout=timedelta(seconds=1))
|
c10d.new_group([0, 1], timeout=timedelta(seconds=1))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class ProcessGroupNCCLJitTest(JitTestCase):
|
||||||
|
|
||||||
def _create_nccl_pg(self, name_prefix):
|
def _create_nccl_pg(self, name_prefix):
|
||||||
tcp_store = create_tcp_store(jit_class=True)
|
tcp_store = create_tcp_store(jit_class=True)
|
||||||
opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(0, True)
|
opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(10000, True)
|
||||||
|
|
||||||
name = unique_process_group_name(name_prefix)
|
name = unique_process_group_name(name_prefix)
|
||||||
|
|
||||||
|
|
@ -49,7 +49,7 @@ class ProcessGroupNCCLJitTest(JitTestCase):
|
||||||
tcp_store = create_tcp_store(jit_class=True)
|
tcp_store = create_tcp_store(jit_class=True)
|
||||||
|
|
||||||
return torch.classes.dist_c10d.frontend().new_process_group_helper(
|
return torch.classes.dist_c10d.frontend().new_process_group_helper(
|
||||||
self.world_size, self.rank, [], "nccl", tcp_store, name, 0)
|
self.world_size, self.rank, [], "nccl", tcp_store, name, 10000)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||||
|
|
@ -172,7 +172,7 @@ class C10dFrontendJitTest(JitTestCase):
|
||||||
pg_name = unique_process_group_name("singleton_test_process_group")
|
pg_name = unique_process_group_name("singleton_test_process_group")
|
||||||
|
|
||||||
ProcessGroupNCCL1 = frontend1.new_process_group_helper(
|
ProcessGroupNCCL1 = frontend1.new_process_group_helper(
|
||||||
self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 0)
|
self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 10000)
|
||||||
|
|
||||||
ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name)
|
ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name)
|
||||||
self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name)
|
self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name)
|
||||||
|
|
@ -189,7 +189,7 @@ class C10dProcessGroupSerialization(JitTestCase):
|
||||||
|
|
||||||
name = unique_process_group_name("module_member_process_group")
|
name = unique_process_group_name("module_member_process_group")
|
||||||
self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper(
|
self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper(
|
||||||
1, 0, [], "nccl", tcp_store, name, 0)
|
1, 0, [], "nccl", tcp_store, name, 10000)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor):
|
def forward(self, input: torch.Tensor):
|
||||||
if self.pg is None:
|
if self.pg is None:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
#include <c10d/ProcessGroupNCCL.hpp>
|
#include <c10d/ProcessGroupNCCL.hpp>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#ifdef USE_C10D_NCCL
|
#ifdef USE_C10D_NCCL
|
||||||
|
|
@ -13,6 +14,7 @@
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||||
|
#include <c10/core/DeviceType.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <c10/util/Logging.h>
|
#include <c10/util/Logging.h>
|
||||||
|
|
@ -159,6 +161,14 @@ std::vector<at::Device> getDeviceList(const std::vector<at::Tensor>& tensors) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return CUDA device with ordinal given by input rank.
|
||||||
|
at::Device getDeviceForRank(int rank) {
|
||||||
|
TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
|
||||||
|
auto numGPUs = at::cuda::getNumGPUs();
|
||||||
|
int16_t deviceIdx = static_cast<int16_t>(rank % numGPUs);
|
||||||
|
return at::Device(at::DeviceType::CUDA, deviceIdx);
|
||||||
|
}
|
||||||
|
|
||||||
// [Sync Streams] Helper that lets the input ncclStreams to wait for the current
|
// [Sync Streams] Helper that lets the input ncclStreams to wait for the current
|
||||||
// stream. NCCL communications run on ncclStreams, but input tensors are
|
// stream. NCCL communications run on ncclStreams, but input tensors are
|
||||||
// allocated on different streams (i.e., current streams). Communications on
|
// allocated on different streams (i.e., current streams). Communications on
|
||||||
|
|
@ -502,6 +512,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||||
asyncErrorHandling_ = false;
|
asyncErrorHandling_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Perform health check by initializing dummy communicators and destroying
|
||||||
|
// them. This will help indicate any NCCL-related issues prior to the first
|
||||||
|
// collective.
|
||||||
|
// Run it in a separate thread and wait on CV to handle timeouts, since
|
||||||
|
// majority of getNCCLComm failures are hangs.
|
||||||
|
runHealthCheck();
|
||||||
|
|
||||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||||
ncclCommWatchdogThread_ =
|
ncclCommWatchdogThread_ =
|
||||||
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
|
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
|
||||||
|
|
@ -527,6 +544,64 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||||
<< "\nNCCL_DEBUG: " << ncclDebugLevel;
|
<< "\nNCCL_DEBUG: " << ncclDebugLevel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ProcessGroupNCCL::runHealthCheck() {
|
||||||
|
// Run health check in a separate thread and wait on CV to handle timeouts,
|
||||||
|
// since majority of getNCCLComm failures are hangs.
|
||||||
|
|
||||||
|
struct HealthCheckData {
|
||||||
|
std::mutex healthCheckMutex;
|
||||||
|
std::condition_variable healthCheckCv;
|
||||||
|
bool healthCheckSuccess = false;
|
||||||
|
std::exception_ptr healthCheckException;
|
||||||
|
};
|
||||||
|
|
||||||
|
HealthCheckData healthCheckData;
|
||||||
|
auto t = std::thread([&healthCheckData, this]() {
|
||||||
|
try {
|
||||||
|
std::vector<at::Device> rankDevice = {getDeviceForRank(rank_)};
|
||||||
|
const auto key = getKeyFromDevices(rankDevice);
|
||||||
|
// OpType does not matter, only need to set to not go through send/recv
|
||||||
|
// path.
|
||||||
|
getNCCLComm(key, rankDevice, OpType::ALLREDUCE);
|
||||||
|
// Now destroy the communicators and remove them from cache so we don't
|
||||||
|
// use destroyed communicators.
|
||||||
|
destroyNCCLComms(key);
|
||||||
|
// Notify main thread the health check is complete.
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
|
||||||
|
healthCheckData.healthCheckSuccess = true;
|
||||||
|
}
|
||||||
|
healthCheckData.healthCheckCv.notify_one();
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
// Populate exception ptr.
|
||||||
|
healthCheckData.healthCheckException = std::current_exception();
|
||||||
|
// Unblock waiting main thread which will report exception.
|
||||||
|
healthCheckData.healthCheckCv.notify_one();
|
||||||
|
} // Unknown exceptions will just cause the program to terminate.
|
||||||
|
});
|
||||||
|
// We don't need to join the thread, just need to verify health check via the
|
||||||
|
// CV. Hence we detach the thread here.
|
||||||
|
t.detach(); // NOLINT
|
||||||
|
LOG(INFO) << "[Rank " << rank_ << "]"
|
||||||
|
<< " will wait up to " << options_->timeout.count()
|
||||||
|
<< " msec for NCCL health check to complete.";
|
||||||
|
std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
|
||||||
|
healthCheckData.healthCheckCv.wait_for(
|
||||||
|
lock, options_->timeout, [&healthCheckData]() {
|
||||||
|
return healthCheckData.healthCheckSuccess;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (healthCheckData.healthCheckException) {
|
||||||
|
std::rethrow_exception(healthCheckData.healthCheckException);
|
||||||
|
}
|
||||||
|
// If there is no exception, the likely culprit is a timeout/hang which is how
|
||||||
|
// most communicator init issues manifest themselves.
|
||||||
|
TORCH_CHECK(
|
||||||
|
healthCheckData.healthCheckSuccess,
|
||||||
|
"ProcessGroupNCCL: Health check failure: Failed to initialize NCCL communicator on rank ",
|
||||||
|
rank_);
|
||||||
|
}
|
||||||
|
|
||||||
void ProcessGroupNCCL::setSequenceNumberForGroup() {
|
void ProcessGroupNCCL::setSequenceNumberForGroup() {
|
||||||
if (rank_ == 0) {
|
if (rank_ == 0) {
|
||||||
// Create and broadcast sequence number
|
// Create and broadcast sequence number
|
||||||
|
|
@ -874,6 +949,30 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
false,
|
||||||
|
"Expected to find key ",
|
||||||
|
devNCCLCommMapKey,
|
||||||
|
" in NCCL communicator map.");
|
||||||
|
}
|
||||||
|
std::vector<std::shared_ptr<NCCLComm>>& ncclComms =
|
||||||
|
devNCCLCommMap_[devNCCLCommMapKey];
|
||||||
|
// Loop through communicators and call ncclCommAbort.
|
||||||
|
for (const auto& comm : ncclComms) {
|
||||||
|
// ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being
|
||||||
|
// destroyed, so using ncclCommAbort here.
|
||||||
|
comm->ncclCommAbort();
|
||||||
|
}
|
||||||
|
// Remove communicators from the cache.
|
||||||
|
devNCCLCommMap_.erase(devNCCLCommMapKey);
|
||||||
|
// Clear used device indices.
|
||||||
|
usedDeviceIdxs_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||||
const std::string& devicesKey,
|
const std::string& devicesKey,
|
||||||
const std::vector<at::Device>& devices,
|
const std::vector<at::Device>& devices,
|
||||||
|
|
@ -1697,7 +1796,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::barrier(
|
||||||
"This can potentially cause a hang if this rank to GPU mapping is incorrect.",
|
"This can potentially cause a hang if this rank to GPU mapping is incorrect.",
|
||||||
"Specify device_ids in barrier() to force use of a particular device."
|
"Specify device_ids in barrier() to force use of a particular device."
|
||||||
);
|
);
|
||||||
devices.emplace_back(at::DeviceType::CUDA, deviceIdx);
|
devices.emplace_back(getDeviceForRank(rank_));
|
||||||
} else {
|
} else {
|
||||||
for (auto usedDeviceIdx : usedDeviceIdxs_) {
|
for (auto usedDeviceIdx : usedDeviceIdxs_) {
|
||||||
devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx);
|
devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx);
|
||||||
|
|
|
||||||
|
|
@ -430,6 +430,18 @@ class TORCH_API ProcessGroupNCCL : public ProcessGroup {
|
||||||
void abortTimedOutCollectives(
|
void abortTimedOutCollectives(
|
||||||
std::unordered_set<std::string>& abortedCommIds);
|
std::unordered_set<std::string>& abortedCommIds);
|
||||||
|
|
||||||
|
// Performs a health check by initializing dummy NCCL communicators and then
|
||||||
|
// destroying them. This will help indicate and signal any NCCL-related issues
|
||||||
|
// prior to the first collective. The actual initialization and subsequent
|
||||||
|
// destruction is ran on a separate thread and the main thread is signalled
|
||||||
|
// about timeouts/errors to report to the application.
|
||||||
|
void runHealthCheck();
|
||||||
|
|
||||||
|
// Destroys initialized NCCL communicators in devNCCLComMap_ given by input
|
||||||
|
// key. Throws if there are no communicators to destroy. Also removes
|
||||||
|
// communicators from the cache and clears used device indices.
|
||||||
|
void destroyNCCLComms(const std::string& devNCCLCommMapKey);
|
||||||
|
|
||||||
void workCleanupLoop();
|
void workCleanupLoop();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
||||||
|
|
@ -7343,7 +7343,9 @@ class DistributedTest:
|
||||||
# tests expected behavior when nonzero rank hangs.
|
# tests expected behavior when nonzero rank hangs.
|
||||||
nccl_pg = dist.new_group(
|
nccl_pg = dist.new_group(
|
||||||
ranks=list(i for i in range(int(self.world_size))),
|
ranks=list(i for i in range(int(self.world_size))),
|
||||||
timeout=timedelta(seconds=2),
|
# provide sufficient timeout so communicators
|
||||||
|
# can be initialized in ctor.
|
||||||
|
timeout=timedelta(seconds=15),
|
||||||
backend=dist.Backend.NCCL,
|
backend=dist.Backend.NCCL,
|
||||||
)
|
)
|
||||||
gloo_pg = dist.new_group(
|
gloo_pg = dist.new_group(
|
||||||
|
|
@ -7354,7 +7356,7 @@ class DistributedTest:
|
||||||
# Let all ranks call allreduce first to set up communicators etc.
|
# Let all ranks call allreduce first to set up communicators etc.
|
||||||
# Directly simulating error here will run into store issue described
|
# Directly simulating error here will run into store issue described
|
||||||
# in https://github.com/pytorch/pytorch/issues/54524.
|
# in https://github.com/pytorch/pytorch/issues/54524.
|
||||||
nccl_pg.allreduce(tensors).wait()
|
nccl_pg.allreduce(tensors).wait(timedelta(seconds=5))
|
||||||
# All ranks besides 0 call into allreduce. This is to simulate a
|
# All ranks besides 0 call into allreduce. This is to simulate a
|
||||||
# desync across the world, where some ranks call into
|
# desync across the world, where some ranks call into
|
||||||
# monitored_barrier() and others are stuck in collective comm. In
|
# monitored_barrier() and others are stuck in collective comm. In
|
||||||
|
|
@ -7388,6 +7390,8 @@ class DistributedTest:
|
||||||
monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
|
monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._barrier(timeout=30)
|
||||||
|
|
||||||
@with_nccl_blocking_wait
|
@with_nccl_blocking_wait
|
||||||
@require_backend({"gloo", "nccl"})
|
@require_backend({"gloo", "nccl"})
|
||||||
@require_backends_available({"gloo", "nccl"})
|
@require_backends_available({"gloo", "nccl"})
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user