mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Rewrite autograd producer consumer stream sync logic (#151079)
Also see previous work https://github.com/pytorch/pytorch/pull/142097 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151079 Approved by: https://github.com/albanD
This commit is contained in:
parent
f136046919
commit
f78e4529a9
|
|
@ -13131,14 +13131,12 @@ class TestAutogradStreamSynchronization(TestCase):
|
|||
for _ in range(2):
|
||||
test()
|
||||
|
||||
# This fails because we currently sync to the default stream
|
||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||
@skipIfMPS
|
||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||
@unittest.skipIf(
|
||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||
)
|
||||
@unittest.expectedFailure
|
||||
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
|
||||
self,
|
||||
):
|
||||
|
|
@ -13312,7 +13310,6 @@ class TestAutogradStreamSynchronization(TestCase):
|
|||
# This test may spuriously fail on non-cuda accelerators (since we won't
|
||||
# be calling sleep)
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
@unittest.expectedFailure
|
||||
def test_side_stream_backward_overlap(self):
|
||||
# In case 2/3, we would designate the consumer as the accumulation
|
||||
# stream and naively, one might have the consumer wait for the producer
|
||||
|
|
|
|||
|
|
@ -8379,16 +8379,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
|
|||
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
|
||||
name="broken_unflatten_backward",
|
||||
),
|
||||
# -> CPU device conversion backwards is broken
|
||||
XFailRule(
|
||||
error_type=RuntimeError,
|
||||
error_msg="Unknown layout in record_stream_any_impl",
|
||||
op_match_fn=lambda device, op: (op.full_name == "to"),
|
||||
sample_match_fn=lambda device, sample: (
|
||||
sample.kwargs.get("device", None) == "cpu"
|
||||
),
|
||||
name="broken_to_backward",
|
||||
),
|
||||
# sum() backward is not implemented for non-full reductions
|
||||
XFailRule(
|
||||
error_type=NotImplementedError,
|
||||
|
|
|
|||
|
|
@ -1065,13 +1065,32 @@ void Engine::evaluate_function(
|
|||
Node* func,
|
||||
InputBuffer& inputs,
|
||||
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
|
||||
// The InputBuffer::adds that supplied incoming grads took pains to
|
||||
// ensure they're safe to consume in the context of the present
|
||||
// func's stream (if applicable). So we guard onto that stream
|
||||
// before working with the grads in any capacity.
|
||||
// Locally set the current stream to func's associated stream
|
||||
auto opt_parent_stream = (*func).stream();
|
||||
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
|
||||
|
||||
// Ensure that the incoming gradients are ready
|
||||
for (size_t pos = 0; pos < inputs.ready_events.size(); ++pos) {
|
||||
if (!inputs.buffer[pos].defined()) {
|
||||
continue;
|
||||
}
|
||||
const auto device = inputs.buffer[pos].device();
|
||||
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
|
||||
bool is_accelerator =
|
||||
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
|
||||
if (!is_accelerator) {
|
||||
continue;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(inputs.ready_events[pos].has_value());
|
||||
TORCH_INTERNAL_ASSERT(inputs.ready_streams[pos].has_value());
|
||||
TORCH_INTERNAL_ASSERT(opt_parent_stream.has_value());
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
if (opt_parent_stream.value() != inputs.ready_streams[pos].value()) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
opt_parent_stream->wait(inputs.ready_events[pos].value());
|
||||
}
|
||||
}
|
||||
|
||||
// If exec_info_ is not empty, we have to instrument the execution
|
||||
auto& exec_info_ = graph_task->exec_info_;
|
||||
if (!exec_info_.empty()) {
|
||||
|
|
|
|||
|
|
@ -26,8 +26,13 @@ namespace {
|
|||
// See https://github.com/pytorch/pytorch/issues/60306
|
||||
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
|
||||
// improved
|
||||
void record_stream_any_impl(Variable& var, c10::Stream& stream) {
|
||||
void record_stream_any_impl(Variable& var, const c10::Stream& stream) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
|
||||
if (stream.device_index() != var.device().index()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());
|
||||
|
||||
if (C10_UNLIKELY(at::isBatchedTensor(var))) {
|
||||
|
|
@ -126,99 +131,151 @@ static void accumulate(
|
|||
}
|
||||
}
|
||||
|
||||
// Note: [Stream sync contract when dealing with multi-deviced-ness]
|
||||
//
|
||||
// An operator can deal with multiple devices, e.g. if it does a device
|
||||
// transfer, etc. However, for the purpose of stream synchronization, the engine
|
||||
// is only aware of single canonical device/stream for each autograd Node.
|
||||
//
|
||||
// For the proper synchronization, the Node author should make sure of the
|
||||
// following:
|
||||
//
|
||||
// 1) A node consuming a gradient should wait on the canonical stream before
|
||||
// using it.
|
||||
// 2) A node producing a gradient should have it ready on the canonical
|
||||
// stream during node execution.
|
||||
//
|
||||
|
||||
// Note: [Autograd Producer-Consumer Stream Syncs]
|
||||
//
|
||||
// The producer-consumer stream syncs are partially handled in this method
|
||||
// and partially handled in the engine prior to the consumer's execution.
|
||||
// The logic here is mainly responsible for handling the synchronization needed
|
||||
// for accumulation and recording the event that the consumer should wait on
|
||||
// later. The corresponding wait and record_stream happens in the engine.
|
||||
//
|
||||
// First producer
|
||||
// ==============
|
||||
// There are several things we need to do upon seeing the first producer:
|
||||
// 1) Determine the accumulation stream (which may or may not be used):
|
||||
// case A) var's device matches consumer node's canonical device
|
||||
// (The producer node's canonical device may or may not match)
|
||||
// -> accumulator stream = consumer stream
|
||||
// case B) var's device matches producer node's canonical device
|
||||
// and does not match consumer node's canonical device
|
||||
// -> accumulator stream = producer stream
|
||||
// case C) var device matches neither
|
||||
// -> accumulator stream = var device's current stream
|
||||
// See Note [Stream sync contract when dealing with
|
||||
// multi-deviced-ness]
|
||||
// 2) Because we are the first producer, there's no accumulation necessary.
|
||||
// Just move var into the buffer.
|
||||
// 3) Update the ready_events and streams for the current position.
|
||||
// ready_events are events you need to wait for to ensure the corresponding
|
||||
// buffers are ready. The events are updated as we accumulate into the
|
||||
// buffer.
|
||||
//
|
||||
// Nth producer
|
||||
// ============
|
||||
// 1) Synchronize for accumulation. Accumulation operates on both the new
|
||||
// incoming gradient and the existing gradient in the buffer.
|
||||
// (i) wait stream and (ii) record stream to make sure both are ready to be
|
||||
// used on the accumulation stream.
|
||||
// 2) Accumulate on the accumulation stream
|
||||
// 3) Update the ready event and stream for the current position.
|
||||
//
|
||||
void InputBuffer::add(
|
||||
size_t pos,
|
||||
Variable&& var,
|
||||
const std::optional<c10::Stream>& opt_producer_stream,
|
||||
const std::optional<c10::Stream>& opt_producer_stream_,
|
||||
const std::optional<c10::Stream>& opt_consumer_stream) {
|
||||
TORCH_INTERNAL_ASSERT(pos < buffer.size());
|
||||
|
||||
if (!var.defined()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Switches to accumulate device
|
||||
// The device (and stream) chosen for accumulation is:
|
||||
// (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's
|
||||
// device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and
|
||||
// the producer share the same device:
|
||||
// (2a) Uses the consumer's stream as the accumulation stream
|
||||
// (2b) Syncs the accumulation stream with the producer's stream (if
|
||||
// different) (2c) Accumulates.
|
||||
// (3) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
|
||||
// the consumer but not the producer:
|
||||
// (3a) Uses the consumer's stream as the accumulation stream
|
||||
// (3b) Syncs the accumulation stream with the consumer device's default
|
||||
// stream (3c) Accumulates.
|
||||
// (4) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
|
||||
// the producer but not the consumer:
|
||||
// (4a) Uses the producer device's default stream as the accumulation
|
||||
// stream (4b) Syncs the accumulation stream with the producer's
|
||||
// stream (4c) Accumulates.
|
||||
// (5) var is a CUDA/MTIA/privateuse1 variable and it does not share a device
|
||||
// with the consumer or producer.
|
||||
// Accumulation happens on the var device's default stream.
|
||||
|
||||
auto const device = device_of(var);
|
||||
TORCH_INTERNAL_ASSERT(device.has_value());
|
||||
std::optional<c10::Stream> opt_accumulate_stream = std::nullopt;
|
||||
const auto device_type = device->type();
|
||||
if (device->is_cuda() || device->is_mtia() || device->is_privateuseone()) {
|
||||
const auto on_producer =
|
||||
opt_producer_stream && device == opt_producer_stream->device();
|
||||
const auto on_consumer =
|
||||
opt_consumer_stream && device == opt_consumer_stream->device();
|
||||
|
||||
if (on_producer && on_consumer) {
|
||||
// (2a)
|
||||
opt_accumulate_stream = opt_consumer_stream;
|
||||
if (opt_accumulate_stream != opt_producer_stream) {
|
||||
// (2b)
|
||||
auto event = c10::Event{device_type};
|
||||
event.record(*opt_producer_stream);
|
||||
opt_accumulate_stream->wait(event);
|
||||
record_stream_any_impl(var, *opt_accumulate_stream);
|
||||
}
|
||||
const auto device = var.device();
|
||||
const auto device_type = device.type();
|
||||
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
|
||||
bool is_accelerator =
|
||||
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
|
||||
//
|
||||
// Non-accelerator case
|
||||
//
|
||||
if (!is_accelerator) {
|
||||
if (!buffer[pos].defined()) {
|
||||
buffer[pos] = std::move(var);
|
||||
} else {
|
||||
std::optional<c10::Stream> opt_sync_stream = std::nullopt;
|
||||
const auto guard = c10::impl::VirtualGuardImpl{device_type};
|
||||
if (on_consumer && !on_producer) {
|
||||
// (3a)
|
||||
opt_accumulate_stream = opt_consumer_stream;
|
||||
opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
|
||||
} else if (on_producer && !on_consumer) {
|
||||
// (4a)
|
||||
opt_accumulate_stream =
|
||||
guard.getDefaultStream(opt_producer_stream->device());
|
||||
opt_sync_stream = opt_producer_stream;
|
||||
} else {
|
||||
// (5)
|
||||
opt_accumulate_stream = guard.getDefaultStream(*device);
|
||||
}
|
||||
if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
|
||||
// (3b), (4b)
|
||||
c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
|
||||
auto event = c10::Event{device_type};
|
||||
event.record(*opt_sync_stream);
|
||||
opt_accumulate_stream->wait(event);
|
||||
const auto guard = c10::impl::VirtualGuardImpl(device_type);
|
||||
record_stream_any_impl(var, *opt_accumulate_stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& old_var = buffer[pos];
|
||||
if (!old_var.defined()) {
|
||||
buffer[pos] = std::move(var);
|
||||
} else {
|
||||
if (opt_accumulate_stream) {
|
||||
c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
|
||||
accumulate(buffer, pos, std::move(var));
|
||||
} else {
|
||||
// (1) non-CUDA/privateuse1 variable
|
||||
// Accumulation happens on variable's device
|
||||
c10::OptionalDeviceGuard device_guard{device};
|
||||
accumulate(buffer, pos, std::move(var));
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Handle the case where var is on an accelerator but producer node has no
|
||||
// canonical stream, e.g. this can happen if forward is DtoH
|
||||
const std::optional<c10::Stream>& opt_producer_stream =
|
||||
(opt_producer_stream_.has_value()
|
||||
? opt_producer_stream_
|
||||
: std::optional<c10::Stream>(
|
||||
at::accelerator::getCurrentStream(device.index())));
|
||||
|
||||
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
|
||||
|
||||
// See Note: [Autograd Producer-Consumer Stream Syncs]
|
||||
if (!opt_accum_streams[pos].has_value()) {
|
||||
// [ First producer ]
|
||||
TORCH_INTERNAL_ASSERT(!buffer[pos].defined());
|
||||
// 1)
|
||||
if (opt_consumer_stream->device() == device) {
|
||||
// Case A
|
||||
opt_accum_streams[pos] = opt_consumer_stream;
|
||||
if (*opt_consumer_stream != *opt_producer_stream) {
|
||||
// We will end up doing record_stream on the accumulation stream
|
||||
// (which is the consumer stream) later, but we also need to do
|
||||
// it here in case we don't end up accumulating.
|
||||
record_stream_any_impl(var, *opt_consumer_stream);
|
||||
}
|
||||
} else if (opt_producer_stream->device() == device) {
|
||||
// Case B
|
||||
opt_accum_streams[pos] = opt_producer_stream;
|
||||
} else {
|
||||
// Case C
|
||||
opt_accum_streams[pos] =
|
||||
at::accelerator::getCurrentStream(device.index());
|
||||
}
|
||||
// 2)
|
||||
buffer[pos] = std::move(var);
|
||||
// 3)
|
||||
auto event = c10::Event{device_type};
|
||||
event.record(*opt_producer_stream);
|
||||
ready_events[pos] = std::move(event);
|
||||
ready_streams[pos] = opt_producer_stream;
|
||||
} else {
|
||||
// [ Nth producer ]
|
||||
auto accum_stream = opt_accum_streams[pos];
|
||||
auto& ready_event = ready_events[pos];
|
||||
auto& ready_stream = ready_streams[pos];
|
||||
TORCH_INTERNAL_ASSERT(accum_stream && ready_event && ready_stream);
|
||||
// 1)
|
||||
if (*accum_stream != *opt_producer_stream) {
|
||||
auto event = c10::Event{device_type};
|
||||
event.record(*opt_producer_stream);
|
||||
accum_stream->wait(event);
|
||||
record_stream_any_impl(var, *accum_stream);
|
||||
}
|
||||
if (*accum_stream != *ready_stream) {
|
||||
accum_stream->wait(*ready_event);
|
||||
// This is redundant for case A, but needed for case C
|
||||
record_stream_any_impl(buffer[pos], *accum_stream);
|
||||
}
|
||||
// 2)
|
||||
c10::OptionalStreamGuard stream_guard{accum_stream};
|
||||
accumulate(buffer, pos, std::move(var));
|
||||
// 3)
|
||||
auto event = c10::Event{device_type};
|
||||
event.record(*accum_stream);
|
||||
ready_events[pos] = std::move(event);
|
||||
ready_streams[pos] = accum_stream;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,11 @@
|
|||
namespace torch::autograd {
|
||||
|
||||
struct InputBuffer {
|
||||
explicit InputBuffer(size_t size) : buffer(size) {}
|
||||
explicit InputBuffer(size_t size)
|
||||
: buffer(size),
|
||||
opt_accum_streams(size),
|
||||
ready_events(size),
|
||||
ready_streams(size) {}
|
||||
InputBuffer(const InputBuffer& other) = delete;
|
||||
InputBuffer(InputBuffer&& other) = default;
|
||||
explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)) {}
|
||||
|
|
@ -38,6 +42,14 @@ struct InputBuffer {
|
|||
static std::vector<Variable> variables(InputBuffer&& g);
|
||||
|
||||
std::vector<Variable> buffer;
|
||||
// The stream used for accumulation when a variable is used multiple times.
|
||||
std::vector<std::optional<c10::Stream>> opt_accum_streams;
|
||||
// The events you need to wait for to ensure the corresponding buffers
|
||||
// are ready. The events are updated as we accumulate into the buffer.
|
||||
std::vector<std::optional<c10::Event>> ready_events;
|
||||
// The streams corresponding to the events above. This is only used to
|
||||
// check if more synchronization is needed or not.
|
||||
std::vector<std::optional<c10::Stream>> ready_streams;
|
||||
};
|
||||
|
||||
} // namespace torch::autograd
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user