mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Safer bookkeeping of NCCL communicators (#150681)
This consists mainly in two changes: - ensure we can reliably obtain the device from a `NCCLComm` object (there was one constructor which didn't set the device) - use a RAII pattern for acquiring the lock to the global dictionary of `NCCLComms` (which ensures the lock is released in case of exceptions) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150681 Approved by: https://github.com/kwen2501
This commit is contained in:
parent
3da14d38bd
commit
3649e2e7bd
|
|
@ -92,7 +92,9 @@ std::shared_ptr<NCCLComm> NCCLComm::create_scalable(
|
|||
int numRanks,
|
||||
int rank,
|
||||
std::vector<ncclUniqueId>& commIds,
|
||||
at::DeviceIndex deviceIndex,
|
||||
ncclConfig_t& config) {
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
comm->nonBlocking_ = config.blocking == 0;
|
||||
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
|
||||
|
|
@ -112,6 +114,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create_scalable(
|
|||
// in the log file and in the replay tool.
|
||||
comm->ncclId_ = commIds[0];
|
||||
comm->rank_ = rank;
|
||||
comm->deviceIndex_ = deviceIndex;
|
||||
comm->initialized_ = !comm->nonBlocking_;
|
||||
return comm;
|
||||
}
|
||||
|
|
@ -150,6 +153,10 @@ ncclComm_t NCCLComm::getNcclComm() {
|
|||
return ncclComm_;
|
||||
}
|
||||
|
||||
at::DeviceIndex NCCLComm::getDeviceIndex() {
|
||||
return deviceIndex_;
|
||||
}
|
||||
|
||||
// Wait for the communicator to be ready. This is a blocking function.
|
||||
// Arguments:
|
||||
// longInterval: if true, wait with sleep of an interval; otherwise, wait
|
||||
|
|
|
|||
|
|
@ -221,6 +221,7 @@ class NCCLComm {
|
|||
int numRanks,
|
||||
int rank,
|
||||
std::vector<ncclUniqueId>& commIds,
|
||||
at::DeviceIndex deviceIndex,
|
||||
ncclConfig_t& config);
|
||||
#endif // NCCL_HAS_INIT_RANK_SCALABLE
|
||||
#endif // NCCL_HAS_CONFIG
|
||||
|
|
@ -239,6 +240,7 @@ class NCCLComm {
|
|||
#endif
|
||||
|
||||
ncclUniqueId getNcclId();
|
||||
at::DeviceIndex getDeviceIndex();
|
||||
|
||||
// Must not be copyable
|
||||
NCCLComm(const NCCLComm&) = delete;
|
||||
|
|
|
|||
|
|
@ -302,10 +302,8 @@ static void cacheAllocatorRegisterHook(
|
|||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
for (auto& it : ncclCommDevIdxMap) {
|
||||
auto& ncclComm = it.first;
|
||||
auto& devIdx = it.second;
|
||||
if (te.device_ == devIdx) {
|
||||
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
|
||||
if (te.device_ == ncclComm->getDeviceIndex()) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
ncclComm->registerSegment(reinterpret_cast<void*>(te.addr_), te.size_);
|
||||
}
|
||||
|
|
@ -321,10 +319,8 @@ static void cacheAllocatorDeregisterHook(
|
|||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
for (auto& it : ncclCommDevIdxMap) {
|
||||
auto& ncclComm = it.first;
|
||||
auto& devIdx = it.second;
|
||||
if (te.device_ == devIdx) {
|
||||
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
|
||||
if (te.device_ == ncclComm->getDeviceIndex()) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_));
|
||||
}
|
||||
|
|
@ -345,11 +341,12 @@ static std::
|
|||
std::vector<std::shared_ptr<NCCLComm>> allNCCLComms;
|
||||
// within the critical section, we don't want to dump while holding the lock
|
||||
// as dump might hang
|
||||
ncclCommDevIdxMapMutex.lock();
|
||||
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
|
||||
allNCCLComms.push_back(ncclComm);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
|
||||
allNCCLComms.push_back(ncclComm);
|
||||
}
|
||||
}
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
for (auto& ncclComm : allNCCLComms) {
|
||||
std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
|
||||
ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
|
||||
|
|
@ -824,9 +821,10 @@ void ProcessGroupNCCL::WorkNCCL::abort() {
|
|||
// Abort all communicators of this work
|
||||
ncclComm_->abort();
|
||||
|
||||
ncclCommDevIdxMapMutex.lock();
|
||||
ncclCommDevIdxMap.erase(ncclComm_);
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
ncclCommDevIdxMap.erase(ncclComm_);
|
||||
}
|
||||
}
|
||||
|
||||
ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default;
|
||||
|
|
@ -1390,12 +1388,12 @@ bool ProcessGroupNCCL::abortComms(
|
|||
// communicators. Note that ncclCommDevIdxMap is a global container which may
|
||||
// contain other PG's communicators, thus we need to only erase communicators
|
||||
// for the current PG.
|
||||
ncclCommDevIdxMapMutex.lock();
|
||||
for (auto& it : devNCCLCommMap_) {
|
||||
auto& ncclComm = it.second;
|
||||
ncclCommDevIdxMap.erase(ncclComm);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
for (auto& [_, ncclComm] : devNCCLCommMap_) {
|
||||
ncclCommDevIdxMap.erase(ncclComm);
|
||||
}
|
||||
}
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
abortCommsFromMap(devNCCLCommMap_, abortReason);
|
||||
|
|
@ -2705,9 +2703,10 @@ void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) {
|
|||
// Clear used device indices.
|
||||
usedDeviceIdxs_.clear();
|
||||
|
||||
ncclCommDevIdxMapMutex.lock();
|
||||
ncclCommDevIdxMap.erase(ncclComm);
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
ncclCommDevIdxMap.erase(ncclComm);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
|
||||
|
|
@ -2874,8 +2873,8 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
|
|||
<< "ProcessGroupNCCL all-gather unique IDs through store took "
|
||||
<< timerDeltaMs << " ms";
|
||||
#if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG)
|
||||
ncclComm =
|
||||
NCCLComm::create_scalable(numRanks, rank, ncclIDs, options_->config);
|
||||
ncclComm = NCCLComm::create_scalable(
|
||||
numRanks, rank, ncclIDs, deviceIndex, options_->config);
|
||||
#else
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
|
|
@ -2985,9 +2984,10 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
|
|||
// on the same device.
|
||||
// NOTE: we need remove the communicator from this map when it is
|
||||
// destroyed, otherwise may register onto an invalid communicator.
|
||||
ncclCommDevIdxMapMutex.lock();
|
||||
ncclCommDevIdxMap.emplace(ncclComm, device.index());
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ncclCommDevIdxMapMutex);
|
||||
ncclCommDevIdxMap.emplace(ncclComm, device.index());
|
||||
}
|
||||
}
|
||||
|
||||
it = devNCCLCommMap_.find(deviceKey);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user