mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Rewrite autograd producer consumer stream sync logic (#151079)
Also see previous work https://github.com/pytorch/pytorch/pull/142097 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151079 Approved by: https://github.com/albanD
This commit is contained in:
parent
f136046919
commit
f78e4529a9
|
|
@ -13131,14 +13131,12 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
test()
|
test()
|
||||||
|
|
||||||
# This fails because we currently sync to the default stream
|
|
||||||
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
|
||||||
@skipIfMPS
|
@skipIfMPS
|
||||||
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
|
||||||
)
|
)
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
|
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
|
|
@ -13312,7 +13310,6 @@ class TestAutogradStreamSynchronization(TestCase):
|
||||||
# This test may spuriously fail on non-cuda accelerators (since we won't
|
# This test may spuriously fail on non-cuda accelerators (since we won't
|
||||||
# be calling sleep)
|
# be calling sleep)
|
||||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_side_stream_backward_overlap(self):
|
def test_side_stream_backward_overlap(self):
|
||||||
# In case 2/3, we would designate the consumer as the accumulation
|
# In case 2/3, we would designate the consumer as the accumulation
|
||||||
# stream and naively, one might have the consumer wait for the producer
|
# stream and naively, one might have the consumer wait for the producer
|
||||||
|
|
|
||||||
|
|
@ -8379,16 +8379,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
|
||||||
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
|
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
|
||||||
name="broken_unflatten_backward",
|
name="broken_unflatten_backward",
|
||||||
),
|
),
|
||||||
# -> CPU device conversion backwards is broken
|
|
||||||
XFailRule(
|
|
||||||
error_type=RuntimeError,
|
|
||||||
error_msg="Unknown layout in record_stream_any_impl",
|
|
||||||
op_match_fn=lambda device, op: (op.full_name == "to"),
|
|
||||||
sample_match_fn=lambda device, sample: (
|
|
||||||
sample.kwargs.get("device", None) == "cpu"
|
|
||||||
),
|
|
||||||
name="broken_to_backward",
|
|
||||||
),
|
|
||||||
# sum() backward is not implemented for non-full reductions
|
# sum() backward is not implemented for non-full reductions
|
||||||
XFailRule(
|
XFailRule(
|
||||||
error_type=NotImplementedError,
|
error_type=NotImplementedError,
|
||||||
|
|
|
||||||
|
|
@ -1065,13 +1065,32 @@ void Engine::evaluate_function(
|
||||||
Node* func,
|
Node* func,
|
||||||
InputBuffer& inputs,
|
InputBuffer& inputs,
|
||||||
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
|
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
|
||||||
// The InputBuffer::adds that supplied incoming grads took pains to
|
// Locally set the current stream to func's associated stream
|
||||||
// ensure they're safe to consume in the context of the present
|
|
||||||
// func's stream (if applicable). So we guard onto that stream
|
|
||||||
// before working with the grads in any capacity.
|
|
||||||
auto opt_parent_stream = (*func).stream();
|
auto opt_parent_stream = (*func).stream();
|
||||||
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
|
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
|
||||||
|
|
||||||
|
// Ensure that the incoming gradients are ready
|
||||||
|
for (size_t pos = 0; pos < inputs.ready_events.size(); ++pos) {
|
||||||
|
if (!inputs.buffer[pos].defined()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto device = inputs.buffer[pos].device();
|
||||||
|
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
|
||||||
|
bool is_accelerator =
|
||||||
|
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
|
||||||
|
if (!is_accelerator) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TORCH_INTERNAL_ASSERT(inputs.ready_events[pos].has_value());
|
||||||
|
TORCH_INTERNAL_ASSERT(inputs.ready_streams[pos].has_value());
|
||||||
|
TORCH_INTERNAL_ASSERT(opt_parent_stream.has_value());
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
|
if (opt_parent_stream.value() != inputs.ready_streams[pos].value()) {
|
||||||
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
|
opt_parent_stream->wait(inputs.ready_events[pos].value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If exec_info_ is not empty, we have to instrument the execution
|
// If exec_info_ is not empty, we have to instrument the execution
|
||||||
auto& exec_info_ = graph_task->exec_info_;
|
auto& exec_info_ = graph_task->exec_info_;
|
||||||
if (!exec_info_.empty()) {
|
if (!exec_info_.empty()) {
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,13 @@ namespace {
|
||||||
// See https://github.com/pytorch/pytorch/issues/60306
|
// See https://github.com/pytorch/pytorch/issues/60306
|
||||||
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
|
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
|
||||||
// improved
|
// improved
|
||||||
void record_stream_any_impl(Variable& var, c10::Stream& stream) {
|
void record_stream_any_impl(Variable& var, const c10::Stream& stream) {
|
||||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
|
|
||||||
|
if (stream.device_index() != var.device().index()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());
|
const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());
|
||||||
|
|
||||||
if (C10_UNLIKELY(at::isBatchedTensor(var))) {
|
if (C10_UNLIKELY(at::isBatchedTensor(var))) {
|
||||||
|
|
@ -126,99 +131,151 @@ static void accumulate(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note: [Stream sync contract when dealing with multi-deviced-ness]
|
||||||
|
//
|
||||||
|
// An operator can deal with multiple devices, e.g. if it does a device
|
||||||
|
// transfer, etc. However, for the purpose of stream synchronization, the engine
|
||||||
|
// is only aware of single canonical device/stream for each autograd Node.
|
||||||
|
//
|
||||||
|
// For the proper synchronization, the Node author should make sure of the
|
||||||
|
// following:
|
||||||
|
//
|
||||||
|
// 1) A node consuming a gradient should wait on the canonical stream before
|
||||||
|
// using it.
|
||||||
|
// 2) A node producing a gradient should have it ready on the canonical
|
||||||
|
// stream during node execution.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Note: [Autograd Producer-Consumer Stream Syncs]
|
||||||
|
//
|
||||||
|
// The producer-consumer stream syncs are partially handled in this method
|
||||||
|
// and partially handled in the engine prior to the consumer's execution.
|
||||||
|
// The logic here is mainly responsible for handling the synchronization needed
|
||||||
|
// for accumulation and recording the event that the consumer should wait on
|
||||||
|
// later. The corresponding wait and record_stream happens in the engine.
|
||||||
|
//
|
||||||
|
// First producer
|
||||||
|
// ==============
|
||||||
|
// There are several things we need to do upon seeing the first producer:
|
||||||
|
// 1) Determine the accumulation stream (which may or may not be used):
|
||||||
|
// case A) var's device matches consumer node's canonical device
|
||||||
|
// (The producer node's canonical device may or may not match)
|
||||||
|
// -> accumulator stream = consumer stream
|
||||||
|
// case B) var's device matches producer node's canonical device
|
||||||
|
// and does not match consumer node's canonical device
|
||||||
|
// -> accumulator stream = producer stream
|
||||||
|
// case C) var device matches neither
|
||||||
|
// -> accumulator stream = var device's current stream
|
||||||
|
// See Note [Stream sync contract when dealing with
|
||||||
|
// multi-deviced-ness]
|
||||||
|
// 2) Because we are the first producer, there's no accumulation necessary.
|
||||||
|
// Just move var into the buffer.
|
||||||
|
// 3) Update the ready_events and streams for the current position.
|
||||||
|
// ready_events are events you need to wait for to ensure the corresponding
|
||||||
|
// buffers are ready. The events are updated as we accumulate into the
|
||||||
|
// buffer.
|
||||||
|
//
|
||||||
|
// Nth producer
|
||||||
|
// ============
|
||||||
|
// 1) Synchronize for accumulation. Accumulation operates on both the new
|
||||||
|
// incoming gradient and the existing gradient in the buffer.
|
||||||
|
// (i) wait stream and (ii) record stream to make sure both are ready to be
|
||||||
|
// used on the accumulation stream.
|
||||||
|
// 2) Accumulate on the accumulation stream
|
||||||
|
// 3) Update the ready event and stream for the current position.
|
||||||
|
//
|
||||||
void InputBuffer::add(
|
void InputBuffer::add(
|
||||||
size_t pos,
|
size_t pos,
|
||||||
Variable&& var,
|
Variable&& var,
|
||||||
const std::optional<c10::Stream>& opt_producer_stream,
|
const std::optional<c10::Stream>& opt_producer_stream_,
|
||||||
const std::optional<c10::Stream>& opt_consumer_stream) {
|
const std::optional<c10::Stream>& opt_consumer_stream) {
|
||||||
TORCH_INTERNAL_ASSERT(pos < buffer.size());
|
TORCH_INTERNAL_ASSERT(pos < buffer.size());
|
||||||
|
|
||||||
if (!var.defined()) {
|
if (!var.defined()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
const auto device = var.device();
|
||||||
// Switches to accumulate device
|
const auto device_type = device.type();
|
||||||
// The device (and stream) chosen for accumulation is:
|
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
|
||||||
// (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's
|
bool is_accelerator =
|
||||||
// device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and
|
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
|
||||||
// the producer share the same device:
|
//
|
||||||
// (2a) Uses the consumer's stream as the accumulation stream
|
// Non-accelerator case
|
||||||
// (2b) Syncs the accumulation stream with the producer's stream (if
|
//
|
||||||
// different) (2c) Accumulates.
|
if (!is_accelerator) {
|
||||||
// (3) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
|
if (!buffer[pos].defined()) {
|
||||||
// the consumer but not the producer:
|
|
||||||
// (3a) Uses the consumer's stream as the accumulation stream
|
|
||||||
// (3b) Syncs the accumulation stream with the consumer device's default
|
|
||||||
// stream (3c) Accumulates.
|
|
||||||
// (4) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
|
|
||||||
// the producer but not the consumer:
|
|
||||||
// (4a) Uses the producer device's default stream as the accumulation
|
|
||||||
// stream (4b) Syncs the accumulation stream with the producer's
|
|
||||||
// stream (4c) Accumulates.
|
|
||||||
// (5) var is a CUDA/MTIA/privateuse1 variable and it does not share a device
|
|
||||||
// with the consumer or producer.
|
|
||||||
// Accumulation happens on the var device's default stream.
|
|
||||||
|
|
||||||
auto const device = device_of(var);
|
|
||||||
TORCH_INTERNAL_ASSERT(device.has_value());
|
|
||||||
std::optional<c10::Stream> opt_accumulate_stream = std::nullopt;
|
|
||||||
const auto device_type = device->type();
|
|
||||||
if (device->is_cuda() || device->is_mtia() || device->is_privateuseone()) {
|
|
||||||
const auto on_producer =
|
|
||||||
opt_producer_stream && device == opt_producer_stream->device();
|
|
||||||
const auto on_consumer =
|
|
||||||
opt_consumer_stream && device == opt_consumer_stream->device();
|
|
||||||
|
|
||||||
if (on_producer && on_consumer) {
|
|
||||||
// (2a)
|
|
||||||
opt_accumulate_stream = opt_consumer_stream;
|
|
||||||
if (opt_accumulate_stream != opt_producer_stream) {
|
|
||||||
// (2b)
|
|
||||||
auto event = c10::Event{device_type};
|
|
||||||
event.record(*opt_producer_stream);
|
|
||||||
opt_accumulate_stream->wait(event);
|
|
||||||
record_stream_any_impl(var, *opt_accumulate_stream);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::optional<c10::Stream> opt_sync_stream = std::nullopt;
|
|
||||||
const auto guard = c10::impl::VirtualGuardImpl{device_type};
|
|
||||||
if (on_consumer && !on_producer) {
|
|
||||||
// (3a)
|
|
||||||
opt_accumulate_stream = opt_consumer_stream;
|
|
||||||
opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
|
|
||||||
} else if (on_producer && !on_consumer) {
|
|
||||||
// (4a)
|
|
||||||
opt_accumulate_stream =
|
|
||||||
guard.getDefaultStream(opt_producer_stream->device());
|
|
||||||
opt_sync_stream = opt_producer_stream;
|
|
||||||
} else {
|
|
||||||
// (5)
|
|
||||||
opt_accumulate_stream = guard.getDefaultStream(*device);
|
|
||||||
}
|
|
||||||
if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
|
|
||||||
// (3b), (4b)
|
|
||||||
c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
|
|
||||||
auto event = c10::Event{device_type};
|
|
||||||
event.record(*opt_sync_stream);
|
|
||||||
opt_accumulate_stream->wait(event);
|
|
||||||
const auto guard = c10::impl::VirtualGuardImpl(device_type);
|
|
||||||
record_stream_any_impl(var, *opt_accumulate_stream);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& old_var = buffer[pos];
|
|
||||||
if (!old_var.defined()) {
|
|
||||||
buffer[pos] = std::move(var);
|
buffer[pos] = std::move(var);
|
||||||
} else {
|
} else {
|
||||||
if (opt_accumulate_stream) {
|
|
||||||
c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
|
|
||||||
accumulate(buffer, pos, std::move(var));
|
|
||||||
} else {
|
|
||||||
// (1) non-CUDA/privateuse1 variable
|
|
||||||
// Accumulation happens on variable's device
|
|
||||||
c10::OptionalDeviceGuard device_guard{device};
|
c10::OptionalDeviceGuard device_guard{device};
|
||||||
accumulate(buffer, pos, std::move(var));
|
accumulate(buffer, pos, std::move(var));
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Handle the case where var is on an accelerator but producer node has no
|
||||||
|
// canonical stream, e.g. this can happen if forward is DtoH
|
||||||
|
const std::optional<c10::Stream>& opt_producer_stream =
|
||||||
|
(opt_producer_stream_.has_value()
|
||||||
|
? opt_producer_stream_
|
||||||
|
: std::optional<c10::Stream>(
|
||||||
|
at::accelerator::getCurrentStream(device.index())));
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
|
||||||
|
|
||||||
|
// See Note: [Autograd Producer-Consumer Stream Syncs]
|
||||||
|
if (!opt_accum_streams[pos].has_value()) {
|
||||||
|
// [ First producer ]
|
||||||
|
TORCH_INTERNAL_ASSERT(!buffer[pos].defined());
|
||||||
|
// 1)
|
||||||
|
if (opt_consumer_stream->device() == device) {
|
||||||
|
// Case A
|
||||||
|
opt_accum_streams[pos] = opt_consumer_stream;
|
||||||
|
if (*opt_consumer_stream != *opt_producer_stream) {
|
||||||
|
// We will end up doing record_stream on the accumulation stream
|
||||||
|
// (which is the consumer stream) later, but we also need to do
|
||||||
|
// it here in case we don't end up accumulating.
|
||||||
|
record_stream_any_impl(var, *opt_consumer_stream);
|
||||||
|
}
|
||||||
|
} else if (opt_producer_stream->device() == device) {
|
||||||
|
// Case B
|
||||||
|
opt_accum_streams[pos] = opt_producer_stream;
|
||||||
|
} else {
|
||||||
|
// Case C
|
||||||
|
opt_accum_streams[pos] =
|
||||||
|
at::accelerator::getCurrentStream(device.index());
|
||||||
|
}
|
||||||
|
// 2)
|
||||||
|
buffer[pos] = std::move(var);
|
||||||
|
// 3)
|
||||||
|
auto event = c10::Event{device_type};
|
||||||
|
event.record(*opt_producer_stream);
|
||||||
|
ready_events[pos] = std::move(event);
|
||||||
|
ready_streams[pos] = opt_producer_stream;
|
||||||
|
} else {
|
||||||
|
// [ Nth producer ]
|
||||||
|
auto accum_stream = opt_accum_streams[pos];
|
||||||
|
auto& ready_event = ready_events[pos];
|
||||||
|
auto& ready_stream = ready_streams[pos];
|
||||||
|
TORCH_INTERNAL_ASSERT(accum_stream && ready_event && ready_stream);
|
||||||
|
// 1)
|
||||||
|
if (*accum_stream != *opt_producer_stream) {
|
||||||
|
auto event = c10::Event{device_type};
|
||||||
|
event.record(*opt_producer_stream);
|
||||||
|
accum_stream->wait(event);
|
||||||
|
record_stream_any_impl(var, *accum_stream);
|
||||||
|
}
|
||||||
|
if (*accum_stream != *ready_stream) {
|
||||||
|
accum_stream->wait(*ready_event);
|
||||||
|
// This is redundant for case A, but needed for case C
|
||||||
|
record_stream_any_impl(buffer[pos], *accum_stream);
|
||||||
|
}
|
||||||
|
// 2)
|
||||||
|
c10::OptionalStreamGuard stream_guard{accum_stream};
|
||||||
|
accumulate(buffer, pos, std::move(var));
|
||||||
|
// 3)
|
||||||
|
auto event = c10::Event{device_type};
|
||||||
|
event.record(*accum_stream);
|
||||||
|
ready_events[pos] = std::move(event);
|
||||||
|
ready_streams[pos] = accum_stream;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,11 @@
|
||||||
namespace torch::autograd {
|
namespace torch::autograd {
|
||||||
|
|
||||||
struct InputBuffer {
|
struct InputBuffer {
|
||||||
explicit InputBuffer(size_t size) : buffer(size) {}
|
explicit InputBuffer(size_t size)
|
||||||
|
: buffer(size),
|
||||||
|
opt_accum_streams(size),
|
||||||
|
ready_events(size),
|
||||||
|
ready_streams(size) {}
|
||||||
InputBuffer(const InputBuffer& other) = delete;
|
InputBuffer(const InputBuffer& other) = delete;
|
||||||
InputBuffer(InputBuffer&& other) = default;
|
InputBuffer(InputBuffer&& other) = default;
|
||||||
explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)) {}
|
explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)) {}
|
||||||
|
|
@ -38,6 +42,14 @@ struct InputBuffer {
|
||||||
static std::vector<Variable> variables(InputBuffer&& g);
|
static std::vector<Variable> variables(InputBuffer&& g);
|
||||||
|
|
||||||
std::vector<Variable> buffer;
|
std::vector<Variable> buffer;
|
||||||
|
// The stream used for accumulation when a variable is used multiple times.
|
||||||
|
std::vector<std::optional<c10::Stream>> opt_accum_streams;
|
||||||
|
// The events you need to wait for to ensure the corresponding buffers
|
||||||
|
// are ready. The events are updated as we accumulate into the buffer.
|
||||||
|
std::vector<std::optional<c10::Event>> ready_events;
|
||||||
|
// The streams corresponding to the events above. This is only used to
|
||||||
|
// check if more synchronization is needed or not.
|
||||||
|
std::vector<std::optional<c10::Stream>> ready_streams;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace torch::autograd
|
} // namespace torch::autograd
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user