From d173ba5a7566a867677442c788a72c9bddaf7c99 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 30 May 2025 06:53:35 +0000 Subject: [PATCH] Revert "Remove MemPoolContext (#154042)" This reverts commit 3b38989b5f8f918cf1ad38bdade059608544af4b. Reverted https://github.com/pytorch/pytorch/pull/154042 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/154042#issuecomment-2921401100)) --- c10/cuda/CUDACachingAllocator.cpp | 115 ++++++++++------ c10/cuda/CUDACachingAllocator.h | 46 +++++-- c10/cuda/CUDAMallocAsyncAllocator.cpp | 4 +- docs/source/conf.py | 1 + docs/source/cuda.rst | 1 + test/test_cuda.py | 126 +++++++++++------- torch/_C/__init__.pyi.in | 7 +- torch/csrc/cuda/CUDAPluggableAllocator.cpp | 7 +- torch/csrc/cuda/CUDAPluggableAllocator.h | 5 +- torch/csrc/cuda/MemPool.cpp | 4 + torch/csrc/cuda/Module.cpp | 22 +-- .../distributed/c10d/ProcessGroupNCCL.cpp | 7 +- torch/cuda/__init__.py | 1 + torch/cuda/memory.py | 36 ++++- 14 files changed, 243 insertions(+), 139 deletions(-) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index b945efea785..140fc800b00 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -833,9 +833,8 @@ class EventPool { // CUDA graphs helper struct PrivatePool { - PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr) + PrivatePool(MempoolId_t id) : id(std::move(id)), - allocator_(allocator), large_blocks(/*small=*/false, this), small_blocks(/*small=*/true, this) {} PrivatePool(const PrivatePool&) = delete; @@ -856,14 +855,8 @@ struct PrivatePool { // distinguish private blocks by adding a "pool id" check above the stream // check in BlockComparator. BlockComparator is performance- critical though, // I'd rather not add more logic to it. - CUDAAllocator* allocator_; BlockPool large_blocks; BlockPool small_blocks; - - public: - CUDAAllocator* allocator() { - return allocator_; - } }; MempoolId_t BlockPool::owner_MempoolId() const { @@ -912,8 +905,9 @@ struct MempoolIdHash { }; cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) { - if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) { - *ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size); + auto active_pool = MemPoolContext::getActiveMemPool(); + if (active_pool && active_pool->allocator() && p.pool->owner_PrivatePool) { + *ptr = active_pool->allocator()->raw_alloc(size); return *ptr ? cudaSuccess : cudaErrorMemoryAllocation; } else { return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size)); @@ -1283,14 +1277,14 @@ class DeviceCachingAllocator { alloc_block(params, false, context, lock)) // Free all non-split cached blocks and retry alloc. || (C10_LIKELY(captures_underway.empty()) && - release_cached_blocks(context, {0, 0}) && + release_cached_blocks(context) && alloc_block(params, true, context, lock)); } // we are about to oom, try to use existing mempools as a last resort if (!block_found && params.err == cudaErrorMemoryAllocation) { // if already trying to use a mempool, then just oom - bool active_pool = params.pool->owner_PrivatePool; + auto active_pool = MemPoolContext::getActiveMemPool(); if (!active_pool) { for (MempoolId_t mempool_id : use_on_oom_pools) { auto tid = std::this_thread::get_id(); @@ -1677,10 +1671,10 @@ class DeviceCachingAllocator { } /** returns cached blocks to the system allocator **/ - void emptyCache(MempoolId_t mempool_id) { + void emptyCache() { auto context = maybeGatherContext(RecordContext::ALL); std::lock_guard lock(mutex); - release_cached_blocks(context, mempool_id); + release_cached_blocks(context); } /** Retrieves size of largest unused block held by the memory cache **/ @@ -1998,10 +1992,16 @@ class DeviceCachingAllocator { /** Dump a complete snapshot of the memory held by the allocator. Potentially * VERY expensive. **/ - std::vector snapshot(MempoolId_t mempool_id) { + std::vector snapshot() { std::lock_guard lock(mutex); std::vector all_blocks; + MempoolId_t mempool_id = {0, 0}; + + auto active_mempool = MemPoolContext::getActiveMemPool(); + if (active_mempool) { + mempool_id = active_mempool->id(); + } if (mempool_id.first != 0 || mempool_id.second != 0) { // If there is an active mempool, we find the corresponding PrivatePool @@ -2011,7 +2011,7 @@ class DeviceCachingAllocator { all_blocks = get_private_pool_head_blocks(pool->second.get()); } } else { - // When snapshot is called with non-default mempool_id, we return + // When snapshot is called outside a MemPoolContext, we return // all the blocks in the CUDACachingAllocator (as returned by // get_all_blocks). all_blocks = get_all_blocks(); @@ -2130,11 +2130,11 @@ class DeviceCachingAllocator { } } - void createOrIncrefPool(MempoolId_t mempool_id, CUDAAllocator* allocator) { + void ensureExistsAndIncrefPool(MempoolId_t mempool_id) { // Create a PrivatePool object if it does not exist yet // and increment its use_count std::lock_guard lock(mutex); - create_or_incref_pool(mempool_id, allocator); + ensure_exists_and_incref_pool(mempool_id); } void setUseOnOOM(MempoolId_t mempool_id) { @@ -2150,7 +2150,7 @@ class DeviceCachingAllocator { MempoolId_t mempool_id, std::function filter) { std::lock_guard lock(mutex); - create_or_incref_pool(mempool_id); + ensure_exists_and_incref_pool(mempool_id); for (auto it2 = captures_underway.begin(); it2 != captures_underway.end(); ++it2) { TORCH_CHECK( @@ -2272,24 +2272,21 @@ class DeviceCachingAllocator { return blocks; } - void create_or_incref_pool( - MempoolId_t mempool_id, - CUDAAllocator* allocator = nullptr) { + void ensure_exists_and_incref_pool(MempoolId_t mempool_id) { auto it = graph_pools.find(mempool_id); if (it == graph_pools.end()) { // mempool_id does not reference an existing pool. // Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool // usage. use_count is initially 1, which means the pool is - // being used since somebody called createOrIncrefPool. + // being used since somebody called ensureExistsAndIncrefPool. graph_pools.emplace( - mempool_id, std::make_unique(mempool_id, allocator)); + mempool_id, std::make_unique(mempool_id)); } else { // mempool_id references an existing pool, which the current CUDAGraph // capture or torch.cuda.use_mem_pool will // share. Check this pool is live (at least one other capture already // references it). Increment it to establish the usage. TORCH_INTERNAL_ASSERT(it->second->use_count > 0); - TORCH_INTERNAL_ASSERT(allocator == nullptr); it->second->use_count++; } } @@ -2779,8 +2776,7 @@ class DeviceCachingAllocator { bool in_fbcode = false; #endif - bool active_pool = - p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator(); + auto active_pool = MemPoolContext::getActiveMemPool(); if (set_fraction && total_allocated_memory + size > allowed_memory_maximum) { p.err = cudaErrorMemoryAllocation; @@ -2805,6 +2801,12 @@ class DeviceCachingAllocator { } return bool(p.block); } else { + if (active_pool && active_pool->allocator() && + p.pool->owner_PrivatePool) { + // Ensure that active_pool and p.pool are the same + auto pp = get_private_pool(active_pool->id()); + TORCH_INTERNAL_ASSERT(pp == p.pool->owner_PrivatePool); + } if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) { // At scope exit, acquire the lock again. This provides safety against // any potential exceptions in the cudaMallocMaybeCapturing function. @@ -2924,9 +2926,13 @@ class DeviceCachingAllocator { return true; } - bool release_cached_blocks( - const std::shared_ptr& context, - MempoolId_t mempool_id) { + bool release_cached_blocks(const std::shared_ptr& context) { + MempoolId_t mempool_id = {0, 0}; + auto active_mempool = MemPoolContext::getActiveMemPool(); + if (active_mempool) { + mempool_id = active_mempool->id(); + } + if (mempool_id.first == 0 && mempool_id.second == 0) { // If there is no active mempool, we work on releasing *all* blocks. @@ -2999,10 +3005,15 @@ class DeviceCachingAllocator { context ? context : block->context_when_segment_allocated); auto* pool = block->pool; - if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) { + auto active_pool = MemPoolContext::getActiveMemPool(); + if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) { + // Ensure that active_pool and pool are the same + auto pp = get_private_pool(active_pool->id()); + TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool); + // If there is an active mempool with a given allocator, // we use the given allocator's delete function. - pool->owner_PrivatePool->allocator()->raw_delete((void*)block->ptr); + active_pool->allocator()->raw_delete((void*)block->ptr); } else { C10_CUDA_CHECK(cudaFree((void*)block->ptr)); } @@ -3578,9 +3589,9 @@ class NativeCachingAllocator : public CUDAAllocator { } } - void emptyCache(MempoolId_t mempool_id) override { + void emptyCache() override { for (auto& da : device_allocator) - da->emptyCache(mempool_id); + da->emptyCache(); } void enable(bool value) override { @@ -3628,7 +3639,7 @@ class NativeCachingAllocator : public CUDAAllocator { device_allocator[block->device]->recordStream(block, stream); } - SnapshotInfo snapshot(MempoolId_t mempool_id) override { + SnapshotInfo snapshot() override { // Set-up converter to convert timestamps from tsc to microseconds. auto tsc_to_ns = clock_converter.makeConverter(); auto tsc_to_us = [=](approx_time_t t_approx) { @@ -3646,7 +3657,7 @@ class NativeCachingAllocator : public CUDAAllocator { // Get the device_traces' TraceEntry lists. for (auto& da : device_allocator) { result.device_traces.emplace_back(da->trace(tsc_to_us)); - auto snap = da->snapshot(mempool_id); + auto snap = da->snapshot(); result.segments.insert(result.segments.end(), snap.begin(), snap.end()); } @@ -3774,13 +3785,11 @@ class NativeCachingAllocator : public CUDAAllocator { device_allocator[device]->resetPeakStats(); } - void createOrIncrefPool( + void ensureExistsAndIncrefPool( c10::DeviceIndex device, - MempoolId_t mempool_id, - CUDAAllocator* allocator) override { + MempoolId_t mempool_id) override { assertValidDevice(device); - device_allocator[device]->createOrIncrefPool( - std::move(mempool_id), allocator); + device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id)); } void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override { @@ -4125,7 +4134,7 @@ MemPool::MemPool( id_ = {uuid_++, 0}; } device_ = c10::cuda::current_device(); - CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator); + CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_); if (use_on_oom) { CUDACachingAllocator::setUseOnOOM(device_, id_); } @@ -4134,7 +4143,8 @@ MemPool::MemPool( MemPool::~MemPool() { TORCH_INTERNAL_ASSERT(use_count() == 1); CUDACachingAllocator::releasePool(device_, id_); - c10::cuda::CUDACachingAllocator::emptyCache(id_); + auto ctx = MemPoolContext(this); + c10::cuda::CUDACachingAllocator::emptyCache(); } MempoolId_t MemPool::id() { @@ -4160,4 +4170,23 @@ MempoolId_t MemPool::graph_pool_handle(bool is_user_created) { return {uuid_++, 0}; } +// Note that active_mempool_ is a global variable here +// and not inside MemPoolContext class, because in windows we +// can't use __declspec(dllexport) and __declspec(thread) +// together: https://stackoverflow.com/a/50967977 +static thread_local MemPool* active_mempool_ = nullptr; + +MemPoolContext::MemPoolContext(MemPool* mempool) + : prev_mempool_(active_mempool_) { + active_mempool_ = mempool; +} + +MemPoolContext::~MemPoolContext() { + active_mempool_ = prev_mempool_; +} + +MemPool* MemPoolContext::getActiveMemPool() { + return active_mempool_; +} + } // namespace c10::cuda diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 6d86a2178d5..57d549d9260 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -211,7 +211,7 @@ class CUDAAllocator : public Allocator { virtual bool initialized() = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; - virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; + virtual void emptyCache() = 0; virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; @@ -221,7 +221,7 @@ class CUDAAllocator : public Allocator { c10::DeviceIndex device) = 0; virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; virtual void resetPeakStats(c10::DeviceIndex device) = 0; - virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; + virtual SnapshotInfo snapshot() = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, MempoolId_t mempool_id, @@ -239,14 +239,13 @@ class CUDAAllocator : public Allocator { " does not yet support getPoolUseCount. " "If you need it, please file an issue describing your use case."); } - virtual void createOrIncrefPool( + virtual void ensureExistsAndIncrefPool( c10::DeviceIndex /*device*/, - MempoolId_t /*mempool_id*/, - CUDAAllocator* allocator = nullptr) { + MempoolId_t /*mempool_id*/) { TORCH_CHECK( false, name(), - " does not yet support createOrIncrefPool. " + " does not yet support ensureExistsAndIncrefPool. " "If you need it, please file an issue describing your use case."); } virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) { @@ -365,7 +364,7 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) { return get()->setMemoryFraction(fraction, device); } -inline void emptyCache(MempoolId_t mempool_id = {0, 0}) { +inline void emptyCache() { return get()->emptyCache(); } @@ -402,8 +401,8 @@ inline void resetPeakStats(c10::DeviceIndex device) { return get()->resetPeakStats(device); } -inline SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) { - return get()->snapshot(mempool_id); +inline SnapshotInfo snapshot() { + return get()->snapshot(); } inline std::shared_ptr getCheckpointState( @@ -476,11 +475,10 @@ inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { return get()->releasePool(device, mempool_id); } -inline void createOrIncrefPool( +inline void ensureExistsAndIncrefPool( c10::DeviceIndex device, - MempoolId_t mempool_id, - CUDAAllocator* allocator_ptr = nullptr) { - get()->createOrIncrefPool(device, mempool_id, allocator_ptr); + MempoolId_t mempool_id) { + get()->ensureExistsAndIncrefPool(device, mempool_id); } inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) { get()->setUseOnOOM(device, mempool_id); @@ -557,4 +555,26 @@ struct C10_CUDA_API MemPool { c10::DeviceIndex device_; }; +// MemPoolContext holds the currently active pool and stashes the previous +// pool. On deletion it makes the previous pool active. +struct C10_CUDA_API MemPoolContext { + MemPoolContext(MemPool* mempool); + + ~MemPoolContext(); + + // getActiveMemPool() can be used to get the currently active pool. + // For instance: in CUDACachingAllocator, we can route allocations + // to a user provided allocator, by doing: + // + // auto active_pool = MemPoolContext::getActiveMemPool(); + // if (active_pool && active_pool->allocator()) { + // ptr = active_pool->allocator()->raw_alloc(size); + // } + // + static MemPool* getActiveMemPool(); + + private: + MemPool* prev_mempool_; +}; + } // namespace c10::cuda diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index b5f313e419d..dc56545ec37 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -496,7 +496,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { // introduces performance nondeterminism. } - void emptyCache(/*unused*/ MempoolId_t mempool_id) override { + void emptyCache() override { std::lock_guard lk(general_mutex); for (int dev = 0; dev < device_count; dev++) { @@ -778,7 +778,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero)); } - SnapshotInfo snapshot(MempoolId_t mempool_id) override { + SnapshotInfo snapshot() override { TORCH_CHECK( false, "Calling snapshot with backend:cudaMallocAsync is not meaningful. " diff --git a/docs/source/conf.py b/docs/source/conf.py index ed943929cfa..a326c8cbd20 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2282,6 +2282,7 @@ coverage_ignore_classes = [ "UnsynchronizedAccessError", # torch.cuda.memory "MemPool", + "MemPoolContext", # torch.distributed.elastic.multiprocessing.errors "ChildFailedError", "ProcessFailure", diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 431758fc2ef..38532c4d5f5 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -128,6 +128,7 @@ Memory management CUDAPluggableAllocator change_current_allocator MemPool + MemPoolContext .. currentmodule:: torch.cuda.memory diff --git a/test/test_cuda.py b/test/test_cuda.py index c77c724d3c9..35949061c41 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -5049,57 +5049,41 @@ class TestMemPool(TestCase): # increments the id self.assertTrue(abs(pool2[1] - pool1[1]) > 0) - def get_dummy_allocator(self, check_vars): - dummy_allocator_source_vars = """ + def test_mempool_with_allocator(self): + pool = torch.cuda.MemPool() + + # MemPool doesn't have an allocator by default + self.assertEqual(pool.allocator, None) + + from torch.utils.cpp_extension import load_inline + + dummy_allocator_source = """ #include #include #include extern "C" { - C10_EXPORT int called_dummy_alloc = 0; - C10_EXPORT int called_dummy_free = 0; + C10_EXPORT int called_dummy_alloc = 0; + C10_EXPORT int called_dummy_free = 0; - // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 - C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { - called_dummy_alloc = 123; - void* ptr; - C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); - return ptr; - } - - C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { - called_dummy_free = 321; - C10_CUDA_CHECK(cudaFree(ptr)); - } - } - """ - dummy_allocator_source_no_vars = """ - #include - #include - #include - - extern "C" { // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { + called_dummy_alloc = 123; void* ptr; C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); return ptr; } C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { + called_dummy_free = 321; C10_CUDA_CHECK(cudaFree(ptr)); } } """ - - from torch.utils.cpp_extension import load_inline - dummy_allocator_libname = "dummy_allocator" dummy_allocator = load_inline( name=dummy_allocator_libname, - cpp_sources=dummy_allocator_source_vars - if check_vars - else dummy_allocator_source_no_vars, + cpp_sources=dummy_allocator_source, is_python_module=False, keep_intermediates=False, verbose=True, @@ -5110,15 +5094,6 @@ class TestMemPool(TestCase): "dummy_alloc", "dummy_free", ) - return allocator, dummy_allocator - - def test_mempool_with_allocator(self): - pool = torch.cuda.MemPool() - - # MemPool doesn't have an allocator by default - self.assertEqual(pool.allocator, None) - allocator, dummy_allocator = self.get_dummy_allocator(check_vars=True) - pool = torch.cuda.MemPool(allocator.allocator()) # pool should point to the same allocator as the one passed into it @@ -5153,8 +5128,6 @@ class TestMemPool(TestCase): # out tensor self.assertEqual(called_dummy_alloc.value, 123) - out_non_pool = torch.empty(nelem_1mb, device="cuda") - with torch.cuda.use_mem_pool(pool): # pool should have 1 segment since we made a small allocation (1 MB) # above and so the CUDACachingAllocator packed it into a 2 MB buffer @@ -5172,9 +5145,6 @@ class TestMemPool(TestCase): # to make a new 2 MB buffer to accomodate out_2 self.assertEqual(len(pool.snapshot()), 2) - all_segments = torch.cuda.memory._snapshot()["segments"] - self.assertEqual(len(all_segments), 3) - del out_0, out_1, out_2 # pool's destructor calls emptyCache() @@ -5186,7 +5156,40 @@ class TestMemPool(TestCase): @serialTest() def test_mempool_limited_memory_with_allocator(self): - allocator, _ = self.get_dummy_allocator(check_vars=False) + from torch.utils.cpp_extension import load_inline + + dummy_allocator_source = """ + #include + #include + #include + + extern "C" { + // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 + C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { + void* ptr; + C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); + return ptr; + } + + C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { + C10_CUDA_CHECK(cudaFree(ptr)); + } + } + """ + dummy_allocator_libname = "dummy_allocator" + dummy_allocator = load_inline( + name=dummy_allocator_libname, + cpp_sources=dummy_allocator_source, + is_python_module=False, + keep_intermediates=False, + verbose=True, + with_cuda=True, + ) + allocator = torch.cuda.memory.CUDAPluggableAllocator( + dummy_allocator, + "dummy_alloc", + "dummy_free", + ) pool_do_not_use = torch.cuda.MemPool(allocator.allocator()) pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True) @@ -5255,13 +5258,38 @@ class TestMemPool(TestCase): self._teardown_mempool_limited_memory_test() + def test_mempool_context(self): + active_pool = torch.cuda.MemPoolContext.active_pool() + + # there is no active pool if none was made active + self.assertEqual(active_pool, None) + + pool = torch.cuda.MemPool() + ctx = torch.cuda.MemPoolContext(pool) + active_pool = torch.cuda.MemPoolContext.active_pool() + + # pool was made active + self.assertEqual(active_pool, pool) + + del ctx + active_pool = torch.cuda.MemPoolContext.active_pool() + + # ctx was deleted, so active pool is the previous one + self.assertEqual(active_pool, None) + def test_mempool_multithread(self): pool_ids = [] + active_pool_ids = [] def create_mempool_and_make_active(): pool = torch.cuda.MemPool() pool_ids.extend([pool.id]) + ctx = torch.cuda.MemPoolContext(pool) + active_pool = torch.cuda.MemPoolContext.active_pool() + active_pool_ids.extend([active_pool.id]) + del ctx + num_threads = 4 threads = [ threading.Thread(target=create_mempool_and_make_active) @@ -5276,12 +5304,14 @@ class TestMemPool(TestCase): # mempool id creation is atomic self.assertEqual(len(set(pool_ids)), 4) + # each thread should have different active mempool, since + # the pointer to the mempool is thread local + self.assertEqual(len(set(active_pool_ids)), 4) + @skipIfRocm(msg="expandable_segments mode is not supported on ROCm") - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode") def test_mempool_expandable(self): torch.cuda.memory._set_allocator_settings("expandable_segments:True") - allocator, _ = self.get_dummy_allocator(check_vars=False) - pool = torch.cuda.MemPool(allocator.allocator()) + pool = torch.cuda.MemPool() # torch.cuda.MemPool doesn't work with expandable segments with self.assertRaises(RuntimeError): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 35cbbeda392..821b977f60f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2027,7 +2027,7 @@ def _cuda_resetPeakMemoryStats(device: _int) -> None: ... def _cuda_hostMemoryStats() -> dict[str, Any]: ... def _cuda_resetAccumulatedHostMemoryStats() -> None: ... def _cuda_resetPeakHostMemoryStats() -> None: ... -def _cuda_memorySnapshot(mempool_id: tuple[_int, _int] | None) -> dict[str, Any]: ... +def _cuda_memorySnapshot() -> dict[str, Any]: ... def _cuda_record_memory_history_legacy( enabled: _bool, record_context: _bool, @@ -2304,6 +2304,11 @@ class _MemPool: def allocator(self) -> _cuda_CUDAAllocator | None: ... def use_count(self) -> _int: ... +class _MemPoolContext: + def __init__(self, pool: _MemPool) -> None: ... + @staticmethod + def active_pool() -> _MemPool | None: ... + def _cuda_isCurrentStreamCapturing() -> _bool: ... def _graph_pool_handle() -> tuple[_int, _int]: ... diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index 43606807c6e..ba4c669b06d 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -184,8 +184,7 @@ void CUDAPluggableAllocator::setMemoryFraction( } } -void CUDAPluggableAllocator::emptyCache( - /*unused*/ c10::cuda::MempoolId_t mempool_id) { +void CUDAPluggableAllocator::emptyCache() { if (reset_fn_) { return reset_fn_(); } @@ -238,8 +237,8 @@ void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) { "If you need it, please file an issue describing your use case."); } -c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::snapshot( - c10::cuda::MempoolId_t mempool_id) { +c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator:: + snapshot() { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support snapshot. " diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 54a478707b7..ade983e708c 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -114,7 +114,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator bool initialized() override; double getMemoryFraction(c10::DeviceIndex device) override; void setMemoryFraction(double fraction, c10::DeviceIndex device) override; - void emptyCache(c10::cuda::MempoolId_t mempool_id = {0, 0}) override; + void emptyCache() override; void enable(bool) override {} bool isEnabled() const override { return true; @@ -128,8 +128,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; - c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot( - c10::cuda::MempoolId_t mempool) override; + c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override; void beginAllocateToPool( c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id, diff --git a/torch/csrc/cuda/MemPool.cpp b/torch/csrc/cuda/MemPool.cpp index b651a4b5e68..8e53a37591e 100644 --- a/torch/csrc/cuda/MemPool.cpp +++ b/torch/csrc/cuda/MemPool.cpp @@ -24,4 +24,8 @@ void THCPMemPool_init(PyObject* module) { .def_property_readonly("id", &::c10::cuda::MemPool::id) .def_property_readonly("allocator", &::c10::cuda::MemPool::allocator) .def("use_count", &::c10::cuda::MemPool::use_count); + shared_ptr_class_<::c10::cuda::MemPoolContext>(torch_C_m, "_MemPoolContext") + .def(py::init()) + .def_static( + "active_pool", &::c10::cuda::MemPoolContext::getActiveMemPool); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index cbb8910fff5..0391067d886 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -721,24 +721,8 @@ CapturedTraceback* getFromContext( "attempting to gather stack context from the wrong StackContext type."); } -PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { +PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS - c10::cuda::MempoolId_t mempool_id = {0, 0}; - if (arg && arg != Py_None) { - TORCH_CHECK(PyTuple_Check(arg), "mempool_id must be a tuple"); - Py_ssize_t size = PyTuple_Size(arg); - TORCH_CHECK(size == 2, "mempool_id must be a tuple of 2 integers"); - - auto id1 = THPObjectPtr(PyTuple_GetItem(arg, 0)); - auto id2 = THPObjectPtr(PyTuple_GetItem(arg, 1)); - TORCH_CHECK( - THPUtils_checkLong(id1) && THPUtils_checkLong(id2), - "mempool_id elements must be integers"); - - mempool_id = c10::cuda::MempoolId_t( - static_cast(THPUtils_unpackLong(id1)), - static_cast(THPUtils_unpackLong(id2))); - } using c10::cuda::CUDACachingAllocator::BlockInfo; using c10::cuda::CUDACachingAllocator::SegmentInfo; @@ -818,7 +802,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { return segmentDict; }; - auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(mempool_id); + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); py::list segments; @@ -2027,7 +2011,7 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_resetPeakMemoryStats, METH_O, nullptr}, - {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_O, nullptr}, + {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr}, {"_cuda_attach_out_of_memory_observer", THCPModule_attachOutOfMemoryObserver, METH_O, diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index ef21660e7d7..b67d98a6d76 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1186,7 +1186,8 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { // We must ensure we're listening for allocator trace events in order to // register future segments allocated in this pool (this call is idempotent). attachAllocatorHooks(); - auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id()); + auto ctx = c10::cuda::MemPoolContext(pool); + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); for (const auto& segmentInfo : snapshot.segments) { TORCH_INTERNAL_ASSERT( segmentInfo.device == pool->device(), @@ -1220,7 +1221,8 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) { auto iter = ncclCommMemPoolMap.find(ncclComm); iter->second.erase(pool->id()); } - auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id()); + auto ctx = c10::cuda::MemPoolContext(pool); + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); for (const auto& segmentInfo : snapshot.segments) { TORCH_INTERNAL_ASSERT( segmentInfo.device == pool->device(), @@ -5570,6 +5572,7 @@ at::Tensor ProcessGroupNCCL::allocateTensor( } // Allocate tensor under this MemPool's context + auto ctx = c10::cuda::MemPoolContext(memPool_.get()); auto tid = std::this_thread::get_id(); c10::cuda::CUDACachingAllocator::beginAllocateToPool( memPool_->device(), memPool_->id(), [=](cudaStream_t) { diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 4f38de8a2c5..ca311b0f691 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1872,6 +1872,7 @@ __all__ = [ "memory_summary", "memory_usage", "MemPool", + "MemPoolContext", "use_mem_pool", "temperature", "power_draw", diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index df02fb28d44..470dda35394 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -60,6 +60,7 @@ __all__ = [ "CUDAPluggableAllocator", "change_current_allocator", "MemPool", + "MemPoolContext", "use_mem_pool", ] @@ -72,6 +73,7 @@ if not hasattr(torch._C, "_cuda_CUDAAllocator"): if not hasattr(torch._C, "_MemPool"): # Define dummy base classes torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool") + torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext") torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( "_cuda_beginAllocateToPool" ) @@ -90,6 +92,7 @@ from torch._C import ( # noqa: F401 _cuda_endAllocateToPool, _cuda_releasePool, _MemPool, + _MemPoolContext, ) @@ -614,7 +617,7 @@ def max_memory_cached(device: "Device" = None) -> int: return max_memory_reserved(device=device) -def memory_snapshot(mempool_id=None): +def memory_snapshot(): r"""Return a snapshot of the CUDA memory allocator state across all devices. Interpreting the output of this function requires familiarity with the @@ -624,7 +627,7 @@ def memory_snapshot(mempool_id=None): See :ref:`cuda-memory-management` for more details about GPU memory management. """ - return torch._C._cuda_memorySnapshot(mempool_id)["segments"] + return torch._C._cuda_memorySnapshot()["segments"] def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str: @@ -995,7 +998,7 @@ def _snapshot(device: "Device" = None): Returns: The Snapshot dictionary object """ - return _C._cuda_memorySnapshot(None) + return _C._cuda_memorySnapshot() def _dump_snapshot(filename="dump_snapshot.pickle"): @@ -1107,6 +1110,25 @@ def _get_current_allocator() -> _CUDAAllocator: return _CUDAAllocator(torch._C._cuda_getAllocator()) +class MemPoolContext(_MemPoolContext): + r"""MemPoolContext holds the currently active pool and stashes the previous + pool. On deletion it makes the previous pool active. + + Args: + pool(torch.cuda.MemPool): a MemPool object to be made active so that + allocations route to this pool. + + """ + + def __init__(self, pool: _MemPool): + super().__init__(pool) + + @staticmethod + def active_pool() -> Optional[_MemPool]: + r"""Returns the active MemPool""" + return _MemPoolContext.active_pool() + + class MemPool(_MemPool): r"""MemPool represents a pool of memory in a caching allocator. Currently, it's just the ID of the pool object maintained in the CUDACachingAllocator. @@ -1155,7 +1177,11 @@ class MemPool(_MemPool): See :ref:`cuda-memory-management` for more details about GPU memory management. """ - snapshot = torch.cuda.memory_snapshot(self.id) + try: + ctx = MemPoolContext(self) + snapshot = torch.cuda.memory_snapshot() + finally: + del ctx return snapshot @@ -1176,6 +1202,7 @@ def use_mem_pool(pool: MemPool, device: "Device" = None): (e.g. by calling backward) the allocations in that thread will not route to the given pool. """ + ctx = MemPoolContext(pool) device_index = ( torch.cuda.current_device() if device is None else _get_device_index(device) ) @@ -1185,3 +1212,4 @@ def use_mem_pool(pool: MemPool, device: "Device" = None): finally: _cuda_endAllocateToPool(device_index, pool.id) _cuda_releasePool(device_index, pool.id) + del ctx