Safer bookkeeping of NCCL communicators (#150681)

This consists mainly in two changes:
- ensure we can reliably obtain the device from a `NCCLComm` object (there was one constructor which didn't set the device)
- use a RAII pattern for acquiring the lock to the global dictionary of `NCCLComms` (which ensures the lock is released in case of exceptions)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150681
Approved by: https://github.com/kwen2501
This commit is contained in:
Luca Wehrstedt 2025-04-07 12:34:01 +00:00 committed by PyTorch MergeBot
parent 3da14d38bd
commit 3649e2e7bd
3 changed files with 37 additions and 28 deletions

View File

@ -92,7 +92,9 @@ std::shared_ptr<NCCLComm> NCCLComm::create_scalable(
int numRanks,
int rank,
std::vector<ncclUniqueId>& commIds,
at::DeviceIndex deviceIndex,
ncclConfig_t& config) {
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
auto comm = std::make_shared<NCCLComm>();
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
@ -112,6 +114,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create_scalable(
// in the log file and in the replay tool.
comm->ncclId_ = commIds[0];
comm->rank_ = rank;
comm->deviceIndex_ = deviceIndex;
comm->initialized_ = !comm->nonBlocking_;
return comm;
}
@ -150,6 +153,10 @@ ncclComm_t NCCLComm::getNcclComm() {
return ncclComm_;
}
at::DeviceIndex NCCLComm::getDeviceIndex() {
return deviceIndex_;
}
// Wait for the communicator to be ready. This is a blocking function.
// Arguments:
// longInterval: if true, wait with sleep of an interval; otherwise, wait

View File

@ -221,6 +221,7 @@ class NCCLComm {
int numRanks,
int rank,
std::vector<ncclUniqueId>& commIds,
at::DeviceIndex deviceIndex,
ncclConfig_t& config);
#endif // NCCL_HAS_INIT_RANK_SCALABLE
#endif // NCCL_HAS_CONFIG
@ -239,6 +240,7 @@ class NCCLComm {
#endif
ncclUniqueId getNcclId();
at::DeviceIndex getDeviceIndex();
// Must not be copyable
NCCLComm(const NCCLComm&) = delete;

View File

@ -302,10 +302,8 @@ static void cacheAllocatorRegisterHook(
}
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
for (auto& it : ncclCommDevIdxMap) {
auto& ncclComm = it.first;
auto& devIdx = it.second;
if (te.device_ == devIdx) {
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
if (te.device_ == ncclComm->getDeviceIndex()) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
ncclComm->registerSegment(reinterpret_cast<void*>(te.addr_), te.size_);
}
@ -321,10 +319,8 @@ static void cacheAllocatorDeregisterHook(
}
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
for (auto& it : ncclCommDevIdxMap) {
auto& ncclComm = it.first;
auto& devIdx = it.second;
if (te.device_ == devIdx) {
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
if (te.device_ == ncclComm->getDeviceIndex()) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_));
}
@ -345,11 +341,12 @@ static std::
std::vector<std::shared_ptr<NCCLComm>> allNCCLComms;
// within the critical section, we don't want to dump while holding the lock
// as dump might hang
ncclCommDevIdxMapMutex.lock();
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
allNCCLComms.push_back(ncclComm);
{
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
allNCCLComms.push_back(ncclComm);
}
}
ncclCommDevIdxMapMutex.unlock();
for (auto& ncclComm : allNCCLComms) {
std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
@ -824,9 +821,10 @@ void ProcessGroupNCCL::WorkNCCL::abort() {
// Abort all communicators of this work
ncclComm_->abort();
ncclCommDevIdxMapMutex.lock();
ncclCommDevIdxMap.erase(ncclComm_);
ncclCommDevIdxMapMutex.unlock();
{
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
ncclCommDevIdxMap.erase(ncclComm_);
}
}
ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default;
@ -1390,12 +1388,12 @@ bool ProcessGroupNCCL::abortComms(
// communicators. Note that ncclCommDevIdxMap is a global container which may
// contain other PG's communicators, thus we need to only erase communicators
// for the current PG.
ncclCommDevIdxMapMutex.lock();
for (auto& it : devNCCLCommMap_) {
auto& ncclComm = it.second;
ncclCommDevIdxMap.erase(ncclComm);
{
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
for (auto& [_, ncclComm] : devNCCLCommMap_) {
ncclCommDevIdxMap.erase(ncclComm);
}
}
ncclCommDevIdxMapMutex.unlock();
std::lock_guard<std::mutex> lock(mutex_);
abortCommsFromMap(devNCCLCommMap_, abortReason);
@ -2705,9 +2703,10 @@ void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
// Clear used device indices.
usedDeviceIdxs_.clear();
ncclCommDevIdxMapMutex.lock();
ncclCommDevIdxMap.erase(ncclComm);
ncclCommDevIdxMapMutex.unlock();
{
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
ncclCommDevIdxMap.erase(ncclComm);
}
}
std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
@ -2874,8 +2873,8 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
<< "ProcessGroupNCCL all-gather unique IDs through store took "
<< timerDeltaMs << " ms";
#if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG)
ncclComm =
NCCLComm::create_scalable(numRanks, rank, ncclIDs, options_->config);
ncclComm = NCCLComm::create_scalable(
numRanks, rank, ncclIDs, deviceIndex, options_->config);
#else
C10_THROW_ERROR(
DistBackendError,
@ -2985,9 +2984,10 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
// on the same device.
// NOTE: we need remove the communicator from this map when it is
// destroyed, otherwise may register onto an invalid communicator.
ncclCommDevIdxMapMutex.lock();
ncclCommDevIdxMap.emplace(ncclComm, device.index());
ncclCommDevIdxMapMutex.unlock();
{
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
ncclCommDevIdxMap.emplace(ncclComm, device.index());
}
}
it = devNCCLCommMap_.find(deviceKey);