mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add option to use mempool on OOM (#151487)
MemPool is a separate pool of memory handled by the caching allocator. This PR adds the option let the caching allocator try to use this pool as a last resort instead of OOMing by associating a use_on_oom bool with each MemPool.
Usage:
Users can optionally specify a ``use_on_oom`` bool (which is False by default) during MemPool creation. If true, then the CUDACachingAllocator will be able to use memory in this pool as a last resort instead of OOMing.
```
pool = torch.cuda.MemPool(allocator, use_on_oom=True)
with torch.cuda.use_mem_pool(pool):
a = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
del a
# at the memory limit, this will succeed by using pool's memory in order to avoid the oom
b = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
```
Testing:
```
python test/test_cuda.py -k test_mempool_limited_memory_with_allocator
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151487
Approved by: https://github.com/eqy, https://github.com/syed-ahmed, https://github.com/ngimel
This commit is contained in:
parent
65b845f82b
commit
d22c4cc353
|
|
@ -1070,6 +1070,9 @@ class DeviceCachingAllocator {
|
|||
std::vector<std::pair<MempoolId_t, std::function<bool(cudaStream_t)>>>
|
||||
captures_underway;
|
||||
|
||||
// tracks which pools we can use as a last resort before ooming
|
||||
ska::flat_hash_set<MempoolId_t, MempoolIdHash> use_on_oom_pools;
|
||||
|
||||
// See free() for this thing's purpose
|
||||
std::vector<Block*> needs_events_deferred_until_no_capture;
|
||||
// outstanding cuda events
|
||||
|
|
@ -1257,6 +1260,30 @@ class DeviceCachingAllocator {
|
|||
alloc_block(params, true, context, lock));
|
||||
}
|
||||
|
||||
// we are about to oom, try to use existing mempools as a last resort
|
||||
if (!block_found && params.err == cudaErrorMemoryAllocation) {
|
||||
// if already trying to use a mempool, then just oom
|
||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
if (!active_pool) {
|
||||
for (MempoolId_t mempool_id : use_on_oom_pools) {
|
||||
auto filter = [](cudaStream_t) { return true; };
|
||||
beginAllocateToPool(mempool_id, filter);
|
||||
auto& pool = get_pool(size, stream);
|
||||
const size_t alloc_size = get_allocation_size(size);
|
||||
AllocParams mempool_params(
|
||||
device, size, stream, &pool, alloc_size, stats);
|
||||
mempool_params.stat_types = get_stat_types_for_pool(pool);
|
||||
block_found = get_free_block(mempool_params);
|
||||
endAllocateToPool(mempool_id);
|
||||
releasePool(mempool_id);
|
||||
if (block_found) {
|
||||
params = mempool_params;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!block_found) {
|
||||
// For any error code other than cudaErrorMemoryAllocation,
|
||||
// alloc_block should have thrown an exception already.
|
||||
|
|
@ -2087,6 +2114,14 @@ class DeviceCachingAllocator {
|
|||
ensure_exists_and_incref_pool(mempool_id);
|
||||
}
|
||||
|
||||
void setUseOnOOM(bool use_on_oom, MempoolId_t mempool_id) {
|
||||
// Choose if this pool should be used as a last resort before ooming
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
if (use_on_oom) {
|
||||
use_on_oom_pools.insert(mempool_id);
|
||||
}
|
||||
}
|
||||
|
||||
// See Note [Interaction with CUDA graph capture]
|
||||
|
||||
// Called by CUDAGraph::capture_begin
|
||||
|
|
@ -3714,6 +3749,14 @@ class NativeCachingAllocator : public CUDAAllocator {
|
|||
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id));
|
||||
}
|
||||
|
||||
void setUseOnOOM(
|
||||
c10::DeviceIndex device,
|
||||
bool use_on_oom,
|
||||
MempoolId_t mempool_id) override {
|
||||
assertValidDevice(device);
|
||||
device_allocator[device]->setUseOnOOM(use_on_oom, std::move(mempool_id));
|
||||
}
|
||||
|
||||
// CUDAGraph interactions
|
||||
void beginAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
|
|
@ -4042,7 +4085,8 @@ std::atomic<CaptureId_t> MemPool::uuid_{1};
|
|||
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created)
|
||||
bool is_user_created,
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
|
|
@ -4051,6 +4095,7 @@ MemPool::MemPool(
|
|||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_);
|
||||
CUDACachingAllocator::setUseOnOOM(device_, use_on_oom, id_);
|
||||
}
|
||||
|
||||
MemPool::~MemPool() {
|
||||
|
|
|
|||
|
|
@ -243,6 +243,17 @@ class CUDAAllocator : public Allocator {
|
|||
" does not yet support ensureExistsAndIncrefPool. "
|
||||
"If you need it, please file an issue describing your use case.");
|
||||
}
|
||||
virtual void setUseOnOOM(
|
||||
c10::DeviceIndex device,
|
||||
bool use_on_oom,
|
||||
MempoolId_t mempool_id) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
name(),
|
||||
" does not yet support setUseOnOOM. "
|
||||
"If you need it, please file an issue describing your use case.");
|
||||
}
|
||||
|
||||
// returns true if the allocated blocks are equal to expected live allocations
|
||||
virtual bool checkPoolLiveAllocations(
|
||||
c10::DeviceIndex device,
|
||||
|
|
@ -457,6 +468,12 @@ inline void ensureExistsAndIncrefPool(
|
|||
MempoolId_t mempool_id) {
|
||||
get()->ensureExistsAndIncrefPool(device, mempool_id);
|
||||
}
|
||||
inline void setUseOnOOM(
|
||||
c10::DeviceIndex device,
|
||||
bool use_on_oom,
|
||||
MempoolId_t mempool_id) {
|
||||
get()->setUseOnOOM(device, use_on_oom, mempool_id);
|
||||
}
|
||||
|
||||
inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return get()->getPoolUseCount(device, mempool_id);
|
||||
|
|
@ -506,7 +523,8 @@ namespace c10::cuda {
|
|||
struct C10_CUDA_API MemPool {
|
||||
MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true);
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
|
|
|
|||
|
|
@ -770,6 +770,21 @@ be called internally on deletion of the pool, hence returning all the memory to
|
|||
del tensor, del pool
|
||||
|
||||
|
||||
Users can optionally specify a ``use_on_oom`` bool (which is False by default) during MemPool
|
||||
creation. If true, then the CUDACachingAllocator will be able to use memory in this pool as
|
||||
a last resort instead of OOMing.
|
||||
|
||||
.. code:: python
|
||||
|
||||
pool = torch.cuda.MemPool(allocator, use_on_oom=True)
|
||||
with torch.cuda.use_mem_pool(pool):
|
||||
a = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
|
||||
del a
|
||||
|
||||
# at the memory limit, this will succeed by using pool's memory in order to avoid the oom
|
||||
b = torch.randn(40 * 1024 * 1024, dtype=torch.uint8, device="cuda")
|
||||
|
||||
|
||||
The following :meth:`torch.cuda.MemPool.use_count` and :meth:`torch.cuda.MemPool.snapshot`
|
||||
APIs can be used for debugging purposes:
|
||||
|
||||
|
|
|
|||
|
|
@ -4920,6 +4920,25 @@ class TestBlockStateAbsorption(TestCase):
|
|||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
|
||||
class TestMemPool(TestCase):
|
||||
def _setup_mempool_limited_memory_test(self, additional_allowed_memory_in_mb):
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
self.init_fraction = torch.cuda.get_per_process_memory_fraction()
|
||||
torch.cuda.memory.empty_cache()
|
||||
mb = 1024 * 1024
|
||||
_, all_memory = torch.cuda.memory.mem_get_info(device)
|
||||
pre_reserved = torch.cuda.memory_reserved(device)
|
||||
total_allowed = additional_allowed_memory_in_mb * mb + pre_reserved
|
||||
fraction_allowed = total_allowed / all_memory
|
||||
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed, device)
|
||||
|
||||
dtype = torch.int8
|
||||
return device, dtype
|
||||
|
||||
def _teardown_mempool_limited_memory_test(self):
|
||||
torch.cuda.memory.empty_cache()
|
||||
torch.cuda.memory.set_per_process_memory_fraction(self.init_fraction)
|
||||
|
||||
def test_mempool_id(self):
|
||||
pool1 = torch.cuda.graph_pool_handle()
|
||||
pool2 = torch.cuda.MemPool().id
|
||||
|
|
@ -5036,6 +5055,110 @@ class TestMemPool(TestCase):
|
|||
# out tensor
|
||||
self.assertEqual(called_dummy_free.value, 321)
|
||||
|
||||
@serialTest()
|
||||
def test_mempool_limited_memory_with_allocator(self):
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
|
||||
dummy_allocator_source = """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
extern "C" {
|
||||
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
|
||||
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
|
||||
void* ptr;
|
||||
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
|
||||
C10_CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
}
|
||||
"""
|
||||
dummy_allocator_libname = "dummy_allocator"
|
||||
dummy_allocator = load_inline(
|
||||
name=dummy_allocator_libname,
|
||||
cpp_sources=dummy_allocator_source,
|
||||
is_python_module=False,
|
||||
keep_intermediates=False,
|
||||
verbose=True,
|
||||
with_cuda=True,
|
||||
)
|
||||
allocator = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
dummy_allocator,
|
||||
"dummy_alloc",
|
||||
"dummy_free",
|
||||
)
|
||||
pool_do_not_use = torch.cuda.MemPool(allocator.allocator())
|
||||
pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
|
||||
|
||||
nelem_1mb = 1024 * 1024 // 4
|
||||
|
||||
self._setup_mempool_limited_memory_test(80)
|
||||
# remaining free mem: 80 mb
|
||||
# mempool_use [] 0 mb
|
||||
# mempool_do_not_use [] 0 mb
|
||||
# default pool [] 0 mb
|
||||
with torch.cuda.use_mem_pool(pool_do_not_use):
|
||||
a = torch.randn(40 * nelem_1mb, device="cuda")
|
||||
with torch.cuda.use_mem_pool(pool_use):
|
||||
b = torch.randn(40 * nelem_1mb, device="cuda")
|
||||
a_dataptr = a.data_ptr()
|
||||
b_dataptr = b.data_ptr()
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [aaaa] 40 mb
|
||||
# mempool_use [bbbb] 40 mb
|
||||
# default pool [] 0 mb
|
||||
with self.assertRaises(torch.OutOfMemoryError):
|
||||
# out of memory
|
||||
c = torch.randn(40 * nelem_1mb, device="cuda")
|
||||
|
||||
del a, b
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [____] 40 mb
|
||||
# mempool_use [____] 40 mb
|
||||
# default pool [] 0 mb
|
||||
|
||||
# c should not oom and instead can use mempool_use as fallback
|
||||
c = torch.randn(30 * nelem_1mb, device="cuda")
|
||||
c_dataptr = c.data_ptr()
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [____] 40 mb
|
||||
# mempool_use [ccc_] 40 mb
|
||||
# default pool [] 0 mb
|
||||
with self.assertRaises(torch.OutOfMemoryError):
|
||||
# out of memory since can't use mempool_do_not_use
|
||||
d = torch.randn(30 * nelem_1mb, device="cuda")
|
||||
|
||||
del c
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [____] 40 mb
|
||||
# mempool_use [____] 40 mb
|
||||
# default pool [] 0 mb
|
||||
|
||||
# expect that we used same memory address for both a and c
|
||||
self.assertEqual(b_dataptr, c_dataptr)
|
||||
|
||||
# make sure we can still use mempool_use as intended after c is deleted
|
||||
with torch.cuda.use_mem_pool(pool_use):
|
||||
e = torch.randn(20 * nelem_1mb, device="cuda")
|
||||
# remaining free mem: 0 mb
|
||||
# mempool_do_not_use [____] 40 mb
|
||||
# mempool_use [ee__] 40 mb
|
||||
# default pool [] 0 mb
|
||||
|
||||
e_dataptr = e.data_ptr()
|
||||
del e
|
||||
|
||||
self.assertEqual(e_dataptr, c_dataptr)
|
||||
|
||||
# pool's destructor calls emptyCache()
|
||||
del pool_use, pool_do_not_use
|
||||
|
||||
self._teardown_mempool_limited_memory_test()
|
||||
|
||||
def test_mempool_context(self):
|
||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
||||
|
||||
|
|
|
|||
|
|
@ -2182,7 +2182,7 @@ class _CUDAGraph:
|
|||
|
||||
# Defined in torch/csrc/cuda/MemPool.cpp
|
||||
class _MemPool:
|
||||
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None, is_user_created: _bool = True) -> None: ...
|
||||
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None, is_user_created: _bool = True, use_on_oom: _bool = False) -> None: ...
|
||||
@property
|
||||
def id(self) -> Tuple[_int, _int]: ...
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -9,15 +9,17 @@
|
|||
template <typename T>
|
||||
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
||||
|
||||
// NOLINTNEXTLINE(misc-use-internal-linkage)
|
||||
void THCPMemPool_init(PyObject* module) {
|
||||
auto torch_C_m = py::handle(module).cast<py::module>();
|
||||
shared_ptr_class_<::c10::cuda::MemPool>(torch_C_m, "_MemPool")
|
||||
.def(
|
||||
py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created) {
|
||||
bool is_user_created,
|
||||
bool use_on_oom) {
|
||||
torch::utils::device_lazy_init(at::kCUDA);
|
||||
return std::make_shared<::c10::cuda::MemPool>(
|
||||
allocator, is_user_created);
|
||||
allocator, is_user_created, use_on_oom);
|
||||
}))
|
||||
.def_property_readonly("id", &::c10::cuda::MemPool::id)
|
||||
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
|
||||
|
|
|
|||
|
|
@ -1129,11 +1129,18 @@ class MemPool(_MemPool):
|
|||
define how memory gets allocated in the pool. If :attr:`allocator`
|
||||
is ``None`` (default), memory allocation follows the default/
|
||||
current configuration of the CUDACachingAllocator.
|
||||
use_on_oom(bool): a bool that indicates if this pool can be used
|
||||
as a last resort if a memory allocation outside of the pool fails due
|
||||
to Out Of Memory. This is False by default.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None):
|
||||
super().__init__(allocator, True)
|
||||
def __init__(
|
||||
self,
|
||||
allocator: Optional[_cuda_CUDAAllocator] = None,
|
||||
use_on_oom: bool = False,
|
||||
):
|
||||
super().__init__(allocator, True, use_on_oom)
|
||||
|
||||
@property
|
||||
def id(self) -> tuple[int, int]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user