diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 140fc800b00..b945efea785 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -833,8 +833,9 @@ class EventPool { // CUDA graphs helper struct PrivatePool { - PrivatePool(MempoolId_t id) + PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr) : id(std::move(id)), + allocator_(allocator), large_blocks(/*small=*/false, this), small_blocks(/*small=*/true, this) {} PrivatePool(const PrivatePool&) = delete; @@ -855,8 +856,14 @@ struct PrivatePool { // distinguish private blocks by adding a "pool id" check above the stream // check in BlockComparator. BlockComparator is performance- critical though, // I'd rather not add more logic to it. + CUDAAllocator* allocator_; BlockPool large_blocks; BlockPool small_blocks; + + public: + CUDAAllocator* allocator() { + return allocator_; + } }; MempoolId_t BlockPool::owner_MempoolId() const { @@ -905,9 +912,8 @@ struct MempoolIdHash { }; cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) { - auto active_pool = MemPoolContext::getActiveMemPool(); - if (active_pool && active_pool->allocator() && p.pool->owner_PrivatePool) { - *ptr = active_pool->allocator()->raw_alloc(size); + if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) { + *ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size); return *ptr ? cudaSuccess : cudaErrorMemoryAllocation; } else { return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size)); @@ -1277,14 +1283,14 @@ class DeviceCachingAllocator { alloc_block(params, false, context, lock)) // Free all non-split cached blocks and retry alloc. || (C10_LIKELY(captures_underway.empty()) && - release_cached_blocks(context) && + release_cached_blocks(context, {0, 0}) && alloc_block(params, true, context, lock)); } // we are about to oom, try to use existing mempools as a last resort if (!block_found && params.err == cudaErrorMemoryAllocation) { // if already trying to use a mempool, then just oom - auto active_pool = MemPoolContext::getActiveMemPool(); + bool active_pool = params.pool->owner_PrivatePool; if (!active_pool) { for (MempoolId_t mempool_id : use_on_oom_pools) { auto tid = std::this_thread::get_id(); @@ -1671,10 +1677,10 @@ class DeviceCachingAllocator { } /** returns cached blocks to the system allocator **/ - void emptyCache() { + void emptyCache(MempoolId_t mempool_id) { auto context = maybeGatherContext(RecordContext::ALL); std::lock_guard lock(mutex); - release_cached_blocks(context); + release_cached_blocks(context, mempool_id); } /** Retrieves size of largest unused block held by the memory cache **/ @@ -1992,16 +1998,10 @@ class DeviceCachingAllocator { /** Dump a complete snapshot of the memory held by the allocator. Potentially * VERY expensive. **/ - std::vector snapshot() { + std::vector snapshot(MempoolId_t mempool_id) { std::lock_guard lock(mutex); std::vector all_blocks; - MempoolId_t mempool_id = {0, 0}; - - auto active_mempool = MemPoolContext::getActiveMemPool(); - if (active_mempool) { - mempool_id = active_mempool->id(); - } if (mempool_id.first != 0 || mempool_id.second != 0) { // If there is an active mempool, we find the corresponding PrivatePool @@ -2011,7 +2011,7 @@ class DeviceCachingAllocator { all_blocks = get_private_pool_head_blocks(pool->second.get()); } } else { - // When snapshot is called outside a MemPoolContext, we return + // When snapshot is called with non-default mempool_id, we return // all the blocks in the CUDACachingAllocator (as returned by // get_all_blocks). all_blocks = get_all_blocks(); @@ -2130,11 +2130,11 @@ class DeviceCachingAllocator { } } - void ensureExistsAndIncrefPool(MempoolId_t mempool_id) { + void createOrIncrefPool(MempoolId_t mempool_id, CUDAAllocator* allocator) { // Create a PrivatePool object if it does not exist yet // and increment its use_count std::lock_guard lock(mutex); - ensure_exists_and_incref_pool(mempool_id); + create_or_incref_pool(mempool_id, allocator); } void setUseOnOOM(MempoolId_t mempool_id) { @@ -2150,7 +2150,7 @@ class DeviceCachingAllocator { MempoolId_t mempool_id, std::function filter) { std::lock_guard lock(mutex); - ensure_exists_and_incref_pool(mempool_id); + create_or_incref_pool(mempool_id); for (auto it2 = captures_underway.begin(); it2 != captures_underway.end(); ++it2) { TORCH_CHECK( @@ -2272,21 +2272,24 @@ class DeviceCachingAllocator { return blocks; } - void ensure_exists_and_incref_pool(MempoolId_t mempool_id) { + void create_or_incref_pool( + MempoolId_t mempool_id, + CUDAAllocator* allocator = nullptr) { auto it = graph_pools.find(mempool_id); if (it == graph_pools.end()) { // mempool_id does not reference an existing pool. // Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool // usage. use_count is initially 1, which means the pool is - // being used since somebody called ensureExistsAndIncrefPool. + // being used since somebody called createOrIncrefPool. graph_pools.emplace( - mempool_id, std::make_unique(mempool_id)); + mempool_id, std::make_unique(mempool_id, allocator)); } else { // mempool_id references an existing pool, which the current CUDAGraph // capture or torch.cuda.use_mem_pool will // share. Check this pool is live (at least one other capture already // references it). Increment it to establish the usage. TORCH_INTERNAL_ASSERT(it->second->use_count > 0); + TORCH_INTERNAL_ASSERT(allocator == nullptr); it->second->use_count++; } } @@ -2776,7 +2779,8 @@ class DeviceCachingAllocator { bool in_fbcode = false; #endif - auto active_pool = MemPoolContext::getActiveMemPool(); + bool active_pool = + p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator(); if (set_fraction && total_allocated_memory + size > allowed_memory_maximum) { p.err = cudaErrorMemoryAllocation; @@ -2801,12 +2805,6 @@ class DeviceCachingAllocator { } return bool(p.block); } else { - if (active_pool && active_pool->allocator() && - p.pool->owner_PrivatePool) { - // Ensure that active_pool and p.pool are the same - auto pp = get_private_pool(active_pool->id()); - TORCH_INTERNAL_ASSERT(pp == p.pool->owner_PrivatePool); - } if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) { // At scope exit, acquire the lock again. This provides safety against // any potential exceptions in the cudaMallocMaybeCapturing function. @@ -2926,13 +2924,9 @@ class DeviceCachingAllocator { return true; } - bool release_cached_blocks(const std::shared_ptr& context) { - MempoolId_t mempool_id = {0, 0}; - auto active_mempool = MemPoolContext::getActiveMemPool(); - if (active_mempool) { - mempool_id = active_mempool->id(); - } - + bool release_cached_blocks( + const std::shared_ptr& context, + MempoolId_t mempool_id) { if (mempool_id.first == 0 && mempool_id.second == 0) { // If there is no active mempool, we work on releasing *all* blocks. @@ -3005,15 +2999,10 @@ class DeviceCachingAllocator { context ? context : block->context_when_segment_allocated); auto* pool = block->pool; - auto active_pool = MemPoolContext::getActiveMemPool(); - if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) { - // Ensure that active_pool and pool are the same - auto pp = get_private_pool(active_pool->id()); - TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool); - + if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) { // If there is an active mempool with a given allocator, // we use the given allocator's delete function. - active_pool->allocator()->raw_delete((void*)block->ptr); + pool->owner_PrivatePool->allocator()->raw_delete((void*)block->ptr); } else { C10_CUDA_CHECK(cudaFree((void*)block->ptr)); } @@ -3589,9 +3578,9 @@ class NativeCachingAllocator : public CUDAAllocator { } } - void emptyCache() override { + void emptyCache(MempoolId_t mempool_id) override { for (auto& da : device_allocator) - da->emptyCache(); + da->emptyCache(mempool_id); } void enable(bool value) override { @@ -3639,7 +3628,7 @@ class NativeCachingAllocator : public CUDAAllocator { device_allocator[block->device]->recordStream(block, stream); } - SnapshotInfo snapshot() override { + SnapshotInfo snapshot(MempoolId_t mempool_id) override { // Set-up converter to convert timestamps from tsc to microseconds. auto tsc_to_ns = clock_converter.makeConverter(); auto tsc_to_us = [=](approx_time_t t_approx) { @@ -3657,7 +3646,7 @@ class NativeCachingAllocator : public CUDAAllocator { // Get the device_traces' TraceEntry lists. for (auto& da : device_allocator) { result.device_traces.emplace_back(da->trace(tsc_to_us)); - auto snap = da->snapshot(); + auto snap = da->snapshot(mempool_id); result.segments.insert(result.segments.end(), snap.begin(), snap.end()); } @@ -3785,11 +3774,13 @@ class NativeCachingAllocator : public CUDAAllocator { device_allocator[device]->resetPeakStats(); } - void ensureExistsAndIncrefPool( + void createOrIncrefPool( c10::DeviceIndex device, - MempoolId_t mempool_id) override { + MempoolId_t mempool_id, + CUDAAllocator* allocator) override { assertValidDevice(device); - device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id)); + device_allocator[device]->createOrIncrefPool( + std::move(mempool_id), allocator); } void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override { @@ -4134,7 +4125,7 @@ MemPool::MemPool( id_ = {uuid_++, 0}; } device_ = c10::cuda::current_device(); - CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_); + CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator); if (use_on_oom) { CUDACachingAllocator::setUseOnOOM(device_, id_); } @@ -4143,8 +4134,7 @@ MemPool::MemPool( MemPool::~MemPool() { TORCH_INTERNAL_ASSERT(use_count() == 1); CUDACachingAllocator::releasePool(device_, id_); - auto ctx = MemPoolContext(this); - c10::cuda::CUDACachingAllocator::emptyCache(); + c10::cuda::CUDACachingAllocator::emptyCache(id_); } MempoolId_t MemPool::id() { @@ -4170,23 +4160,4 @@ MempoolId_t MemPool::graph_pool_handle(bool is_user_created) { return {uuid_++, 0}; } -// Note that active_mempool_ is a global variable here -// and not inside MemPoolContext class, because in windows we -// can't use __declspec(dllexport) and __declspec(thread) -// together: https://stackoverflow.com/a/50967977 -static thread_local MemPool* active_mempool_ = nullptr; - -MemPoolContext::MemPoolContext(MemPool* mempool) - : prev_mempool_(active_mempool_) { - active_mempool_ = mempool; -} - -MemPoolContext::~MemPoolContext() { - active_mempool_ = prev_mempool_; -} - -MemPool* MemPoolContext::getActiveMemPool() { - return active_mempool_; -} - } // namespace c10::cuda diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 57d549d9260..6d86a2178d5 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -211,7 +211,7 @@ class CUDAAllocator : public Allocator { virtual bool initialized() = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; - virtual void emptyCache() = 0; + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; @@ -221,7 +221,7 @@ class CUDAAllocator : public Allocator { c10::DeviceIndex device) = 0; virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; virtual void resetPeakStats(c10::DeviceIndex device) = 0; - virtual SnapshotInfo snapshot() = 0; + virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, MempoolId_t mempool_id, @@ -239,13 +239,14 @@ class CUDAAllocator : public Allocator { " does not yet support getPoolUseCount. " "If you need it, please file an issue describing your use case."); } - virtual void ensureExistsAndIncrefPool( + virtual void createOrIncrefPool( c10::DeviceIndex /*device*/, - MempoolId_t /*mempool_id*/) { + MempoolId_t /*mempool_id*/, + CUDAAllocator* allocator = nullptr) { TORCH_CHECK( false, name(), - " does not yet support ensureExistsAndIncrefPool. " + " does not yet support createOrIncrefPool. " "If you need it, please file an issue describing your use case."); } virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) { @@ -364,7 +365,7 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) { return get()->setMemoryFraction(fraction, device); } -inline void emptyCache() { +inline void emptyCache(MempoolId_t mempool_id = {0, 0}) { return get()->emptyCache(); } @@ -401,8 +402,8 @@ inline void resetPeakStats(c10::DeviceIndex device) { return get()->resetPeakStats(device); } -inline SnapshotInfo snapshot() { - return get()->snapshot(); +inline SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) { + return get()->snapshot(mempool_id); } inline std::shared_ptr getCheckpointState( @@ -475,10 +476,11 @@ inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { return get()->releasePool(device, mempool_id); } -inline void ensureExistsAndIncrefPool( +inline void createOrIncrefPool( c10::DeviceIndex device, - MempoolId_t mempool_id) { - get()->ensureExistsAndIncrefPool(device, mempool_id); + MempoolId_t mempool_id, + CUDAAllocator* allocator_ptr = nullptr) { + get()->createOrIncrefPool(device, mempool_id, allocator_ptr); } inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) { get()->setUseOnOOM(device, mempool_id); @@ -555,26 +557,4 @@ struct C10_CUDA_API MemPool { c10::DeviceIndex device_; }; -// MemPoolContext holds the currently active pool and stashes the previous -// pool. On deletion it makes the previous pool active. -struct C10_CUDA_API MemPoolContext { - MemPoolContext(MemPool* mempool); - - ~MemPoolContext(); - - // getActiveMemPool() can be used to get the currently active pool. - // For instance: in CUDACachingAllocator, we can route allocations - // to a user provided allocator, by doing: - // - // auto active_pool = MemPoolContext::getActiveMemPool(); - // if (active_pool && active_pool->allocator()) { - // ptr = active_pool->allocator()->raw_alloc(size); - // } - // - static MemPool* getActiveMemPool(); - - private: - MemPool* prev_mempool_; -}; - } // namespace c10::cuda diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index dc56545ec37..b5f313e419d 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -496,7 +496,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { // introduces performance nondeterminism. } - void emptyCache() override { + void emptyCache(/*unused*/ MempoolId_t mempool_id) override { std::lock_guard lk(general_mutex); for (int dev = 0; dev < device_count; dev++) { @@ -778,7 +778,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero)); } - SnapshotInfo snapshot() override { + SnapshotInfo snapshot(MempoolId_t mempool_id) override { TORCH_CHECK( false, "Calling snapshot with backend:cudaMallocAsync is not meaningful. " diff --git a/docs/source/conf.py b/docs/source/conf.py index a326c8cbd20..ed943929cfa 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2282,7 +2282,6 @@ coverage_ignore_classes = [ "UnsynchronizedAccessError", # torch.cuda.memory "MemPool", - "MemPoolContext", # torch.distributed.elastic.multiprocessing.errors "ChildFailedError", "ProcessFailure", diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 38532c4d5f5..431758fc2ef 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -128,7 +128,6 @@ Memory management CUDAPluggableAllocator change_current_allocator MemPool - MemPoolContext .. currentmodule:: torch.cuda.memory diff --git a/test/test_cuda.py b/test/test_cuda.py index f82aa15a9af..f37b7c82008 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -5049,41 +5049,57 @@ class TestMemPool(TestCase): # increments the id self.assertTrue(abs(pool2[1] - pool1[1]) > 0) - def test_mempool_with_allocator(self): - pool = torch.cuda.MemPool() - - # MemPool doesn't have an allocator by default - self.assertEqual(pool.allocator, None) - - from torch.utils.cpp_extension import load_inline - - dummy_allocator_source = """ + def get_dummy_allocator(self, check_vars): + dummy_allocator_source_vars = """ #include #include #include extern "C" { - C10_EXPORT int called_dummy_alloc = 0; - C10_EXPORT int called_dummy_free = 0; + C10_EXPORT int called_dummy_alloc = 0; + C10_EXPORT int called_dummy_free = 0; + // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 + C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { + called_dummy_alloc = 123; + void* ptr; + C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); + return ptr; + } + + C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { + called_dummy_free = 321; + C10_CUDA_CHECK(cudaFree(ptr)); + } + } + """ + dummy_allocator_source_no_vars = """ + #include + #include + #include + + extern "C" { // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { - called_dummy_alloc = 123; void* ptr; C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); return ptr; } C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { - called_dummy_free = 321; C10_CUDA_CHECK(cudaFree(ptr)); } } """ + + from torch.utils.cpp_extension import load_inline + dummy_allocator_libname = "dummy_allocator" dummy_allocator = load_inline( name=dummy_allocator_libname, - cpp_sources=dummy_allocator_source, + cpp_sources=dummy_allocator_source_vars + if check_vars + else dummy_allocator_source_no_vars, is_python_module=False, keep_intermediates=False, verbose=True, @@ -5094,6 +5110,15 @@ class TestMemPool(TestCase): "dummy_alloc", "dummy_free", ) + return allocator, dummy_allocator + + def test_mempool_with_allocator(self): + pool = torch.cuda.MemPool() + + # MemPool doesn't have an allocator by default + self.assertEqual(pool.allocator, None) + allocator, dummy_allocator = self.get_dummy_allocator(check_vars=True) + pool = torch.cuda.MemPool(allocator.allocator()) # pool should point to the same allocator as the one passed into it @@ -5128,6 +5153,8 @@ class TestMemPool(TestCase): # out tensor self.assertEqual(called_dummy_alloc.value, 123) + out_non_pool = torch.empty(nelem_1mb, device="cuda") + with torch.cuda.use_mem_pool(pool): # pool should have 1 segment since we made a small allocation (1 MB) # above and so the CUDACachingAllocator packed it into a 2 MB buffer @@ -5145,6 +5172,9 @@ class TestMemPool(TestCase): # to make a new 2 MB buffer to accomodate out_2 self.assertEqual(len(pool.snapshot()), 2) + all_segments = torch.cuda.memory._snapshot()["segments"] + self.assertEqual(len(all_segments), 3) + del out_0, out_1, out_2 # pool's destructor calls emptyCache() @@ -5156,40 +5186,7 @@ class TestMemPool(TestCase): @serialTest() def test_mempool_limited_memory_with_allocator(self): - from torch.utils.cpp_extension import load_inline - - dummy_allocator_source = """ - #include - #include - #include - - extern "C" { - // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 - C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { - void* ptr; - C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); - return ptr; - } - - C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { - C10_CUDA_CHECK(cudaFree(ptr)); - } - } - """ - dummy_allocator_libname = "dummy_allocator" - dummy_allocator = load_inline( - name=dummy_allocator_libname, - cpp_sources=dummy_allocator_source, - is_python_module=False, - keep_intermediates=False, - verbose=True, - with_cuda=True, - ) - allocator = torch.cuda.memory.CUDAPluggableAllocator( - dummy_allocator, - "dummy_alloc", - "dummy_free", - ) + allocator, _ = self.get_dummy_allocator(check_vars=False) pool_do_not_use = torch.cuda.MemPool(allocator.allocator()) pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True) @@ -5258,38 +5255,13 @@ class TestMemPool(TestCase): self._teardown_mempool_limited_memory_test() - def test_mempool_context(self): - active_pool = torch.cuda.MemPoolContext.active_pool() - - # there is no active pool if none was made active - self.assertEqual(active_pool, None) - - pool = torch.cuda.MemPool() - ctx = torch.cuda.MemPoolContext(pool) - active_pool = torch.cuda.MemPoolContext.active_pool() - - # pool was made active - self.assertEqual(active_pool, pool) - - del ctx - active_pool = torch.cuda.MemPoolContext.active_pool() - - # ctx was deleted, so active pool is the previous one - self.assertEqual(active_pool, None) - def test_mempool_multithread(self): pool_ids = [] - active_pool_ids = [] def create_mempool_and_make_active(): pool = torch.cuda.MemPool() pool_ids.extend([pool.id]) - ctx = torch.cuda.MemPoolContext(pool) - active_pool = torch.cuda.MemPoolContext.active_pool() - active_pool_ids.extend([active_pool.id]) - del ctx - num_threads = 4 threads = [ threading.Thread(target=create_mempool_and_make_active) @@ -5304,14 +5276,12 @@ class TestMemPool(TestCase): # mempool id creation is atomic self.assertEqual(len(set(pool_ids)), 4) - # each thread should have different active mempool, since - # the pointer to the mempool is thread local - self.assertEqual(len(set(active_pool_ids)), 4) - @skipIfRocm(msg="expandable_segments mode is not supported on ROCm") + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode") def test_mempool_expandable(self): torch.cuda.memory._set_allocator_settings("expandable_segments:True") - pool = torch.cuda.MemPool() + allocator, _ = self.get_dummy_allocator(check_vars=False) + pool = torch.cuda.MemPool(allocator.allocator()) # torch.cuda.MemPool doesn't work with expandable segments with self.assertRaises(RuntimeError): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 821b977f60f..35cbbeda392 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2027,7 +2027,7 @@ def _cuda_resetPeakMemoryStats(device: _int) -> None: ... def _cuda_hostMemoryStats() -> dict[str, Any]: ... def _cuda_resetAccumulatedHostMemoryStats() -> None: ... def _cuda_resetPeakHostMemoryStats() -> None: ... -def _cuda_memorySnapshot() -> dict[str, Any]: ... +def _cuda_memorySnapshot(mempool_id: tuple[_int, _int] | None) -> dict[str, Any]: ... def _cuda_record_memory_history_legacy( enabled: _bool, record_context: _bool, @@ -2304,11 +2304,6 @@ class _MemPool: def allocator(self) -> _cuda_CUDAAllocator | None: ... def use_count(self) -> _int: ... -class _MemPoolContext: - def __init__(self, pool: _MemPool) -> None: ... - @staticmethod - def active_pool() -> _MemPool | None: ... - def _cuda_isCurrentStreamCapturing() -> _bool: ... def _graph_pool_handle() -> tuple[_int, _int]: ... diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index ba4c669b06d..43606807c6e 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -184,7 +184,8 @@ void CUDAPluggableAllocator::setMemoryFraction( } } -void CUDAPluggableAllocator::emptyCache() { +void CUDAPluggableAllocator::emptyCache( + /*unused*/ c10::cuda::MempoolId_t mempool_id) { if (reset_fn_) { return reset_fn_(); } @@ -237,8 +238,8 @@ void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) { "If you need it, please file an issue describing your use case."); } -c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator:: - snapshot() { +c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::snapshot( + c10::cuda::MempoolId_t mempool_id) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support snapshot. " diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index ade983e708c..54a478707b7 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -114,7 +114,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator bool initialized() override; double getMemoryFraction(c10::DeviceIndex device) override; void setMemoryFraction(double fraction, c10::DeviceIndex device) override; - void emptyCache() override; + void emptyCache(c10::cuda::MempoolId_t mempool_id = {0, 0}) override; void enable(bool) override {} bool isEnabled() const override { return true; @@ -128,7 +128,8 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; - c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override; + c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot( + c10::cuda::MempoolId_t mempool) override; void beginAllocateToPool( c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id, diff --git a/torch/csrc/cuda/MemPool.cpp b/torch/csrc/cuda/MemPool.cpp index 8e53a37591e..b651a4b5e68 100644 --- a/torch/csrc/cuda/MemPool.cpp +++ b/torch/csrc/cuda/MemPool.cpp @@ -24,8 +24,4 @@ void THCPMemPool_init(PyObject* module) { .def_property_readonly("id", &::c10::cuda::MemPool::id) .def_property_readonly("allocator", &::c10::cuda::MemPool::allocator) .def("use_count", &::c10::cuda::MemPool::use_count); - shared_ptr_class_<::c10::cuda::MemPoolContext>(torch_C_m, "_MemPoolContext") - .def(py::init()) - .def_static( - "active_pool", &::c10::cuda::MemPoolContext::getActiveMemPool); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 0391067d886..cbb8910fff5 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -721,8 +721,24 @@ CapturedTraceback* getFromContext( "attempting to gather stack context from the wrong StackContext type."); } -PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { +PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS + c10::cuda::MempoolId_t mempool_id = {0, 0}; + if (arg && arg != Py_None) { + TORCH_CHECK(PyTuple_Check(arg), "mempool_id must be a tuple"); + Py_ssize_t size = PyTuple_Size(arg); + TORCH_CHECK(size == 2, "mempool_id must be a tuple of 2 integers"); + + auto id1 = THPObjectPtr(PyTuple_GetItem(arg, 0)); + auto id2 = THPObjectPtr(PyTuple_GetItem(arg, 1)); + TORCH_CHECK( + THPUtils_checkLong(id1) && THPUtils_checkLong(id2), + "mempool_id elements must be integers"); + + mempool_id = c10::cuda::MempoolId_t( + static_cast(THPUtils_unpackLong(id1)), + static_cast(THPUtils_unpackLong(id2))); + } using c10::cuda::CUDACachingAllocator::BlockInfo; using c10::cuda::CUDACachingAllocator::SegmentInfo; @@ -802,7 +818,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { return segmentDict; }; - auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(mempool_id); py::list segments; @@ -2011,7 +2027,7 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_resetPeakMemoryStats, METH_O, nullptr}, - {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr}, + {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_O, nullptr}, {"_cuda_attach_out_of_memory_observer", THCPModule_attachOutOfMemoryObserver, METH_O, diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index b67d98a6d76..ef21660e7d7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1186,8 +1186,7 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { // We must ensure we're listening for allocator trace events in order to // register future segments allocated in this pool (this call is idempotent). attachAllocatorHooks(); - auto ctx = c10::cuda::MemPoolContext(pool); - auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id()); for (const auto& segmentInfo : snapshot.segments) { TORCH_INTERNAL_ASSERT( segmentInfo.device == pool->device(), @@ -1221,8 +1220,7 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) { auto iter = ncclCommMemPoolMap.find(ncclComm); iter->second.erase(pool->id()); } - auto ctx = c10::cuda::MemPoolContext(pool); - auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id()); for (const auto& segmentInfo : snapshot.segments) { TORCH_INTERNAL_ASSERT( segmentInfo.device == pool->device(), @@ -5572,7 +5570,6 @@ at::Tensor ProcessGroupNCCL::allocateTensor( } // Allocate tensor under this MemPool's context - auto ctx = c10::cuda::MemPoolContext(memPool_.get()); auto tid = std::this_thread::get_id(); c10::cuda::CUDACachingAllocator::beginAllocateToPool( memPool_->device(), memPool_->id(), [=](cudaStream_t) { diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index ca311b0f691..4f38de8a2c5 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1872,7 +1872,6 @@ __all__ = [ "memory_summary", "memory_usage", "MemPool", - "MemPoolContext", "use_mem_pool", "temperature", "power_draw", diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 470dda35394..df02fb28d44 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -60,7 +60,6 @@ __all__ = [ "CUDAPluggableAllocator", "change_current_allocator", "MemPool", - "MemPoolContext", "use_mem_pool", ] @@ -73,7 +72,6 @@ if not hasattr(torch._C, "_cuda_CUDAAllocator"): if not hasattr(torch._C, "_MemPool"): # Define dummy base classes torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool") - torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext") torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( "_cuda_beginAllocateToPool" ) @@ -92,7 +90,6 @@ from torch._C import ( # noqa: F401 _cuda_endAllocateToPool, _cuda_releasePool, _MemPool, - _MemPoolContext, ) @@ -617,7 +614,7 @@ def max_memory_cached(device: "Device" = None) -> int: return max_memory_reserved(device=device) -def memory_snapshot(): +def memory_snapshot(mempool_id=None): r"""Return a snapshot of the CUDA memory allocator state across all devices. Interpreting the output of this function requires familiarity with the @@ -627,7 +624,7 @@ def memory_snapshot(): See :ref:`cuda-memory-management` for more details about GPU memory management. """ - return torch._C._cuda_memorySnapshot()["segments"] + return torch._C._cuda_memorySnapshot(mempool_id)["segments"] def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str: @@ -998,7 +995,7 @@ def _snapshot(device: "Device" = None): Returns: The Snapshot dictionary object """ - return _C._cuda_memorySnapshot() + return _C._cuda_memorySnapshot(None) def _dump_snapshot(filename="dump_snapshot.pickle"): @@ -1110,25 +1107,6 @@ def _get_current_allocator() -> _CUDAAllocator: return _CUDAAllocator(torch._C._cuda_getAllocator()) -class MemPoolContext(_MemPoolContext): - r"""MemPoolContext holds the currently active pool and stashes the previous - pool. On deletion it makes the previous pool active. - - Args: - pool(torch.cuda.MemPool): a MemPool object to be made active so that - allocations route to this pool. - - """ - - def __init__(self, pool: _MemPool): - super().__init__(pool) - - @staticmethod - def active_pool() -> Optional[_MemPool]: - r"""Returns the active MemPool""" - return _MemPoolContext.active_pool() - - class MemPool(_MemPool): r"""MemPool represents a pool of memory in a caching allocator. Currently, it's just the ID of the pool object maintained in the CUDACachingAllocator. @@ -1177,11 +1155,7 @@ class MemPool(_MemPool): See :ref:`cuda-memory-management` for more details about GPU memory management. """ - try: - ctx = MemPoolContext(self) - snapshot = torch.cuda.memory_snapshot() - finally: - del ctx + snapshot = torch.cuda.memory_snapshot(self.id) return snapshot @@ -1202,7 +1176,6 @@ def use_mem_pool(pool: MemPool, device: "Device" = None): (e.g. by calling backward) the allocations in that thread will not route to the given pool. """ - ctx = MemPoolContext(pool) device_index = ( torch.cuda.current_device() if device is None else _get_device_index(device) ) @@ -1212,4 +1185,3 @@ def use_mem_pool(pool: MemPool, device: "Device" = None): finally: _cuda_endAllocateToPool(device_index, pool.id) _cuda_releasePool(device_index, pool.id) - del ctx