[Cudagraph] better support for streams (#126809)

This PR fixes Issue #124391.

There are two root causes.

### Root Cause 1 [better support for stream during cudagraph capture]

When recording a new function, CUDA graph tree records memory block states (e.g., address, size, allocated, etc) via `getCheckpointState`. Let's say the record is called `block_state`.

Later, CUDA graph tree would like to recover exactly the same memory block states by `apply_checkpoint_execution_state_in_allocator`, which a) frees all memory blocks; b) allocate all recorded block states (regardless of `block_state->allocated`); c) free blocks with `block_state->allocated == False`; and d) check block_state matches remaining blocks (e.g., `block_state->ptr == block->ptr`).

An error may occur when multiple streams exists during recording. [Note](https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDACachingAllocator.cpp#L2149-L2152) that a block will not be merged with other blocks if it is used by some streams, even if `block->allocated==False`. This may lead to a mismatch between `block_state->ptr` and `block->ptr` in `apply_checkpoint_execution_state_in_allocator`.

This PR solves the issue by avoiding inserting events if this events coming from a stream used during cudagraph capture. The reason is that we know all events or streams used during cudagraph capture must have been completed before cudagraph capture finishes.

### Root Cause 2 [fix a bug in checkpoint state]
When we getCheckpointState, we create block state. At that time, we do not record block->device. So block_state->device == 0 no matter the real value of block->device. See [how](https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDACachingAllocator.cpp#L744-L750) BlockState is created from a block.

When use block state during setSegmentStateToCheckpoint, we use [block_state.device (=0)](https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDACachingAllocator.cpp#L1526). This leads to errors.

We fixed this issue by recording block->device into block_state in getCheckpointState.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126809
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng 2024-05-29 04:52:35 +00:00 committed by PyTorch MergeBot
parent a41f828da7
commit 2b72e2a596
2 changed files with 83 additions and 4 deletions

View File

@ -742,7 +742,8 @@ struct PrivatePool {
};
BlockState::BlockState(Block* block)
: stream(block->stream),
: device(block->device),
stream(block->stream),
stream_uses(block->stream_uses),
size(block->size),
ptr(block->ptr),
@ -925,6 +926,10 @@ class DeviceCachingAllocator {
std::vector<AllocatorTraceTracker> trace_trackers_;
// mapping from block to a stream_set, containing streams on which the block
// was used while cudagraph capturing
std::unordered_map<Block*, stream_set> block_to_cudagraph_stream_uses;
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
DeviceCachingAllocator()
@ -1362,6 +1367,9 @@ class DeviceCachingAllocator {
return;
}
block->stream_uses.insert(stream);
if (C10_UNLIKELY(!captures_underway.empty())) {
block_to_cudagraph_stream_uses[block].insert(stream);
}
}
/** set memory fraction to limit maximum allocated memory **/
@ -1454,7 +1462,9 @@ class DeviceCachingAllocator {
/* Checkpoint the state of a private pool necessary to return it to its
* current state */
std::unique_ptr<PrivatePoolState> getCheckpointState(MempoolId_t id) {
auto context = maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex);
insert_events_deferred_until_no_capture(context);
auto pool = graph_pools.find(id);
if (pool != graph_pools.end()) {
@ -2704,7 +2714,7 @@ class DeviceCachingAllocator {
// This function syncs, so capture should not be underway. Might as well
// make sure capture-deferred end of life events get processed too.
TORCH_INTERNAL_ASSERT(captures_underway.empty());
insert_events_deferred_until_no_capture();
insert_events_deferred_until_no_capture(context);
for (auto& st : cuda_events) {
for (auto& e : st.second) {
@ -2723,6 +2733,24 @@ class DeviceCachingAllocator {
cuda_events.clear();
}
void remove_cudagraph_stream_uses(Block* block) {
// remove stream uses added during cudagraph capture
// (i.e., block->stream_uses - block->cudagraph_stream_uses)
if (C10_UNLIKELY(
block_to_cudagraph_stream_uses.find(block) !=
block_to_cudagraph_stream_uses.end())) {
stream_set streams(std::move(block->stream_uses));
AT_ASSERT(block->stream_uses.empty());
for (auto& stream : streams) {
if (block_to_cudagraph_stream_uses[block].find(stream) ==
block_to_cudagraph_stream_uses[block].end()) {
block->stream_uses.insert(stream);
}
}
block_to_cudagraph_stream_uses.erase(block);
}
}
void insert_events(Block* block) {
c10::DeviceIndex prev_device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device));
@ -2742,18 +2770,27 @@ class DeviceCachingAllocator {
C10_CUDA_CHECK(c10::cuda::MaybeSetDevice(prev_device));
}
void insert_events_deferred_until_no_capture() {
void insert_events_deferred_until_no_capture(
const std::shared_ptr<GatheredContext>& context) {
if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) {
for (auto* block : needs_events_deferred_until_no_capture) {
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
// only streams recorded before cudagraph will be used to insert events
// since we know all streams recorded during cudagraph must have
// completed (refer to Section 3.2.8.7.3.1 Cross-stream Dependencies and
// Events in CUDA Programming Guide).
remove_cudagraph_stream_uses(block);
insert_events(block);
if (block->event_count == 0) {
free_block(block, context);
}
}
needs_events_deferred_until_no_capture.clear();
}
}
void process_events(const std::shared_ptr<GatheredContext>& context) {
insert_events_deferred_until_no_capture();
insert_events_deferred_until_no_capture(context);
// Process outstanding cudaEvents. Events that are completed are
// removed from the queue, and the 'event_count' for the

View File

@ -127,6 +127,48 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
inductor_out = compiled_matmul_cat_col(*inputs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
def test_allreduce_inductor_cudagraph_trees(self):
"""
Tests whether cudagraph trees support all_reduce from nccl
"""
import torch.distributed as dist
# dist.all_reduce is an inplace op in eager mode but a functionanlized op in compiled mode.
# so we define eager_func and func separately for the same semantic.
def eager_func(x):
y = x * x
dist.all_reduce(y, op=dist.ReduceOp.SUM)
x = torch.nn.functional.silu(x)
return x * y
def func(x):
y = x * x
y = dist.all_reduce(y, op=dist.ReduceOp.SUM)
x = torch.nn.functional.silu(x)
return x * y
options = {
"triton.cudagraphs": True,
"triton.cudagraph_trees": True,
}
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
compiled_func = torch.compile(
func, backend="inductor", fullgraph=True, options=options, dynamic=None
)
for nelem in [1024, 2048, 4096]:
x = torch.randn(nelem, device="cuda", dtype=torch.bfloat16)
golden_out = eager_func(x)
for _ in range(3):
compiled_out = compiled_func(x)
self.assertEqual(golden_out, compiled_out)
def test_c10d_functional_tagged_pt2_compliant(self):
op = torch.ops._c10d_functional.all_reduce.default
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)