fix retaining multimem in symmetric memory (#160343)

fixes OOM in #160289

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160343
Approved by: https://github.com/eqy
This commit is contained in:
Natalia Gimelshein 2025-08-12 02:03:15 +00:00 committed by PyTorch MergeBot
parent 95210cc409
commit be53f609aa
3 changed files with 17 additions and 4 deletions

View File

@ -53,7 +53,8 @@
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
_(cuMulticastAddDevice, 12030) \ _(cuMulticastAddDevice, 12030) \
_(cuMulticastBindMem, 12030) \ _(cuMulticastBindMem, 12030) \
_(cuMulticastCreate, 12030) _(cuMulticastCreate, 12030) \
_(cuMulticastUnbind, 12030)
#else #else
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_)
#endif #endif

View File

@ -46,11 +46,13 @@ AllocationRef::AllocationRef(
void* ptr, void* ptr,
HandleType handle, HandleType handle,
size_t block_size, size_t block_size,
int device_idx) int device_idx,
bool is_multicast)
: ptr(ptr), : ptr(ptr),
handle(handle), handle(handle),
block_size(block_size), block_size(block_size),
device_idx(device_idx) {} device_idx(device_idx),
is_multicast(is_multicast) {}
AllocationRef::~AllocationRef() { AllocationRef::~AllocationRef() {
if (is_finalizing()) { if (is_finalizing()) {
@ -63,6 +65,10 @@ AllocationRef::~AllocationRef() {
auto driver_api = c10::cuda::DriverAPI::get(); auto driver_api = c10::cuda::DriverAPI::get();
C10_CUDA_DRIVER_CHECK( C10_CUDA_DRIVER_CHECK(
driver_api->cuMemUnmap_(reinterpret_cast<CUdeviceptr>(ptr), block_size)); driver_api->cuMemUnmap_(reinterpret_cast<CUdeviceptr>(ptr), block_size));
if (is_multicast) {
C10_CUDA_DRIVER_CHECK(
driver_api->cuMulticastUnbind_(handle, device_idx, 0, block_size));
}
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle)); C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle));
#elif defined(USE_ROCM) #elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemUnmap(reinterpret_cast<hipDeviceptr_t>(ptr), block_size)); C10_HIP_CHECK(hipMemUnmap(reinterpret_cast<hipDeviceptr_t>(ptr), block_size));
@ -797,6 +803,10 @@ c10::intrusive_ptr<CUDASymmetricMemory> make_symm_mem(
for (int r = 0; r < world_size; ++r) { for (int r = 0; r < world_size; ++r) {
if (r == rank) { if (r == rank) {
alloc_refs.emplace_back(block->alloc_ref); alloc_refs.emplace_back(block->alloc_ref);
if (mc_addr != nullptr) {
alloc_refs.push_back(c10::make_intrusive<AllocationRef>(
mc_addr, mc_handle, block->block_size, block->device_idx, true));
}
continue; continue;
} }
alloc_refs.push_back(c10::make_intrusive<AllocationRef>( alloc_refs.push_back(c10::make_intrusive<AllocationRef>(

View File

@ -15,12 +15,14 @@ struct AllocationRef : public c10::intrusive_ptr_target {
HandleType handle; HandleType handle;
size_t block_size; size_t block_size;
int device_idx; int device_idx;
bool is_multicast;
AllocationRef( AllocationRef(
void* ptr, void* ptr,
HandleType handle, HandleType handle,
size_t block_size, size_t block_size,
int device_idx); int device_idx,
bool is_multicast = false);
~AllocationRef(); ~AllocationRef();
}; };