Add option to use mempool on OOM (#151487)

MemPool is a separate pool of memory handled by the caching allocator. This PR adds the option let the caching allocator try to use this pool as a last resort instead of OOMing by associating a use_on_oom bool with each MemPool.

Usage:
Users can optionally specify a ``use_on_oom`` bool (which is False by default) during MemPool creation. If true, then the CUDACachingAllocator will be able to use memory in this pool as a last resort instead of OOMing.

```
pool = torch.cuda.MemPool(allocator, use_on_oom=True)
with torch.cuda.use_mem_pool(pool):
    a = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
del a
# at the memory limit, this will succeed by using pool's memory in order to avoid the oom
b = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
```

Testing:
```
python test/test_cuda.py -k test_mempool_limited_memory_with_allocator
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151487
Approved by: https://github.com/eqy, https://github.com/syed-ahmed, https://github.com/ngimel
This commit is contained in:
Dan Johnson 2025-04-25 17:52:27 -07:00 committed by PyTorch MergeBot
parent 65b845f82b
commit d22c4cc353
7 changed files with 217 additions and 7 deletions

View File

@ -1070,6 +1070,9 @@ class DeviceCachingAllocator {
std::vector<std::pair<MempoolId_t, std::function<bool(cudaStream_t)>>>
captures_underway;
// tracks which pools we can use as a last resort before ooming
ska::flat_hash_set<MempoolId_t, MempoolIdHash> use_on_oom_pools;
// See free() for this thing's purpose
std::vector<Block*> needs_events_deferred_until_no_capture;
// outstanding cuda events
@ -1257,6 +1260,30 @@ class DeviceCachingAllocator {
alloc_block(params, true, context, lock));
}
// we are about to oom, try to use existing mempools as a last resort
if (!block_found && params.err == cudaErrorMemoryAllocation) {
// if already trying to use a mempool, then just oom
auto active_pool = MemPoolContext::getActiveMemPool();
if (!active_pool) {
for (MempoolId_t mempool_id : use_on_oom_pools) {
auto filter = [](cudaStream_t) { return true; };
beginAllocateToPool(mempool_id, filter);
auto& pool = get_pool(size, stream);
const size_t alloc_size = get_allocation_size(size);
AllocParams mempool_params(
device, size, stream, &pool, alloc_size, stats);
mempool_params.stat_types = get_stat_types_for_pool(pool);
block_found = get_free_block(mempool_params);
endAllocateToPool(mempool_id);
releasePool(mempool_id);
if (block_found) {
params = mempool_params;
break;
}
}
}
}
if (!block_found) {
// For any error code other than cudaErrorMemoryAllocation,
// alloc_block should have thrown an exception already.
@ -2087,6 +2114,14 @@ class DeviceCachingAllocator {
ensure_exists_and_incref_pool(mempool_id);
}
void setUseOnOOM(bool use_on_oom, MempoolId_t mempool_id) {
// Choose if this pool should be used as a last resort before ooming
std::lock_guard<std::recursive_mutex> lock(mutex);
if (use_on_oom) {
use_on_oom_pools.insert(mempool_id);
}
}
// See Note [Interaction with CUDA graph capture]
// Called by CUDAGraph::capture_begin
@ -3714,6 +3749,14 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id));
}
void setUseOnOOM(
c10::DeviceIndex device,
bool use_on_oom,
MempoolId_t mempool_id) override {
assertValidDevice(device);
device_allocator[device]->setUseOnOOM(use_on_oom, std::move(mempool_id));
}
// CUDAGraph interactions
void beginAllocateToPool(
c10::DeviceIndex device,
@ -4042,7 +4085,8 @@ std::atomic<CaptureId_t> MemPool::uuid_{1};
MemPool::MemPool(
CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created)
bool is_user_created,
bool use_on_oom)
: allocator_(allocator), is_user_created_(is_user_created) {
if (is_user_created_) {
id_ = {0, uid_++};
@ -4051,6 +4095,7 @@ MemPool::MemPool(
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_);
CUDACachingAllocator::setUseOnOOM(device_, use_on_oom, id_);
}
MemPool::~MemPool() {

View File

@ -243,6 +243,17 @@ class CUDAAllocator : public Allocator {
" does not yet support ensureExistsAndIncrefPool. "
"If you need it, please file an issue describing your use case.");
}
virtual void setUseOnOOM(
c10::DeviceIndex device,
bool use_on_oom,
MempoolId_t mempool_id) {
TORCH_CHECK(
false,
name(),
" does not yet support setUseOnOOM. "
"If you need it, please file an issue describing your use case.");
}
// returns true if the allocated blocks are equal to expected live allocations
virtual bool checkPoolLiveAllocations(
c10::DeviceIndex device,
@ -457,6 +468,12 @@ inline void ensureExistsAndIncrefPool(
MempoolId_t mempool_id) {
get()->ensureExistsAndIncrefPool(device, mempool_id);
}
inline void setUseOnOOM(
c10::DeviceIndex device,
bool use_on_oom,
MempoolId_t mempool_id) {
get()->setUseOnOOM(device, use_on_oom, mempool_id);
}
inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->getPoolUseCount(device, mempool_id);
@ -506,7 +523,8 @@ namespace c10::cuda {
struct C10_CUDA_API MemPool {
MemPool(
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true);
bool is_user_created = true,
bool use_on_oom = false);
MemPool(const MemPool&) = delete;
MemPool(MemPool&&) = default;
MemPool& operator=(const MemPool&) = delete;

View File

@ -770,6 +770,21 @@ be called internally on deletion of the pool, hence returning all the memory to
del tensor, del pool
Users can optionally specify a ``use_on_oom`` bool (which is False by default) during MemPool
creation. If true, then the CUDACachingAllocator will be able to use memory in this pool as
a last resort instead of OOMing.
.. code:: python
pool = torch.cuda.MemPool(allocator, use_on_oom=True)
with torch.cuda.use_mem_pool(pool):
a = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
del a
# at the memory limit, this will succeed by using pool's memory in order to avoid the oom
b = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
The following :meth:`torch.cuda.MemPool.use_count` and :meth:`torch.cuda.MemPool.snapshot`
APIs can be used for debugging purposes:

View File

@ -4920,6 +4920,25 @@ class TestBlockStateAbsorption(TestCase):
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
class TestMemPool(TestCase):
def _setup_mempool_limited_memory_test(self, additional_allowed_memory_in_mb):
device = torch.device("cuda:0")
self.init_fraction = torch.cuda.get_per_process_memory_fraction()
torch.cuda.memory.empty_cache()
mb = 1024 * 1024
_, all_memory = torch.cuda.memory.mem_get_info(device)
pre_reserved = torch.cuda.memory_reserved(device)
total_allowed = additional_allowed_memory_in_mb * mb + pre_reserved
fraction_allowed = total_allowed / all_memory
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed, device)
dtype = torch.int8
return device, dtype
def _teardown_mempool_limited_memory_test(self):
torch.cuda.memory.empty_cache()
torch.cuda.memory.set_per_process_memory_fraction(self.init_fraction)
def test_mempool_id(self):
pool1 = torch.cuda.graph_pool_handle()
pool2 = torch.cuda.MemPool().id
@ -5036,6 +5055,110 @@ class TestMemPool(TestCase):
# out tensor
self.assertEqual(called_dummy_free.value, 321)
@serialTest()
def test_mempool_limited_memory_with_allocator(self):
from torch.utils.cpp_extension import load_inline
dummy_allocator_source = """
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
extern "C" {
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr;
}
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
C10_CUDA_CHECK(cudaFree(ptr));
}
}
"""
dummy_allocator_libname = "dummy_allocator"
dummy_allocator = load_inline(
name=dummy_allocator_libname,
cpp_sources=dummy_allocator_source,
is_python_module=False,
keep_intermediates=False,
verbose=True,
with_cuda=True,
)
allocator = torch.cuda.memory.CUDAPluggableAllocator(
dummy_allocator,
"dummy_alloc",
"dummy_free",
)
pool_do_not_use = torch.cuda.MemPool(allocator.allocator())
pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
nelem_1mb = 1024 * 1024 // 4
self._setup_mempool_limited_memory_test(80)
# remaining free mem: 80 mb
# mempool_use [] 0 mb
# mempool_do_not_use [] 0 mb
# default pool [] 0 mb
with torch.cuda.use_mem_pool(pool_do_not_use):
a = torch.randn(40 * nelem_1mb, device="cuda")
with torch.cuda.use_mem_pool(pool_use):
b = torch.randn(40 * nelem_1mb, device="cuda")
a_dataptr = a.data_ptr()
b_dataptr = b.data_ptr()
# remaining free mem: 0 mb
# mempool_do_not_use [aaaa] 40 mb
# mempool_use [bbbb] 40 mb
# default pool [] 0 mb
with self.assertRaises(torch.OutOfMemoryError):
# out of memory
c = torch.randn(40 * nelem_1mb, device="cuda")
del a, b
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [____] 40 mb
# default pool [] 0 mb
# c should not oom and instead can use mempool_use as fallback
c = torch.randn(30 * nelem_1mb, device="cuda")
c_dataptr = c.data_ptr()
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [ccc_] 40 mb
# default pool [] 0 mb
with self.assertRaises(torch.OutOfMemoryError):
# out of memory since can't use mempool_do_not_use
d = torch.randn(30 * nelem_1mb, device="cuda")
del c
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [____] 40 mb
# default pool [] 0 mb
# expect that we used same memory address for both a and c
self.assertEqual(b_dataptr, c_dataptr)
# make sure we can still use mempool_use as intended after c is deleted
with torch.cuda.use_mem_pool(pool_use):
e = torch.randn(20 * nelem_1mb, device="cuda")
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [ee__] 40 mb
# default pool [] 0 mb
e_dataptr = e.data_ptr()
del e
self.assertEqual(e_dataptr, c_dataptr)
# pool's destructor calls emptyCache()
del pool_use, pool_do_not_use
self._teardown_mempool_limited_memory_test()
def test_mempool_context(self):
active_pool = torch.cuda.MemPoolContext.active_pool()

View File

@ -2182,7 +2182,7 @@ class _CUDAGraph:
# Defined in torch/csrc/cuda/MemPool.cpp
class _MemPool:
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None, is_user_created: _bool = True) -> None: ...
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None, is_user_created: _bool = True, use_on_oom: _bool = False) -> None: ...
@property
def id(self) -> Tuple[_int, _int]: ...
@property

View File

@ -9,15 +9,17 @@
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
// NOLINTNEXTLINE(misc-use-internal-linkage)
void THCPMemPool_init(PyObject* module) {
auto torch_C_m = py::handle(module).cast<py::module>();
shared_ptr_class_<::c10::cuda::MemPool>(torch_C_m, "_MemPool")
.def(
py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created) {
bool is_user_created,
bool use_on_oom) {
torch::utils::device_lazy_init(at::kCUDA);
return std::make_shared<::c10::cuda::MemPool>(
allocator, is_user_created);
allocator, is_user_created, use_on_oom);
}))
.def_property_readonly("id", &::c10::cuda::MemPool::id)
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)

View File

@ -1129,11 +1129,18 @@ class MemPool(_MemPool):
define how memory gets allocated in the pool. If :attr:`allocator`
is ``None`` (default), memory allocation follows the default/
current configuration of the CUDACachingAllocator.
use_on_oom(bool): a bool that indicates if this pool can be used
as a last resort if a memory allocation outside of the pool fails due
to Out Of Memory. This is False by default.
"""
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None):
super().__init__(allocator, True)
def __init__(
self,
allocator: Optional[_cuda_CUDAAllocator] = None,
use_on_oom: bool = False,
):
super().__init__(allocator, True, use_on_oom)
@property
def id(self) -> tuple[int, int]: