mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[CUDA graphs] Allows DeviceCachingAllocator to capture cross-stream memory use (#55860)
Summary: Safely deallocating and repurposing memory used across streams relies on recording end-of-life events in all an allocation's usage streams beyond its original allocation stream. The events are later queried to see if all GPU work in those extra streams that could have used the allocation is done (from the CPU's perspective) before repurposing the allocation for use in its original stream. The trouble is, calling EventQuery on an ordinary event recorded in a capturing stream is illegal. Calling EventQuery while capture is underway is also illegal. So when we call `tensor.record_stream` (or `c10::cuda::cudaCachingAllocator::recordStream`) on any tensor that's used or deleted in or around a capture, we often end up with a confusing error thrown from the cudaEventQuery in DeviceCachingAllocator::process_events(). This PR enables hopefully-safe deletion of tensors used across streams in or around capture with a conservative but simple approach: don't record or process end of life events for such tensors until the allocator's sure no captures are underway. You could whiteboard cases where this causes cross-stream-used allocations to be unavailable for reuse longer than absolutely necessary, but cross-stream-used allocations are uncommon, so for practical purposes this approach's impact on the memory footprint of captured sequences should be small. Pull Request resolved: https://github.com/pytorch/pytorch/pull/55860 Reviewed By: ejguan Differential Revision: D27822557 Pulled By: ezyang fbshipit-source-id: b2e18a19d83ed05bad67a8157a14a606ed14d04e
This commit is contained in:
parent
3e42da09df
commit
ffdecc1ac4
|
|
@ -287,6 +287,8 @@ class DeviceCachingAllocator {
|
|||
// Most of the time it's zero, in which case malloc can avoid calling
|
||||
// cudaStreamGetCaptureInfo in the hot path.
|
||||
int captures_underway = 0;
|
||||
// See free() for this thing's purpose
|
||||
std::vector<Block*> needs_events_deferred_until_no_capture;
|
||||
// outstanding cuda events
|
||||
std::deque<std::pair<cudaEvent_t, Block*>> cuda_events;
|
||||
|
||||
|
|
@ -323,8 +325,17 @@ class DeviceCachingAllocator {
|
|||
{
|
||||
std::unique_lock<std::recursive_mutex> lock(mutex);
|
||||
|
||||
// process outstanding cudaEvents
|
||||
process_events();
|
||||
if (C10_LIKELY(captures_underway == 0)) {
|
||||
// Processes end-of-life events for outstanding allocations used on multiple streams
|
||||
// (checks if their GPU-side uses are complete and recycles their memory if so)
|
||||
//
|
||||
// Q. Why skip process_events if a capture might be underway?
|
||||
// A. process_events involves cudaEventQueries, illegal during CUDA graph capture.
|
||||
// Dumb simple solution: defer reclaiming these allocations until after capture.
|
||||
// Cross-stream memory use is uncommon, so the deferral's effect on memory use
|
||||
// during capture should be small.
|
||||
process_events();
|
||||
}
|
||||
|
||||
size = round_size(size);
|
||||
auto& pool = get_pool(size, stream);
|
||||
|
|
@ -458,7 +469,14 @@ class DeviceCachingAllocator {
|
|||
update_stat_array(stats.allocated_bytes, -block->size, {stat_types});
|
||||
|
||||
if (!block->stream_uses.empty()) {
|
||||
insert_events(block);
|
||||
if (C10_UNLIKELY(captures_underway)) {
|
||||
// It's forbidden to cudaEventQuery an event recorded during CUDA graph capture.
|
||||
// We conservatively defer recording end-of-life events until the next call to
|
||||
// process_events() (which won't happen until no captures are underway)
|
||||
needs_events_deferred_until_no_capture.push_back(block);
|
||||
} else {
|
||||
insert_events(block);
|
||||
}
|
||||
} else {
|
||||
free_block(block);
|
||||
}
|
||||
|
|
@ -588,7 +606,7 @@ class DeviceCachingAllocator {
|
|||
|
||||
block_info.size = block->size;
|
||||
block_info.allocated = block->allocated;
|
||||
block_info.active = block->allocated || (block->event_count > 0);
|
||||
block_info.active = block->allocated || (block->event_count > 0) || !block->stream_uses.empty();
|
||||
|
||||
segment_info.total_size += block_info.size;
|
||||
if (block_info.allocated) {
|
||||
|
|
@ -689,7 +707,7 @@ class DeviceCachingAllocator {
|
|||
/** moves a block into a pool of cached free blocks */
|
||||
void free_block(Block* block)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(!block->allocated && block->event_count == 0);
|
||||
TORCH_INTERNAL_ASSERT(!block->allocated && block->event_count == 0 && block->stream_uses.empty());
|
||||
|
||||
size_t original_block_size = block->size;
|
||||
|
||||
|
|
@ -728,7 +746,7 @@ class DeviceCachingAllocator {
|
|||
/** combine previously split blocks. returns the size of the subsumed block, or 0 on failure. */
|
||||
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool)
|
||||
{
|
||||
if (!src || src->allocated || src->event_count > 0) {
|
||||
if (!src || src->allocated || src->event_count > 0 || !src->stream_uses.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -749,7 +767,8 @@ class DeviceCachingAllocator {
|
|||
|
||||
const size_t subsumed_size = src->size;
|
||||
dst->size += subsumed_size;
|
||||
pool.blocks.erase(src);
|
||||
auto erased = pool.blocks.erase(src);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
|
||||
delete src;
|
||||
|
||||
return subsumed_size;
|
||||
|
|
@ -946,6 +965,11 @@ class DeviceCachingAllocator {
|
|||
void synchronize_and_free_events() {
|
||||
// Synchronize on outstanding events and then free associated blocks.
|
||||
|
||||
// This function syncs, so capture should not be underway. Might as well
|
||||
// make sure capture-deferred end of life events get processed too.
|
||||
TORCH_INTERNAL_ASSERT(captures_underway == 0);
|
||||
insert_events_deferred_until_no_capture();
|
||||
|
||||
for (auto& e : cuda_events) {
|
||||
cudaEvent_t event = e.first;
|
||||
Block* block = e.second;
|
||||
|
|
@ -982,8 +1006,20 @@ class DeviceCachingAllocator {
|
|||
C10_CUDA_CHECK(cudaSetDevice(prev_device));
|
||||
}
|
||||
|
||||
void insert_events_deferred_until_no_capture() {
|
||||
if (C10_UNLIKELY(needs_events_deferred_until_no_capture.size() > 0)) {
|
||||
for (auto* block : needs_events_deferred_until_no_capture) {
|
||||
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
|
||||
insert_events(block);
|
||||
}
|
||||
needs_events_deferred_until_no_capture.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void process_events()
|
||||
{
|
||||
insert_events_deferred_until_no_capture();
|
||||
|
||||
// Process outstanding cudaEvents. Events that are completed are removed
|
||||
// from the queue, and the 'event_count' for the corresponding allocation
|
||||
// is decremented. Stops at the first event which has not been completed.
|
||||
|
|
|
|||
|
|
@ -3417,6 +3417,46 @@ torch.cuda.synchronize()
|
|||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@unittest.skipIf((not TEST_CUDA) or
|
||||
TEST_WITH_ROCM or
|
||||
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
||||
def test_graph_record_stream(self):
|
||||
# Makes sure graph capture defers attempting to reclaim allocations used across streams. See
|
||||
# "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp
|
||||
potential_problem = torch.zeros((3,), device="cuda")
|
||||
a = torch.zeros((3,), device="cuda")
|
||||
s0 = torch.cuda.Stream()
|
||||
s1 = torch.cuda.Stream()
|
||||
s2 = torch.cuda.Stream()
|
||||
g = torch.cuda._Graph()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(s0):
|
||||
potential_problem.record_stream(s0)
|
||||
torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES)
|
||||
potential_problem.fill_(1.)
|
||||
del potential_problem
|
||||
|
||||
with torch.cuda.stream(s1):
|
||||
g.capture_begin()
|
||||
# potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc
|
||||
# mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life
|
||||
# event, which will cause the capture to error.
|
||||
b = a.clone()
|
||||
|
||||
# Let's also see what happens if we record_stream on a tensor during capture.
|
||||
s2.wait_stream(s1)
|
||||
with torch.cuda.stream(s2):
|
||||
b.fill_(1.)
|
||||
b.record_stream(s2) # dummy record_stream
|
||||
del b
|
||||
s1.wait_stream(s2)
|
||||
g.capture_end()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
|
||||
c = torch.zeros((3,), device="cuda")
|
||||
|
||||
def test_batch_norm_gather_stats(self):
|
||||
input = torch.randn(1, 3, 3, 3, device='cuda')
|
||||
mean, invstd = torch.batch_norm_gather_stats(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user