mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Cudagraph] better support for streams (#126809)
This PR fixes Issue #124391. There are two root causes. ### Root Cause 1 [better support for stream during cudagraph capture] When recording a new function, CUDA graph tree records memory block states (e.g., address, size, allocated, etc) via `getCheckpointState`. Let's say the record is called `block_state`. Later, CUDA graph tree would like to recover exactly the same memory block states by `apply_checkpoint_execution_state_in_allocator`, which a) frees all memory blocks; b) allocate all recorded block states (regardless of `block_state->allocated`); c) free blocks with `block_state->allocated == False`; and d) check block_state matches remaining blocks (e.g., `block_state->ptr == block->ptr`). An error may occur when multiple streams exists during recording. [Note](https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDACachingAllocator.cpp#L2149-L2152) that a block will not be merged with other blocks if it is used by some streams, even if `block->allocated==False`. This may lead to a mismatch between `block_state->ptr` and `block->ptr` in `apply_checkpoint_execution_state_in_allocator`. This PR solves the issue by avoiding inserting events if this events coming from a stream used during cudagraph capture. The reason is that we know all events or streams used during cudagraph capture must have been completed before cudagraph capture finishes. ### Root Cause 2 [fix a bug in checkpoint state] When we getCheckpointState, we create block state. At that time, we do not record block->device. So block_state->device == 0 no matter the real value of block->device. See [how](https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDACachingAllocator.cpp#L744-L750) BlockState is created from a block. When use block state during setSegmentStateToCheckpoint, we use [block_state.device (=0)](https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDACachingAllocator.cpp#L1526). This leads to errors. We fixed this issue by recording block->device into block_state in getCheckpointState. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126809 Approved by: https://github.com/eellison
This commit is contained in:
parent
a41f828da7
commit
2b72e2a596
|
|
@ -742,7 +742,8 @@ struct PrivatePool {
|
|||
};
|
||||
|
||||
BlockState::BlockState(Block* block)
|
||||
: stream(block->stream),
|
||||
: device(block->device),
|
||||
stream(block->stream),
|
||||
stream_uses(block->stream_uses),
|
||||
size(block->size),
|
||||
ptr(block->ptr),
|
||||
|
|
@ -925,6 +926,10 @@ class DeviceCachingAllocator {
|
|||
|
||||
std::vector<AllocatorTraceTracker> trace_trackers_;
|
||||
|
||||
// mapping from block to a stream_set, containing streams on which the block
|
||||
// was used while cudagraph capturing
|
||||
std::unordered_map<Block*, stream_set> block_to_cudagraph_stream_uses;
|
||||
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
DeviceCachingAllocator()
|
||||
|
|
@ -1362,6 +1367,9 @@ class DeviceCachingAllocator {
|
|||
return;
|
||||
}
|
||||
block->stream_uses.insert(stream);
|
||||
if (C10_UNLIKELY(!captures_underway.empty())) {
|
||||
block_to_cudagraph_stream_uses[block].insert(stream);
|
||||
}
|
||||
}
|
||||
|
||||
/** set memory fraction to limit maximum allocated memory **/
|
||||
|
|
@ -1454,7 +1462,9 @@ class DeviceCachingAllocator {
|
|||
/* Checkpoint the state of a private pool necessary to return it to its
|
||||
* current state */
|
||||
std::unique_ptr<PrivatePoolState> getCheckpointState(MempoolId_t id) {
|
||||
auto context = maybeGatherContext(RecordContext::ALL);
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
insert_events_deferred_until_no_capture(context);
|
||||
|
||||
auto pool = graph_pools.find(id);
|
||||
if (pool != graph_pools.end()) {
|
||||
|
|
@ -2704,7 +2714,7 @@ class DeviceCachingAllocator {
|
|||
// This function syncs, so capture should not be underway. Might as well
|
||||
// make sure capture-deferred end of life events get processed too.
|
||||
TORCH_INTERNAL_ASSERT(captures_underway.empty());
|
||||
insert_events_deferred_until_no_capture();
|
||||
insert_events_deferred_until_no_capture(context);
|
||||
|
||||
for (auto& st : cuda_events) {
|
||||
for (auto& e : st.second) {
|
||||
|
|
@ -2723,6 +2733,24 @@ class DeviceCachingAllocator {
|
|||
cuda_events.clear();
|
||||
}
|
||||
|
||||
void remove_cudagraph_stream_uses(Block* block) {
|
||||
// remove stream uses added during cudagraph capture
|
||||
// (i.e., block->stream_uses - block->cudagraph_stream_uses)
|
||||
if (C10_UNLIKELY(
|
||||
block_to_cudagraph_stream_uses.find(block) !=
|
||||
block_to_cudagraph_stream_uses.end())) {
|
||||
stream_set streams(std::move(block->stream_uses));
|
||||
AT_ASSERT(block->stream_uses.empty());
|
||||
for (auto& stream : streams) {
|
||||
if (block_to_cudagraph_stream_uses[block].find(stream) ==
|
||||
block_to_cudagraph_stream_uses[block].end()) {
|
||||
block->stream_uses.insert(stream);
|
||||
}
|
||||
}
|
||||
block_to_cudagraph_stream_uses.erase(block);
|
||||
}
|
||||
}
|
||||
|
||||
void insert_events(Block* block) {
|
||||
c10::DeviceIndex prev_device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device));
|
||||
|
|
@ -2742,18 +2770,27 @@ class DeviceCachingAllocator {
|
|||
C10_CUDA_CHECK(c10::cuda::MaybeSetDevice(prev_device));
|
||||
}
|
||||
|
||||
void insert_events_deferred_until_no_capture() {
|
||||
void insert_events_deferred_until_no_capture(
|
||||
const std::shared_ptr<GatheredContext>& context) {
|
||||
if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) {
|
||||
for (auto* block : needs_events_deferred_until_no_capture) {
|
||||
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
|
||||
// only streams recorded before cudagraph will be used to insert events
|
||||
// since we know all streams recorded during cudagraph must have
|
||||
// completed (refer to Section 3.2.8.7.3.1 Cross-stream Dependencies and
|
||||
// Events in CUDA Programming Guide).
|
||||
remove_cudagraph_stream_uses(block);
|
||||
insert_events(block);
|
||||
if (block->event_count == 0) {
|
||||
free_block(block, context);
|
||||
}
|
||||
}
|
||||
needs_events_deferred_until_no_capture.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void process_events(const std::shared_ptr<GatheredContext>& context) {
|
||||
insert_events_deferred_until_no_capture();
|
||||
insert_events_deferred_until_no_capture(context);
|
||||
|
||||
// Process outstanding cudaEvents. Events that are completed are
|
||||
// removed from the queue, and the 'event_count' for the
|
||||
|
|
|
|||
|
|
@ -127,6 +127,48 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
inductor_out = compiled_matmul_cat_col(*inputs)
|
||||
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_allreduce_inductor_cudagraph_trees(self):
|
||||
"""
|
||||
Tests whether cudagraph trees support all_reduce from nccl
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
# dist.all_reduce is an inplace op in eager mode but a functionanlized op in compiled mode.
|
||||
# so we define eager_func and func separately for the same semantic.
|
||||
def eager_func(x):
|
||||
y = x * x
|
||||
dist.all_reduce(y, op=dist.ReduceOp.SUM)
|
||||
x = torch.nn.functional.silu(x)
|
||||
return x * y
|
||||
|
||||
def func(x):
|
||||
y = x * x
|
||||
y = dist.all_reduce(y, op=dist.ReduceOp.SUM)
|
||||
x = torch.nn.functional.silu(x)
|
||||
return x * y
|
||||
|
||||
options = {
|
||||
"triton.cudagraphs": True,
|
||||
"triton.cudagraph_trees": True,
|
||||
}
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
compiled_func = torch.compile(
|
||||
func, backend="inductor", fullgraph=True, options=options, dynamic=None
|
||||
)
|
||||
|
||||
for nelem in [1024, 2048, 4096]:
|
||||
x = torch.randn(nelem, device="cuda", dtype=torch.bfloat16)
|
||||
golden_out = eager_func(x)
|
||||
|
||||
for _ in range(3):
|
||||
compiled_out = compiled_func(x)
|
||||
self.assertEqual(golden_out, compiled_out)
|
||||
|
||||
def test_c10d_functional_tagged_pt2_compliant(self):
|
||||
op = torch.ops._c10d_functional.all_reduce.default
|
||||
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user