make only current thread allocate to pool in NcclPG (#153990)

follow up to #153356 that fixes nccl allocation to pool

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153990
Approved by: https://github.com/kwen2501
This commit is contained in:
Natalia Gimelshein 2025-05-21 21:57:32 +00:00 committed by PyTorch MergeBot
parent 28bcd9eb30
commit 401fa87ace
3 changed files with 20 additions and 21 deletions

View File

@ -34,6 +34,7 @@
#include <regex>
#include <set>
#include <stack>
#include <thread>
#include <utility>
#include <vector>
@ -1286,7 +1287,10 @@ class DeviceCachingAllocator {
auto active_pool = MemPoolContext::getActiveMemPool();
if (!active_pool) {
for (MempoolId_t mempool_id : use_on_oom_pools) {
auto filter = [](cudaStream_t) { return true; };
auto tid = std::this_thread::get_id();
auto filter = [tid](cudaStream_t) {
return std::this_thread::get_id() == tid;
};
beginAllocateToPool(mempool_id, filter);
auto& mempool = get_pool(size, stream);
AllocParams mempool_params(
@ -2133,12 +2137,10 @@ class DeviceCachingAllocator {
ensure_exists_and_incref_pool(mempool_id);
}
void setUseOnOOM(bool use_on_oom, MempoolId_t mempool_id) {
void setUseOnOOM(MempoolId_t mempool_id) {
// Choose if this pool should be used as a last resort before ooming
std::lock_guard<std::recursive_mutex> lock(mutex);
if (use_on_oom) {
use_on_oom_pools.insert(mempool_id);
}
use_on_oom_pools.insert(mempool_id);
}
// See Note [Interaction with CUDA graph capture]
@ -3790,12 +3792,9 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id));
}
void setUseOnOOM(
c10::DeviceIndex device,
bool use_on_oom,
MempoolId_t mempool_id) override {
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
assertValidDevice(device);
device_allocator[device]->setUseOnOOM(use_on_oom, std::move(mempool_id));
device_allocator[device]->setUseOnOOM(std::move(mempool_id));
}
// CUDAGraph interactions
@ -4136,7 +4135,9 @@ MemPool::MemPool(
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_);
CUDACachingAllocator::setUseOnOOM(device_, use_on_oom, id_);
if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_);
}
}
MemPool::~MemPool() {

View File

@ -248,10 +248,7 @@ class CUDAAllocator : public Allocator {
" does not yet support ensureExistsAndIncrefPool. "
"If you need it, please file an issue describing your use case.");
}
virtual void setUseOnOOM(
c10::DeviceIndex device,
bool use_on_oom,
MempoolId_t mempool_id) {
virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
TORCH_CHECK(
false,
name(),
@ -483,11 +480,8 @@ inline void ensureExistsAndIncrefPool(
MempoolId_t mempool_id) {
get()->ensureExistsAndIncrefPool(device, mempool_id);
}
inline void setUseOnOOM(
c10::DeviceIndex device,
bool use_on_oom,
MempoolId_t mempool_id) {
get()->setUseOnOOM(device, use_on_oom, mempool_id);
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setUseOnOOM(device, mempool_id);
}
inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {

View File

@ -5570,8 +5570,12 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
// Allocate tensor under this MemPool's context
auto ctx = c10::cuda::MemPoolContext(memPool_.get());
auto tid = std::this_thread::get_id();
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
memPool_->device(), memPool_->id(), [](cudaStream_t) { return true; });
memPool_->device(), memPool_->id(), [=](cudaStream_t) {
auto current_tid = std::this_thread::get_id();
return current_tid == tid;
});
at::Tensor tensor = at::empty({size}, options);
c10::cuda::CUDACachingAllocator::endAllocateToPool(
memPool_->device(), memPool_->id());