mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Update PjRtStreamExecutorClient main execute path to use CommonPjRtBuffer::ScopedHold. Crucially
this now passes reference_held=true always. This is fine because the only time this was ever passed as false was if this was already on the compute stream and this bool is basically ignored if the stream is the compute stream (see MaybeWaitForEventOnStream). PiperOrigin-RevId: 822758577
This commit is contained in:
parent
880f245b56
commit
13ea97f3a9
1
third_party/xla/xla/pjrt/gpu/BUILD
vendored
1
third_party/xla/xla/pjrt/gpu/BUILD
vendored
|
|
@ -79,6 +79,7 @@ cc_library(
|
|||
"//xla/core/collectives:communicator",
|
||||
"//xla/core/collectives:rank_id",
|
||||
"//xla/hlo/builder:xla_computation",
|
||||
"//xla/pjrt:abstract_tracked_device_buffer",
|
||||
"//xla/pjrt:common_pjrt_client",
|
||||
"//xla/pjrt:device_event",
|
||||
"//xla/pjrt:event_pool",
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ limitations under the License.
|
|||
#include "xla/hlo/builder/xla_computation.h"
|
||||
#include "xla/layout.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/pjrt/abstract_tracked_device_buffer.h"
|
||||
#include "xla/pjrt/device_event.h"
|
||||
#include "xla/pjrt/distributed/client.h"
|
||||
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
|
||||
|
|
@ -1043,11 +1044,14 @@ StreamExecutorGpuClient::MakeCrossHostReceiveBuffers(
|
|||
definition_event));
|
||||
|
||||
// Acquire a hold on the buffer to access the underlying memory.
|
||||
PjRtStreamExecutorBuffer::ScopedHold hold = buffer->GetBufferWithUsageHold();
|
||||
CommonPjRtBuffer::ScopedHold hold =
|
||||
buffer->GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kUsage);
|
||||
|
||||
auto recv = [this, gpu_collectives, notifier, local_device, definition_event,
|
||||
stream, mem = hold->device_memory(), shape = shapes[0],
|
||||
dtype = buffer->element_type()]() mutable {
|
||||
stream,
|
||||
mem = tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer())
|
||||
->device_memory(),
|
||||
shape = shapes[0], dtype = buffer->element_type()]() mutable {
|
||||
auto f = [&]() -> absl::Status {
|
||||
// Create a CliqueId.
|
||||
TF_ASSIGN_OR_RETURN(CliqueId clique_id,
|
||||
|
|
|
|||
|
|
@ -727,15 +727,16 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) {
|
|||
VLOG(1) << "PjRtStreamExecutorBuffer::DonateWithControlDependency";
|
||||
std::unique_ptr<PjRtBuffer> new_buffer;
|
||||
|
||||
auto tracked_buffer =
|
||||
GetBufferWithHold(PjRtStreamExecutorBuffer::ScopedHold::kDonation);
|
||||
auto hold = GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kDonation);
|
||||
|
||||
if (!tracked_buffer.ok()) {
|
||||
if (!hold.ok()) {
|
||||
return InvalidArgument(
|
||||
"Invalid buffer passed to DonateWithControlDependency: %s",
|
||||
tracked_buffer.status().ToString());
|
||||
hold.status().ToString());
|
||||
}
|
||||
|
||||
auto* tracked_buffer =
|
||||
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
|
||||
// Copy all the data in the existing tracked_buffer.
|
||||
const auto& original_definition_events = tracked_buffer->definition_events();
|
||||
absl::InlinedVector<BufferSequencingEventRef, 4> definition_events;
|
||||
|
|
@ -774,7 +775,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) {
|
|||
local_device->ReturnStreamToPool(std::move(stream));
|
||||
});
|
||||
|
||||
tracked_buffer.ConfirmDonation();
|
||||
hold.ConfirmDonation();
|
||||
return new_buffer;
|
||||
}
|
||||
|
||||
|
|
@ -1420,7 +1421,9 @@ Future<> PjRtStreamExecutorBuffer::GetReadyFuture() {
|
|||
"GetReadyFuture() called on deleted or donated buffer"));
|
||||
}
|
||||
if (!definition_future_) {
|
||||
definition_events = device_buffer()->definition_events();
|
||||
definition_events =
|
||||
tensorflow::down_cast<TrackedDeviceBuffer*>(device_buffer())
|
||||
->definition_events();
|
||||
std::tie(definition_promise, definition_future_) =
|
||||
Future<>::MakePromise();
|
||||
}
|
||||
|
|
@ -1563,11 +1566,11 @@ absl::Status CheckCompatibleShapes(bool strict_shape_checking,
|
|||
// Makes a tuple from the arguments to an execution.
|
||||
static absl::StatusOr<std::pair<ShapeTree<PjRtStreamExecutorExecutionInput>,
|
||||
BufferSequencingEventRef>>
|
||||
MakeTupleHelper(
|
||||
PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
|
||||
bool strict_shape_checking, const Shape& tupled_parameter_shape,
|
||||
MakeTupleHelper(PjRtStreamExecutorClient* client,
|
||||
LocalDeviceState* local_device, bool strict_shape_checking,
|
||||
const Shape& tupled_parameter_shape,
|
||||
absl::Span<PjRtBuffer* const> py_buffers,
|
||||
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
|
||||
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
||||
int device_ordinal) {
|
||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||
TransferManager* transfer_manager =
|
||||
|
|
@ -1610,11 +1613,11 @@ MakeTupleHelper(
|
|||
local_device, allocator)};
|
||||
++input_iterator;
|
||||
// Then set each sub-tuple in turn from the parameters.
|
||||
for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
|
||||
device_buffers) {
|
||||
for (const CommonPjRtBuffer::ScopedHold& device_buffer : device_buffers) {
|
||||
input_iterator->second = {
|
||||
device_buffer.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation,
|
||||
device_buffer->device_memory()};
|
||||
device_buffer.type() == CommonPjRtBuffer::ScopedHold::kDonation,
|
||||
tensorflow::down_cast<TrackedDeviceBuffer*>(device_buffer.buffer())
|
||||
->device_memory()};
|
||||
++input_iterator;
|
||||
}
|
||||
CHECK(input_iterator == iterator_end);
|
||||
|
|
@ -1643,8 +1646,7 @@ MakeTupleHelper(
|
|||
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||
BufferSequencingEventRef definition_event, PjRtClient* client,
|
||||
PjRtDevice* device, LocalDeviceState* local_device,
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release) {
|
||||
PjRtDevice* device, LocalDeviceState* local_device) {
|
||||
if (result_buffer.shape().IsTuple()) {
|
||||
return absl::InternalError("OutputBufferHelper called on tuple.");
|
||||
}
|
||||
|
|
@ -1680,9 +1682,6 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
|
|||
auto pjrt_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
|
||||
result_buffer.shape(), std::move(out_buffer), client, device,
|
||||
memory_space);
|
||||
RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
|
||||
definition_event, local_device->compute_stream(),
|
||||
&buffers_to_release);
|
||||
return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
|
||||
}
|
||||
|
||||
|
|
@ -1798,7 +1797,7 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents(
|
|||
int device_ordinal, const ExecuteOptions& options,
|
||||
absl::Span<const Shape> executable_parameter_shapes,
|
||||
absl::Span<PjRtBuffer* const> argument_handles,
|
||||
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
|
||||
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
||||
std::vector<ShapeTree<PjRtStreamExecutorExecutionInput>> execution_inputs;
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
|
|
@ -1834,11 +1833,12 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents(
|
|||
execution_inputs.back();
|
||||
auto input_iterator = execution_input.begin();
|
||||
auto iterator_end = execution_input.end();
|
||||
const auto& buf = device_buffers[i]->device_memory();
|
||||
const auto& buf = tensorflow::down_cast<TrackedDeviceBuffer*>(
|
||||
device_buffers[i].buffer())
|
||||
->device_memory();
|
||||
CHECK(input_iterator != iterator_end);
|
||||
input_iterator->second = {
|
||||
device_buffers[i].type() ==
|
||||
PjRtStreamExecutorBuffer::ScopedHold::kDonation,
|
||||
device_buffers[i].type() == CommonPjRtBuffer::ScopedHold::kDonation,
|
||||
buf};
|
||||
++input_iterator;
|
||||
CHECK(input_iterator == iterator_end);
|
||||
|
|
@ -2154,7 +2154,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
|||
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
|
||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||
PjRtDevice* device,
|
||||
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
|
||||
std::vector<CommonPjRtBuffer::ScopedHold>* device_buffers,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
|
||||
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
|
|
@ -2194,16 +2194,17 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
|||
TF_RETURN_IF_ERROR(TestBufferDonationClashes(
|
||||
handle, donation_clashes, must_donate, i, replica, partition));
|
||||
device_buffers->emplace_back(handle->GetBufferWithHold(
|
||||
must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
|
||||
: PjRtStreamExecutorBuffer::ScopedHold::kUsage));
|
||||
PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
|
||||
device_buffers->back();
|
||||
if (!device_buffer.ok()) {
|
||||
must_donate ? CommonPjRtBuffer::ScopedHold::kDonation
|
||||
: CommonPjRtBuffer::ScopedHold::kUsage));
|
||||
CommonPjRtBuffer::ScopedHold& hold = device_buffers->back();
|
||||
if (!hold.ok()) {
|
||||
return InvalidArgument(
|
||||
"Invalid buffer passed to Execute() as argument %d to replica %d: "
|
||||
"%s",
|
||||
i, replica, device_buffer.status().ToString());
|
||||
i, replica, hold.status().ToString());
|
||||
}
|
||||
auto* device_buffer =
|
||||
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
|
||||
// If we are trying to donate the buffer wait on the usage events as well
|
||||
// as the definition events to ensure that all reads have been completed
|
||||
// before the buffer is mutated. Usage holds are excluded during a donation
|
||||
|
|
@ -2408,9 +2409,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
|||
int device_ordinal, const ExecuteOptions& options,
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks,
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release)
|
||||
const {
|
||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
|
||||
tsl::profiler::TraceMe traceme("MakeOutputBuffers");
|
||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
|
|
@ -2426,7 +2425,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
OutputBufferHelper(std::move(tuple_buffer), definition_event, client_,
|
||||
device, device_state, buffers_to_release));
|
||||
device, device_state));
|
||||
outputs.push_back(std::move(buffer));
|
||||
}
|
||||
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||
|
|
@ -2438,7 +2437,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
OutputBufferHelper(std::move(result_buffer), definition_event, client_,
|
||||
device, device_state, buffers_to_release));
|
||||
device, device_state));
|
||||
outputs.push_back(std::move(buffer));
|
||||
}
|
||||
return outputs;
|
||||
|
|
@ -2447,13 +2446,15 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
|||
static absl::Status GetFirstInputError(
|
||||
absl::Span<PjRtBuffer* const> argument_handles) {
|
||||
for (auto* handle : argument_handles) {
|
||||
auto* buffer = tensorflow::down_cast<PjRtStreamExecutorBuffer*>(handle);
|
||||
PjRtStreamExecutorBuffer::ScopedHold hold =
|
||||
buffer->GetBufferWithUsageHold();
|
||||
auto* buffer = tensorflow::down_cast<CommonPjRtBuffer*>(handle);
|
||||
CommonPjRtBuffer::ScopedHold hold =
|
||||
buffer->GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kUsage);
|
||||
if (!hold.ok()) {
|
||||
return hold.status();
|
||||
}
|
||||
for (const auto& event : hold->definition_events()) {
|
||||
auto* tracked_buffer =
|
||||
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
|
||||
for (const auto& event : tracked_buffer->definition_events()) {
|
||||
if (event->IsPredeterminedError()) {
|
||||
return event->GetDefinedStatus();
|
||||
}
|
||||
|
|
@ -2524,7 +2525,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
|||
int executable_idx = executables_.size() > 1 ? partition : 0;
|
||||
|
||||
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
|
||||
std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
|
||||
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
|
||||
device_buffers.reserve(argument_handles.size());
|
||||
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||
result_buffer_or_status =
|
||||
|
|
@ -2543,33 +2544,46 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
|||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
se::Stream* stream = device_state->compute_stream();
|
||||
|
||||
auto definition_event = device_state->GetEventForComputeStreamSyncPoint(
|
||||
auto definition_event_or = device_state->GetEventForComputeStreamSyncPoint(
|
||||
device_state->GetNextComputeStreamSyncPoint(), client_->thread_pool());
|
||||
if (!definition_event.ok()) {
|
||||
if (!definition_event_or.ok()) {
|
||||
StallStreamOnError(device_state, stream);
|
||||
for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
|
||||
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
|
||||
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
|
||||
if (b.type() == CommonPjRtBuffer::ScopedHold::kDonation) {
|
||||
// Even though there was an error we need to call ConfirmDonation, which
|
||||
// renders b invalid, since the computation has been enqueued and b has
|
||||
// been donated.
|
||||
b.ConfirmDonation();
|
||||
}
|
||||
}
|
||||
return definition_event.status();
|
||||
return definition_event_or.status();
|
||||
}
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>> leaves_to_release;
|
||||
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||
leaves_to_release.reserve(result_buffer.leaf_count());
|
||||
for (auto& node : result_buffer.leaves()) {
|
||||
leaves_to_release.push_back(node.second);
|
||||
}
|
||||
}
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release;
|
||||
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
|
||||
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs,
|
||||
MakeOutputBuffers(device_ordinal, options, std::move(result_buffer),
|
||||
*definition_event, device, compute_callbacks,
|
||||
buffers_to_release));
|
||||
*definition_event_or, device, compute_callbacks));
|
||||
|
||||
for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
|
||||
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
|
||||
RecordUsage(std::move(b), device_state, device_state, *definition_event,
|
||||
stream, &buffers_to_release);
|
||||
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
|
||||
if (b.type() == CommonPjRtBuffer::ScopedHold::kUsage) {
|
||||
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||
buffers_to_release.push_back(
|
||||
tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(
|
||||
b.buffer()->GetRawBuffer(b.parent()->memory_space()).get())
|
||||
->device_buffer());
|
||||
}
|
||||
b.ConvertUsageHold(definition_event);
|
||||
} else {
|
||||
CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
|
||||
CHECK(b.type() == CommonPjRtBuffer::ScopedHold::kDonation);
|
||||
b.ConfirmDonation();
|
||||
}
|
||||
}
|
||||
|
|
@ -2583,7 +2597,8 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
|||
}
|
||||
definition_event->AndThen(
|
||||
[callbacks{std::move(compute_callbacks)},
|
||||
buffers_to_release{std::move(buffers_to_release)}]() mutable {
|
||||
buffers_to_release{std::move(buffers_to_release)},
|
||||
leaves_to_release = std::move(leaves_to_release)]() mutable {
|
||||
for (auto& fn : callbacks) {
|
||||
std::move(fn)();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -824,7 +824,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
|||
int device_ordinal, const ExecuteOptions& options,
|
||||
absl::Span<const Shape> executable_parameter_shapes,
|
||||
absl::Span<PjRtBuffer* const> argument_handles,
|
||||
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
|
||||
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
|
||||
|
||||
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||
|
|
@ -832,7 +832,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
|||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
int partition, int executable_idx, const RunId& run_id,
|
||||
const ExecuteOptions& options, PjRtDevice* device,
|
||||
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
|
||||
std::vector<CommonPjRtBuffer::ScopedHold>* device_buffers,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
|
||||
|
||||
|
|
@ -841,9 +841,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
|||
int device_ordinal, const ExecuteOptions& options,
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks,
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release)
|
||||
const;
|
||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
|
||||
|
||||
absl::StatusOr<Result> ExecuteHelper(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user