mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
expandable_segments <-> other allocator options (#134338)
Previously setting garbage_collection_threshold or max_split_size_mb along with expandable_segments:True could cause the allocator to hit assert failures when running nearly out of memory. This PR ensures garbage_collection and max_split freeing do not accidentally try to release expandable segments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134338 Approved by: https://github.com/ezyang
This commit is contained in:
parent
3fc6e47d42
commit
d91b49dbaa
|
|
@ -2611,7 +2611,7 @@ class DeviceCachingAllocator {
|
||||||
while (it != large_blocks.blocks.end()) {
|
while (it != large_blocks.blocks.end()) {
|
||||||
Block* block = *it;
|
Block* block = *it;
|
||||||
++it;
|
++it;
|
||||||
if (!block->is_split() &&
|
if (!block->is_split() && !block->expandable_segment_ &&
|
||||||
static_cast<double>(block->gc_count()) >= age_threshold) {
|
static_cast<double>(block->gc_count()) >= age_threshold) {
|
||||||
block_freed = true;
|
block_freed = true;
|
||||||
gc_reclaimed += block->size;
|
gc_reclaimed += block->size;
|
||||||
|
|
@ -2754,7 +2754,8 @@ class DeviceCachingAllocator {
|
||||||
? CUDAAllocatorConfig::max_split_size()
|
? CUDAAllocatorConfig::max_split_size()
|
||||||
: key.size;
|
: key.size;
|
||||||
auto it = pool.blocks.lower_bound(&key);
|
auto it = pool.blocks.lower_bound(&key);
|
||||||
if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
|
if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
|
||||||
|
(*it)->expandable_segment_) {
|
||||||
// No single block is large enough; free multiple oversize blocks,
|
// No single block is large enough; free multiple oversize blocks,
|
||||||
// starting with the largest
|
// starting with the largest
|
||||||
if (it == pool.blocks.begin())
|
if (it == pool.blocks.begin())
|
||||||
|
|
@ -2766,12 +2767,15 @@ class DeviceCachingAllocator {
|
||||||
((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
|
((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
|
||||||
((*it)->stream == p.stream())) {
|
((*it)->stream == p.stream())) {
|
||||||
auto cur = it;
|
auto cur = it;
|
||||||
totalReleased += (*it)->size;
|
bool is_first = cur == pool.blocks.begin();
|
||||||
if (it != pool.blocks.begin()) {
|
if (!is_first) {
|
||||||
--it;
|
--it;
|
||||||
|
}
|
||||||
|
if (!(*cur)->expandable_segment_) {
|
||||||
release_block(*cur, context);
|
release_block(*cur, context);
|
||||||
} else {
|
totalReleased += (*cur)->size;
|
||||||
release_block(*cur, context);
|
}
|
||||||
|
if (is_first) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4098,6 +4098,62 @@ class TestCudaMallocAsync(TestCase):
|
||||||
finally:
|
finally:
|
||||||
torch.cuda.memory._record_memory_history(None)
|
torch.cuda.memory._record_memory_history(None)
|
||||||
|
|
||||||
|
def test_max_split_expandable(self):
|
||||||
|
torch.cuda.memory.empty_cache()
|
||||||
|
mb = 1024 * 1024
|
||||||
|
_, all_memory = torch.cuda.memory.mem_get_info()
|
||||||
|
total_allowed = 120 * mb
|
||||||
|
fraction_allowed = total_allowed / all_memory
|
||||||
|
assert int(fraction_allowed * all_memory) == total_allowed
|
||||||
|
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
|
||||||
|
|
||||||
|
def alloc(n):
|
||||||
|
return torch.ones(n * mb, dtype=torch.int8, device="cuda")
|
||||||
|
|
||||||
|
torch.cuda.memory._set_allocator_settings(
|
||||||
|
"expandable_segments:False,max_split_size_mb:40"
|
||||||
|
)
|
||||||
|
a = alloc(40)
|
||||||
|
torch.cuda.memory._set_allocator_settings(
|
||||||
|
"expandable_segments:True,max_split_size_mb:40"
|
||||||
|
)
|
||||||
|
b = alloc(40)
|
||||||
|
torch.cuda.memory._set_allocator_settings(
|
||||||
|
"expandable_segments:False,max_split_size_mb:40"
|
||||||
|
)
|
||||||
|
c = alloc(40)
|
||||||
|
with self.assertRaises(torch.OutOfMemoryError):
|
||||||
|
alloc(40)
|
||||||
|
del a, b, c
|
||||||
|
# force release_cached_blocks to run with some expandable segments in the free list
|
||||||
|
alloc(120)
|
||||||
|
|
||||||
|
def test_garbage_collect_expandable(self):
|
||||||
|
torch.cuda.memory.empty_cache()
|
||||||
|
mb = 1024 * 1024
|
||||||
|
_, all_memory = torch.cuda.memory.mem_get_info()
|
||||||
|
total_allowed = 120 * mb
|
||||||
|
fraction_allowed = total_allowed / all_memory
|
||||||
|
assert int(fraction_allowed * all_memory) == total_allowed
|
||||||
|
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
|
||||||
|
|
||||||
|
def alloc(n):
|
||||||
|
return torch.ones(n * mb, dtype=torch.int8, device="cuda")
|
||||||
|
|
||||||
|
torch.cuda.memory._set_allocator_settings(
|
||||||
|
"expandable_segments:False,garbage_collection_threshold:0.5"
|
||||||
|
)
|
||||||
|
a = alloc(40)
|
||||||
|
torch.cuda.memory._set_allocator_settings(
|
||||||
|
"expandable_segments:True,garbage_collection_threshold:0.5"
|
||||||
|
)
|
||||||
|
b = alloc(40)
|
||||||
|
del a, b
|
||||||
|
# causes GC to run. The expandable segment block will be split
|
||||||
|
# so GC would not attempt to free it anyway, but this at least makes sure
|
||||||
|
# expandable_segment blocks can be in the free list when this is called.
|
||||||
|
alloc(80)
|
||||||
|
|
||||||
def test_allocator_settings(self):
|
def test_allocator_settings(self):
|
||||||
def power2_div(size, div_factor):
|
def power2_div(size, div_factor):
|
||||||
pow2 = 1
|
pow2 = 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user