mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refactors empty_cache to return only MemPool memory to the system (#133602)
Canonically, the empty_cache API releases all cached blocks of the CUDACachingAllocator. There is no API that can release only the cached blocks of a given pool. In this PR, we extend the functionality of empty_cache API such that it only releases the cached blocks of an active pool. When empty_cache API is called under a MemPoolContext, we only release the cached blocks that correspond to the pool id of the active pool. Part of https://github.com/pytorch/pytorch/issues/124807. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133602 Approved by: https://github.com/ezyang
This commit is contained in:
parent
bd369bb182
commit
341a28f0ce
|
|
@ -2831,16 +2831,37 @@ class DeviceCachingAllocator {
|
|||
}
|
||||
|
||||
bool release_cached_blocks(const std::shared_ptr<GatheredContext>& context) {
|
||||
// First ensure that all blocks that can't currently be allocated due to
|
||||
// outstanding events are returned to the pool.
|
||||
synchronize_and_free_events(context);
|
||||
MempoolId_t mempool_id = {0, 0};
|
||||
auto active_mempool = MemPoolContext::getActiveMemPool();
|
||||
if (active_mempool) {
|
||||
mempool_id = active_mempool->id();
|
||||
}
|
||||
|
||||
// Free all non-split cached blocks to system allocator
|
||||
release_blocks(large_blocks, context);
|
||||
release_blocks(small_blocks, context);
|
||||
if (mempool_id.first == 0 && mempool_id.second == 0) {
|
||||
// If there is no active mempool, we work on releasing *all* blocks.
|
||||
|
||||
// First ensure that all blocks that can't currently be allocated due to
|
||||
// outstanding events are returned to the pool.
|
||||
synchronize_and_free_events(context);
|
||||
|
||||
// Free all non-split cached blocks to system allocator
|
||||
release_blocks(large_blocks, context);
|
||||
release_blocks(small_blocks, context);
|
||||
}
|
||||
|
||||
for (auto it = graph_pools_freeable.begin();
|
||||
it != graph_pools_freeable.end();) {
|
||||
if (mempool_id.first != 0 || mempool_id.second != 0) {
|
||||
if (it->first == mempool_id) {
|
||||
// If there is an active mempool, we sync only the events
|
||||
// associated with the pool
|
||||
synchronize_and_free_events(context, it->second);
|
||||
} else {
|
||||
// otherwise we move on
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// See notifyCaptureDestroy for the strategy here.
|
||||
TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
|
||||
release_blocks(it->second->small_blocks, context);
|
||||
|
|
@ -2886,10 +2907,21 @@ class DeviceCachingAllocator {
|
|||
block->device,
|
||||
context ? context : block->context_when_segment_allocated);
|
||||
|
||||
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
|
||||
auto* pool = block->pool;
|
||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) {
|
||||
// Ensure that active_pool and pool are the same
|
||||
auto pp = get_private_pool(active_pool->id());
|
||||
TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool);
|
||||
|
||||
// If there is an active mempool with a given allocator,
|
||||
// we use the given allocator's delete function.
|
||||
active_pool->allocator()->raw_delete((void*)block->ptr);
|
||||
} else {
|
||||
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
|
||||
}
|
||||
total_allocated_memory -= block->size;
|
||||
|
||||
auto* pool = block->pool;
|
||||
if (pool->owner_PrivatePool) {
|
||||
// The cudaFreed block belonged to a CUDA graph's PrivatePool.
|
||||
TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->cudaMalloc_count > 0);
|
||||
|
|
@ -3017,7 +3049,8 @@ class DeviceCachingAllocator {
|
|||
}
|
||||
|
||||
void synchronize_and_free_events(
|
||||
const std::shared_ptr<GatheredContext>& context) {
|
||||
const std::shared_ptr<GatheredContext>& context,
|
||||
PrivatePool* pool = nullptr) {
|
||||
// Synchronize on outstanding events and then free associated blocks.
|
||||
stats.num_sync_all_streams++;
|
||||
|
||||
|
|
@ -3026,10 +3059,18 @@ class DeviceCachingAllocator {
|
|||
TORCH_INTERNAL_ASSERT(captures_underway.empty());
|
||||
insert_events_deferred_until_no_capture(context);
|
||||
|
||||
for (auto& st : cuda_events) {
|
||||
for (auto& e : st.second) {
|
||||
EventPool::Event event = std::move(e.first);
|
||||
Block* block = e.second;
|
||||
for (auto it = cuda_events.begin(); it != cuda_events.end();) {
|
||||
for (auto e = it->second.begin(); e != it->second.end();) {
|
||||
Block* block = e->second;
|
||||
|
||||
// If a pool was passed, only synchronize the events
|
||||
// that are associated with the pool, otherwise move on
|
||||
if (pool && block->pool->owner_PrivatePool != pool) {
|
||||
++e;
|
||||
continue;
|
||||
}
|
||||
|
||||
EventPool::Event event = std::move(e->first);
|
||||
|
||||
C10_CUDA_CHECK(cudaEventSynchronize(*event));
|
||||
|
||||
|
|
@ -3037,10 +3078,18 @@ class DeviceCachingAllocator {
|
|||
if (block->event_count == 0) {
|
||||
free_block(block, context);
|
||||
}
|
||||
// We are done with the event, so erase it from the deque
|
||||
e = it->second.erase(e);
|
||||
}
|
||||
|
||||
// If the events deque is empty, only then erase the
|
||||
// cuda event from the events map
|
||||
if (it->second.empty()) {
|
||||
it = cuda_events.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
cuda_events.clear();
|
||||
}
|
||||
|
||||
void remove_cudagraph_stream_uses(Block* block) {
|
||||
|
|
@ -3922,6 +3971,8 @@ MemPool::MemPool(
|
|||
MemPool::~MemPool() {
|
||||
TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
auto ctx = MemPoolContext(this);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache();
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
|
|
|
|||
|
|
@ -4544,10 +4544,14 @@ class TestMemPool(TestCase):
|
|||
# holds a reference
|
||||
self.assertEqual(pool.use_count(), 1)
|
||||
|
||||
# no allocations happened yet, so called_dummy_alloc should be 0
|
||||
# no allocations happened yet, so called_dummy_alloc and
|
||||
# called_dummy_free should be 0
|
||||
alloc_lib = ctypes.CDLL(dummy_allocator)
|
||||
called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc")
|
||||
called_dummy_free = ctypes.c_int.in_dll(alloc_lib, "called_dummy_free")
|
||||
self.assertEqual(called_dummy_alloc.value, 0)
|
||||
self.assertEqual(called_dummy_free.value, 0)
|
||||
|
||||
nelem_1mb = 1024 * 1024 // 4
|
||||
|
||||
with torch.cuda.use_mem_pool(pool):
|
||||
|
|
@ -4582,6 +4586,15 @@ class TestMemPool(TestCase):
|
|||
# to make a new 2 MB buffer to accomodate out_2
|
||||
self.assertEqual(len(pool.snapshot()), 2)
|
||||
|
||||
del out_0, out_1, out_2
|
||||
|
||||
# pool's destructor calls emptyCache()
|
||||
del pool
|
||||
|
||||
# called_dummy_free should be 321 if dummy_free was used to deallocate
|
||||
# out tensor
|
||||
self.assertEqual(called_dummy_free.value, 321)
|
||||
|
||||
def test_mempool_context(self):
|
||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user