Rewrite autograd producer consumer stream sync logic (#151079)

Also see previous work https://github.com/pytorch/pytorch/pull/142097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151079
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2025-05-15 14:54:39 -07:00 committed by PyTorch MergeBot
parent 2ce0b66db8
commit a060f3d272
5 changed files with 185 additions and 101 deletions

View File

@ -13132,14 +13132,12 @@ class TestAutogradStreamSynchronization(TestCase):
for _ in range(2):
test()
# This fails because we currently sync to the default stream
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
@skipIfMPS
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
@unittest.skipIf(
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
)
@unittest.expectedFailure
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
self,
):
@ -13313,7 +13311,6 @@ class TestAutogradStreamSynchronization(TestCase):
# This test may spuriously fail on non-cuda accelerators (since we won't
# be calling sleep)
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
@unittest.expectedFailure
def test_side_stream_backward_overlap(self):
# In case 2/3, we would designate the consumer as the accumulation
# stream and naively, one might have the consumer wait for the producer

View File

@ -8379,16 +8379,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
name="broken_unflatten_backward",
),
# -> CPU device conversion backwards is broken
XFailRule(
error_type=RuntimeError,
error_msg="Unknown layout in record_stream_any_impl",
op_match_fn=lambda device, op: (op.full_name == "to"),
sample_match_fn=lambda device, sample: (
sample.kwargs.get("device", None) == "cpu"
),
name="broken_to_backward",
),
# sum() backward is not implemented for non-full reductions
XFailRule(
error_type=NotImplementedError,

View File

@ -1065,13 +1065,32 @@ void Engine::evaluate_function(
Node* func,
InputBuffer& inputs,
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
// The InputBuffer::adds that supplied incoming grads took pains to
// ensure they're safe to consume in the context of the present
// func's stream (if applicable). So we guard onto that stream
// before working with the grads in any capacity.
// Locally set the current stream to func's associated stream
auto opt_parent_stream = (*func).stream();
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
// Ensure that the incoming gradients are ready
for (size_t pos = 0; pos < inputs.ready_events.size(); ++pos) {
if (!inputs.buffer[pos].defined()) {
continue;
}
const auto device = inputs.buffer[pos].device();
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
bool is_accelerator =
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
if (!is_accelerator) {
continue;
}
TORCH_INTERNAL_ASSERT(inputs.ready_events[pos].has_value());
TORCH_INTERNAL_ASSERT(inputs.ready_streams[pos].has_value());
TORCH_INTERNAL_ASSERT(opt_parent_stream.has_value());
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
if (opt_parent_stream.value() != inputs.ready_streams[pos].value()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
opt_parent_stream->wait(inputs.ready_events[pos].value());
}
}
// If exec_info_ is not empty, we have to instrument the execution
auto& exec_info_ = graph_task->exec_info_;
if (!exec_info_.empty()) {

View File

@ -26,8 +26,13 @@ namespace {
// See https://github.com/pytorch/pytorch/issues/60306
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
// improved
void record_stream_any_impl(Variable& var, c10::Stream& stream) {
void record_stream_any_impl(Variable& var, const c10::Stream& stream) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
if (stream.device_index() != var.device().index()) {
return;
}
const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());
if (C10_UNLIKELY(at::isBatchedTensor(var))) {
@ -126,99 +131,160 @@ static void accumulate(
}
}
// Note: [Stream sync contract when dealing with multi-deviced-ness]
//
// An operator can deal with multiple devices, e.g. if it does a device
// transfer, etc. However, for the purpose of stream synchronization, the engine
// is only aware of single canonical device/stream for each autograd Node.
//
// For the proper synchronization, the Node author should make sure of the
// following:
//
// 1) A node consuming a gradient should wait on the canonical stream before
// using it.
// 2) A node producing a gradient should have it ready on the canonical
// stream during node execution.
//
// Note: [Autograd Producer-Consumer Stream Syncs]
//
// The producer-consumer stream syncs are partially handled in this method
// and partially handled in the engine prior to the consumer's execution.
// The logic here is mainly responsible for handling the synchronization needed
// for accumulation and recording the event that the consumer should wait on
// later. The corresponding wait and record_stream happens in the engine.
//
// First producer
// ==============
// There are several things we need to do upon seeing the first producer:
// 1) Determine the accumulation stream (which may or may not be used):
// case A) var's device matches consumer node's canonical device
// (The producer node's canonical device may or may not match)
// -> accumulator stream = consumer stream
// case B) var's device matches producer node's canonical device
// and does not match consumer node's canonical device
// -> accumulator stream = producer stream
// case C) var device matches neither
// -> accumulator stream = var device's current stream
// See Note [Stream sync contract when dealing with
// multi-deviced-ness]
// 2) Because we are the first producer, there's no accumulation necessary.
// Just move var into the buffer.
// 3) Update the ready_events and streams for the current position.
// ready_events are events you need to wait for to ensure the corresponding
// buffers are ready. The events are updated as we accumulate into the
// buffer.
//
// Nth producer
// ============
// 1) Synchronize for accumulation. Accumulation operates on both the new
// incoming gradient and the existing gradient in the buffer.
// (i) wait stream and (ii) record stream to make sure both are ready to be
// used on the accumulation stream.
// 2) Accumulate on the accumulation stream
// 3) Update the ready event and stream for the current position.
//
void InputBuffer::add(
size_t pos,
Variable&& var,
const std::optional<c10::Stream>& opt_producer_stream,
const std::optional<c10::Stream>& opt_consumer_stream) {
const std::optional<c10::Stream>& opt_producer_stream_,
const std::optional<c10::Stream>& opt_consumer_stream_) {
TORCH_INTERNAL_ASSERT(pos < buffer.size());
if (!var.defined()) {
return;
}
// Switches to accumulate device
// The device (and stream) chosen for accumulation is:
// (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's
// device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and
// the producer share the same device:
// (2a) Uses the consumer's stream as the accumulation stream
// (2b) Syncs the accumulation stream with the producer's stream (if
// different) (2c) Accumulates.
// (3) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
// the consumer but not the producer:
// (3a) Uses the consumer's stream as the accumulation stream
// (3b) Syncs the accumulation stream with the consumer device's default
// stream (3c) Accumulates.
// (4) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
// the producer but not the consumer:
// (4a) Uses the producer device's default stream as the accumulation
// stream (4b) Syncs the accumulation stream with the producer's
// stream (4c) Accumulates.
// (5) var is a CUDA/MTIA/privateuse1 variable and it does not share a device
// with the consumer or producer.
// Accumulation happens on the var device's default stream.
auto const device = device_of(var);
TORCH_INTERNAL_ASSERT(device.has_value());
std::optional<c10::Stream> opt_accumulate_stream = std::nullopt;
const auto device_type = device->type();
if (device->is_cuda() || device->is_mtia() || device->is_privateuseone()) {
const auto on_producer =
opt_producer_stream && device == opt_producer_stream->device();
const auto on_consumer =
opt_consumer_stream && device == opt_consumer_stream->device();
if (on_producer && on_consumer) {
// (2a)
opt_accumulate_stream = opt_consumer_stream;
if (opt_accumulate_stream != opt_producer_stream) {
// (2b)
auto event = c10::Event{device_type};
event.record(*opt_producer_stream);
opt_accumulate_stream->wait(event);
record_stream_any_impl(var, *opt_accumulate_stream);
}
const auto device = var.device();
const auto device_type = device.type();
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
bool is_accelerator =
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
//
// Non-accelerator case
//
if (!is_accelerator) {
if (!buffer[pos].defined()) {
buffer[pos] = std::move(var);
} else {
std::optional<c10::Stream> opt_sync_stream = std::nullopt;
const auto guard = c10::impl::VirtualGuardImpl{device_type};
if (on_consumer && !on_producer) {
// (3a)
opt_accumulate_stream = opt_consumer_stream;
opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
} else if (on_producer && !on_consumer) {
// (4a)
opt_accumulate_stream =
guard.getDefaultStream(opt_producer_stream->device());
opt_sync_stream = opt_producer_stream;
} else {
// (5)
opt_accumulate_stream = guard.getDefaultStream(*device);
}
if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
// (3b), (4b)
c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
auto event = c10::Event{device_type};
event.record(*opt_sync_stream);
opt_accumulate_stream->wait(event);
const auto guard = c10::impl::VirtualGuardImpl(device_type);
record_stream_any_impl(var, *opt_accumulate_stream);
}
}
}
auto& old_var = buffer[pos];
if (!old_var.defined()) {
buffer[pos] = std::move(var);
} else {
if (opt_accumulate_stream) {
c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
accumulate(buffer, pos, std::move(var));
} else {
// (1) non-CUDA/privateuse1 variable
// Accumulation happens on variable's device
c10::OptionalDeviceGuard device_guard{device};
accumulate(buffer, pos, std::move(var));
}
return;
}
// Handle the case where var is on an accelerator but producer node has no
// canonical stream, e.g. this can happen if forward is DtoH
const std::optional<c10::Stream>& opt_producer_stream =
(opt_producer_stream_.has_value()
? opt_producer_stream_
: std::optional<c10::Stream>(
at::accelerator::getCurrentStream(device.index())));
// opt_consumer_stream is always non-null when is_accelerator is true
// when InputBuffer is used in the engine. InputBuffer is also called
// elsewhere however! (e.g. other engine implementations)
const std::optional<c10::Stream>& opt_consumer_stream =
(opt_consumer_stream_.has_value()
? opt_consumer_stream_
: std::optional<c10::Stream>(
at::accelerator::getCurrentStream(device.index())));
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
// See Note: [Autograd Producer-Consumer Stream Syncs]
if (!opt_accum_streams[pos].has_value()) {
// [ First producer ]
TORCH_INTERNAL_ASSERT(!buffer[pos].defined());
// 1)
if (opt_consumer_stream->device() == device) {
// Case A
opt_accum_streams[pos] = opt_consumer_stream;
if (*opt_consumer_stream != *opt_producer_stream) {
// We will end up doing record_stream on the accumulation stream
// (which is the consumer stream) later, but we also need to do
// it here in case we don't end up accumulating.
record_stream_any_impl(var, *opt_consumer_stream);
}
} else if (opt_producer_stream->device() == device) {
// Case B
opt_accum_streams[pos] = opt_producer_stream;
} else {
// Case C
opt_accum_streams[pos] =
at::accelerator::getCurrentStream(device.index());
}
// 2)
buffer[pos] = std::move(var);
// 3)
auto event = c10::Event{device_type};
event.record(*opt_producer_stream);
ready_events[pos] = std::move(event);
ready_streams[pos] = opt_producer_stream;
} else {
// [ Nth producer ]
auto accum_stream = opt_accum_streams[pos];
auto& ready_event = ready_events[pos];
auto& ready_stream = ready_streams[pos];
TORCH_INTERNAL_ASSERT(accum_stream && ready_event && ready_stream);
// 1)
if (*accum_stream != *opt_producer_stream) {
auto event = c10::Event{device_type};
event.record(*opt_producer_stream);
accum_stream->wait(event);
record_stream_any_impl(var, *accum_stream);
}
if (*accum_stream != *ready_stream) {
accum_stream->wait(*ready_event);
// This is redundant for case A, but needed for case C
record_stream_any_impl(buffer[pos], *accum_stream);
}
// 2)
c10::OptionalStreamGuard stream_guard{accum_stream};
accumulate(buffer, pos, std::move(var));
// 3)
auto event = c10::Event{device_type};
event.record(*accum_stream);
ready_events[pos] = std::move(event);
ready_streams[pos] = accum_stream;
}
}

View File

@ -15,7 +15,11 @@
namespace torch::autograd {
struct InputBuffer {
explicit InputBuffer(size_t size) : buffer(size) {}
explicit InputBuffer(size_t size)
: buffer(size),
opt_accum_streams(size),
ready_events(size),
ready_streams(size) {}
InputBuffer(const InputBuffer& other) = delete;
InputBuffer(InputBuffer&& other) = default;
explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)) {}
@ -38,6 +42,14 @@ struct InputBuffer {
static std::vector<Variable> variables(InputBuffer&& g);
std::vector<Variable> buffer;
// The stream used for accumulation when a variable is used multiple times.
std::vector<std::optional<c10::Stream>> opt_accum_streams;
// The events you need to wait for to ensure the corresponding buffers
// are ready. The events are updated as we accumulate into the buffer.
std::vector<std::optional<c10::Event>> ready_events;
// The streams corresponding to the events above. This is only used to
// check if more synchronization is needed or not.
std::vector<std::optional<c10::Stream>> ready_streams;
};
} // namespace torch::autograd