diff --git a/test/test_autograd.py b/test/test_autograd.py index 4688c1685b6..ab4c5862db5 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -13131,14 +13131,12 @@ class TestAutogradStreamSynchronization(TestCase): for _ in range(2): test() - # This fails because we currently sync to the default stream # AttributeError: module 'torch.mps' has no attribute 'default_stream' @skipIfMPS @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) - @unittest.expectedFailure def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream( self, ): @@ -13312,7 +13310,6 @@ class TestAutogradStreamSynchronization(TestCase): # This test may spuriously fail on non-cuda accelerators (since we won't # be calling sleep) @unittest.skipIf(not TEST_CUDA, "requires CUDA") - @unittest.expectedFailure def test_side_stream_backward_overlap(self): # In case 2/3, we would designate the consumer as the accumulation # stream and naively, one might have the consumer wait for the producer diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 1f4a29d95eb..f53268cb24d 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -8379,16 +8379,6 @@ BACKWARD_SKIPS_AND_XFAILS = [ sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name), name="broken_unflatten_backward", ), - # -> CPU device conversion backwards is broken - XFailRule( - error_type=RuntimeError, - error_msg="Unknown layout in record_stream_any_impl", - op_match_fn=lambda device, op: (op.full_name == "to"), - sample_match_fn=lambda device, sample: ( - sample.kwargs.get("device", None) == "cpu" - ), - name="broken_to_backward", - ), # sum() backward is not implemented for non-full reductions XFailRule( error_type=NotImplementedError, diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index f1bdd05d84c..b6ef6216ef4 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1065,13 +1065,32 @@ void Engine::evaluate_function( Node* func, InputBuffer& inputs, const std::shared_ptr& cpu_ready_queue) { - // The InputBuffer::adds that supplied incoming grads took pains to - // ensure they're safe to consume in the context of the present - // func's stream (if applicable). So we guard onto that stream - // before working with the grads in any capacity. + // Locally set the current stream to func's associated stream auto opt_parent_stream = (*func).stream(); c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream}; + // Ensure that the incoming gradients are ready + for (size_t pos = 0; pos < inputs.ready_events.size(); ++pos) { + if (!inputs.buffer[pos].defined()) { + continue; + } + const auto device = inputs.buffer[pos].device(); + // TODO: Use at::accelerator::isAccelerator(device->type()) instead + bool is_accelerator = + device.is_cuda() || device.is_mtia() || device.is_privateuseone(); + if (!is_accelerator) { + continue; + } + TORCH_INTERNAL_ASSERT(inputs.ready_events[pos].has_value()); + TORCH_INTERNAL_ASSERT(inputs.ready_streams[pos].has_value()); + TORCH_INTERNAL_ASSERT(opt_parent_stream.has_value()); + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + if (opt_parent_stream.value() != inputs.ready_streams[pos].value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + opt_parent_stream->wait(inputs.ready_events[pos].value()); + } + } + // If exec_info_ is not empty, we have to instrument the execution auto& exec_info_ = graph_task->exec_info_; if (!exec_info_.empty()) { diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 4525a50aede..a52d20652d6 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -26,8 +26,13 @@ namespace { // See https://github.com/pytorch/pytorch/issues/60306 // TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is // improved -void record_stream_any_impl(Variable& var, c10::Stream& stream) { +void record_stream_any_impl(Variable& var, const c10::Stream& stream) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + + if (stream.device_index() != var.device().index()) { + return; + } + const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type()); if (C10_UNLIKELY(at::isBatchedTensor(var))) { @@ -126,99 +131,151 @@ static void accumulate( } } +// Note: [Stream sync contract when dealing with multi-deviced-ness] +// +// An operator can deal with multiple devices, e.g. if it does a device +// transfer, etc. However, for the purpose of stream synchronization, the engine +// is only aware of single canonical device/stream for each autograd Node. +// +// For the proper synchronization, the Node author should make sure of the +// following: +// +// 1) A node consuming a gradient should wait on the canonical stream before +// using it. +// 2) A node producing a gradient should have it ready on the canonical +// stream during node execution. +// + +// Note: [Autograd Producer-Consumer Stream Syncs] +// +// The producer-consumer stream syncs are partially handled in this method +// and partially handled in the engine prior to the consumer's execution. +// The logic here is mainly responsible for handling the synchronization needed +// for accumulation and recording the event that the consumer should wait on +// later. The corresponding wait and record_stream happens in the engine. +// +// First producer +// ============== +// There are several things we need to do upon seeing the first producer: +// 1) Determine the accumulation stream (which may or may not be used): +// case A) var's device matches consumer node's canonical device +// (The producer node's canonical device may or may not match) +// -> accumulator stream = consumer stream +// case B) var's device matches producer node's canonical device +// and does not match consumer node's canonical device +// -> accumulator stream = producer stream +// case C) var device matches neither +// -> accumulator stream = var device's current stream +// See Note [Stream sync contract when dealing with +// multi-deviced-ness] +// 2) Because we are the first producer, there's no accumulation necessary. +// Just move var into the buffer. +// 3) Update the ready_events and streams for the current position. +// ready_events are events you need to wait for to ensure the corresponding +// buffers are ready. The events are updated as we accumulate into the +// buffer. +// +// Nth producer +// ============ +// 1) Synchronize for accumulation. Accumulation operates on both the new +// incoming gradient and the existing gradient in the buffer. +// (i) wait stream and (ii) record stream to make sure both are ready to be +// used on the accumulation stream. +// 2) Accumulate on the accumulation stream +// 3) Update the ready event and stream for the current position. +// void InputBuffer::add( size_t pos, Variable&& var, - const std::optional& opt_producer_stream, + const std::optional& opt_producer_stream_, const std::optional& opt_consumer_stream) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); + if (!var.defined()) { return; } - - // Switches to accumulate device - // The device (and stream) chosen for accumulation is: - // (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's - // device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and - // the producer share the same device: - // (2a) Uses the consumer's stream as the accumulation stream - // (2b) Syncs the accumulation stream with the producer's stream (if - // different) (2c) Accumulates. - // (3) var is a CUDA/MTIA/privateuse1 variable and it shares a device with - // the consumer but not the producer: - // (3a) Uses the consumer's stream as the accumulation stream - // (3b) Syncs the accumulation stream with the consumer device's default - // stream (3c) Accumulates. - // (4) var is a CUDA/MTIA/privateuse1 variable and it shares a device with - // the producer but not the consumer: - // (4a) Uses the producer device's default stream as the accumulation - // stream (4b) Syncs the accumulation stream with the producer's - // stream (4c) Accumulates. - // (5) var is a CUDA/MTIA/privateuse1 variable and it does not share a device - // with the consumer or producer. - // Accumulation happens on the var device's default stream. - - auto const device = device_of(var); - TORCH_INTERNAL_ASSERT(device.has_value()); - std::optional opt_accumulate_stream = std::nullopt; - const auto device_type = device->type(); - if (device->is_cuda() || device->is_mtia() || device->is_privateuseone()) { - const auto on_producer = - opt_producer_stream && device == opt_producer_stream->device(); - const auto on_consumer = - opt_consumer_stream && device == opt_consumer_stream->device(); - - if (on_producer && on_consumer) { - // (2a) - opt_accumulate_stream = opt_consumer_stream; - if (opt_accumulate_stream != opt_producer_stream) { - // (2b) - auto event = c10::Event{device_type}; - event.record(*opt_producer_stream); - opt_accumulate_stream->wait(event); - record_stream_any_impl(var, *opt_accumulate_stream); - } + const auto device = var.device(); + const auto device_type = device.type(); + // TODO: Use at::accelerator::isAccelerator(device->type()) instead + bool is_accelerator = + device.is_cuda() || device.is_mtia() || device.is_privateuseone(); + // + // Non-accelerator case + // + if (!is_accelerator) { + if (!buffer[pos].defined()) { + buffer[pos] = std::move(var); } else { - std::optional opt_sync_stream = std::nullopt; - const auto guard = c10::impl::VirtualGuardImpl{device_type}; - if (on_consumer && !on_producer) { - // (3a) - opt_accumulate_stream = opt_consumer_stream; - opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device()); - } else if (on_producer && !on_consumer) { - // (4a) - opt_accumulate_stream = - guard.getDefaultStream(opt_producer_stream->device()); - opt_sync_stream = opt_producer_stream; - } else { - // (5) - opt_accumulate_stream = guard.getDefaultStream(*device); - } - if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) { - // (3b), (4b) - c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()}; - auto event = c10::Event{device_type}; - event.record(*opt_sync_stream); - opt_accumulate_stream->wait(event); - const auto guard = c10::impl::VirtualGuardImpl(device_type); - record_stream_any_impl(var, *opt_accumulate_stream); - } - } - } - - auto& old_var = buffer[pos]; - if (!old_var.defined()) { - buffer[pos] = std::move(var); - } else { - if (opt_accumulate_stream) { - c10::OptionalStreamGuard stream_guard{opt_accumulate_stream}; - accumulate(buffer, pos, std::move(var)); - } else { - // (1) non-CUDA/privateuse1 variable - // Accumulation happens on variable's device c10::OptionalDeviceGuard device_guard{device}; accumulate(buffer, pos, std::move(var)); } + return; + } + // Handle the case where var is on an accelerator but producer node has no + // canonical stream, e.g. this can happen if forward is DtoH + const std::optional& opt_producer_stream = + (opt_producer_stream_.has_value() + ? opt_producer_stream_ + : std::optional( + at::accelerator::getCurrentStream(device.index()))); + + TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream); + + // See Note: [Autograd Producer-Consumer Stream Syncs] + if (!opt_accum_streams[pos].has_value()) { + // [ First producer ] + TORCH_INTERNAL_ASSERT(!buffer[pos].defined()); + // 1) + if (opt_consumer_stream->device() == device) { + // Case A + opt_accum_streams[pos] = opt_consumer_stream; + if (*opt_consumer_stream != *opt_producer_stream) { + // We will end up doing record_stream on the accumulation stream + // (which is the consumer stream) later, but we also need to do + // it here in case we don't end up accumulating. + record_stream_any_impl(var, *opt_consumer_stream); + } + } else if (opt_producer_stream->device() == device) { + // Case B + opt_accum_streams[pos] = opt_producer_stream; + } else { + // Case C + opt_accum_streams[pos] = + at::accelerator::getCurrentStream(device.index()); + } + // 2) + buffer[pos] = std::move(var); + // 3) + auto event = c10::Event{device_type}; + event.record(*opt_producer_stream); + ready_events[pos] = std::move(event); + ready_streams[pos] = opt_producer_stream; + } else { + // [ Nth producer ] + auto accum_stream = opt_accum_streams[pos]; + auto& ready_event = ready_events[pos]; + auto& ready_stream = ready_streams[pos]; + TORCH_INTERNAL_ASSERT(accum_stream && ready_event && ready_stream); + // 1) + if (*accum_stream != *opt_producer_stream) { + auto event = c10::Event{device_type}; + event.record(*opt_producer_stream); + accum_stream->wait(event); + record_stream_any_impl(var, *accum_stream); + } + if (*accum_stream != *ready_stream) { + accum_stream->wait(*ready_event); + // This is redundant for case A, but needed for case C + record_stream_any_impl(buffer[pos], *accum_stream); + } + // 2) + c10::OptionalStreamGuard stream_guard{accum_stream}; + accumulate(buffer, pos, std::move(var)); + // 3) + auto event = c10::Event{device_type}; + event.record(*accum_stream); + ready_events[pos] = std::move(event); + ready_streams[pos] = accum_stream; } } diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index 5c3b46fbdaa..89abd91f491 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -15,7 +15,11 @@ namespace torch::autograd { struct InputBuffer { - explicit InputBuffer(size_t size) : buffer(size) {} + explicit InputBuffer(size_t size) + : buffer(size), + opt_accum_streams(size), + ready_events(size), + ready_streams(size) {} InputBuffer(const InputBuffer& other) = delete; InputBuffer(InputBuffer&& other) = default; explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)) {} @@ -38,6 +42,14 @@ struct InputBuffer { static std::vector variables(InputBuffer&& g); std::vector buffer; + // The stream used for accumulation when a variable is used multiple times. + std::vector> opt_accum_streams; + // The events you need to wait for to ensure the corresponding buffers + // are ready. The events are updated as we accumulate into the buffer. + std::vector> ready_events; + // The streams corresponding to the events above. This is only used to + // check if more synchronization is needed or not. + std::vector> ready_streams; }; } // namespace torch::autograd