Update PjRtStreamExecutorClient main execute path to use CommonPjRtBuffer::ScopedHold. Crucially

this now passes reference_held=true always. This is fine because the only time
this was ever passed as false was if this was already on the compute stream and
this bool is basically ignored if the stream is the compute stream (see
MaybeWaitForEventOnStream).

PiperOrigin-RevId: 822758577
This commit is contained in:
Parker Schuh 2025-10-22 15:16:49 -07:00 committed by TensorFlower Gardener
parent 880f245b56
commit 13ea97f3a9
4 changed files with 81 additions and 63 deletions

View File

@ -79,6 +79,7 @@ cc_library(
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/hlo/builder:xla_computation",
"//xla/pjrt:abstract_tracked_device_buffer",
"//xla/pjrt:common_pjrt_client",
"//xla/pjrt:device_event",
"//xla/pjrt:event_pool",

View File

@ -60,6 +60,7 @@ limitations under the License.
#include "xla/hlo/builder/xla_computation.h"
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/pjrt/abstract_tracked_device_buffer.h"
#include "xla/pjrt/device_event.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
@ -1043,11 +1044,14 @@ StreamExecutorGpuClient::MakeCrossHostReceiveBuffers(
definition_event));
// Acquire a hold on the buffer to access the underlying memory.
PjRtStreamExecutorBuffer::ScopedHold hold = buffer->GetBufferWithUsageHold();
CommonPjRtBuffer::ScopedHold hold =
buffer->GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kUsage);
auto recv = [this, gpu_collectives, notifier, local_device, definition_event,
stream, mem = hold->device_memory(), shape = shapes[0],
dtype = buffer->element_type()]() mutable {
stream,
mem = tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer())
->device_memory(),
shape = shapes[0], dtype = buffer->element_type()]() mutable {
auto f = [&]() -> absl::Status {
// Create a CliqueId.
TF_ASSIGN_OR_RETURN(CliqueId clique_id,

View File

@ -727,15 +727,16 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) {
VLOG(1) << "PjRtStreamExecutorBuffer::DonateWithControlDependency";
std::unique_ptr<PjRtBuffer> new_buffer;
auto tracked_buffer =
GetBufferWithHold(PjRtStreamExecutorBuffer::ScopedHold::kDonation);
auto hold = GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kDonation);
if (!tracked_buffer.ok()) {
if (!hold.ok()) {
return InvalidArgument(
"Invalid buffer passed to DonateWithControlDependency: %s",
tracked_buffer.status().ToString());
hold.status().ToString());
}
auto* tracked_buffer =
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
// Copy all the data in the existing tracked_buffer.
const auto& original_definition_events = tracked_buffer->definition_events();
absl::InlinedVector<BufferSequencingEventRef, 4> definition_events;
@ -774,7 +775,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) {
local_device->ReturnStreamToPool(std::move(stream));
});
tracked_buffer.ConfirmDonation();
hold.ConfirmDonation();
return new_buffer;
}
@ -1420,7 +1421,9 @@ Future<> PjRtStreamExecutorBuffer::GetReadyFuture() {
"GetReadyFuture() called on deleted or donated buffer"));
}
if (!definition_future_) {
definition_events = device_buffer()->definition_events();
definition_events =
tensorflow::down_cast<TrackedDeviceBuffer*>(device_buffer())
->definition_events();
std::tie(definition_promise, definition_future_) =
Future<>::MakePromise();
}
@ -1563,12 +1566,12 @@ absl::Status CheckCompatibleShapes(bool strict_shape_checking,
// Makes a tuple from the arguments to an execution.
static absl::StatusOr<std::pair<ShapeTree<PjRtStreamExecutorExecutionInput>,
BufferSequencingEventRef>>
MakeTupleHelper(
PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
bool strict_shape_checking, const Shape& tupled_parameter_shape,
absl::Span<PjRtBuffer* const> py_buffers,
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
int device_ordinal) {
MakeTupleHelper(PjRtStreamExecutorClient* client,
LocalDeviceState* local_device, bool strict_shape_checking,
const Shape& tupled_parameter_shape,
absl::Span<PjRtBuffer* const> py_buffers,
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
int device_ordinal) {
se::DeviceMemoryAllocator* allocator = client->allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
@ -1610,11 +1613,11 @@ MakeTupleHelper(
local_device, allocator)};
++input_iterator;
// Then set each sub-tuple in turn from the parameters.
for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
device_buffers) {
for (const CommonPjRtBuffer::ScopedHold& device_buffer : device_buffers) {
input_iterator->second = {
device_buffer.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation,
device_buffer->device_memory()};
device_buffer.type() == CommonPjRtBuffer::ScopedHold::kDonation,
tensorflow::down_cast<TrackedDeviceBuffer*>(device_buffer.buffer())
->device_memory()};
++input_iterator;
}
CHECK(input_iterator == iterator_end);
@ -1643,8 +1646,7 @@ MakeTupleHelper(
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtClient* client,
PjRtDevice* device, LocalDeviceState* local_device,
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release) {
PjRtDevice* device, LocalDeviceState* local_device) {
if (result_buffer.shape().IsTuple()) {
return absl::InternalError("OutputBufferHelper called on tuple.");
}
@ -1680,9 +1682,6 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
auto pjrt_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
result_buffer.shape(), std::move(out_buffer), client, device,
memory_space);
RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
definition_event, local_device->compute_stream(),
&buffers_to_release);
return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
}
@ -1798,7 +1797,7 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents(
int device_ordinal, const ExecuteOptions& options,
absl::Span<const Shape> executable_parameter_shapes,
absl::Span<PjRtBuffer* const> argument_handles,
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
std::vector<ShapeTree<PjRtStreamExecutorExecutionInput>> execution_inputs;
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
@ -1834,11 +1833,12 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents(
execution_inputs.back();
auto input_iterator = execution_input.begin();
auto iterator_end = execution_input.end();
const auto& buf = device_buffers[i]->device_memory();
const auto& buf = tensorflow::down_cast<TrackedDeviceBuffer*>(
device_buffers[i].buffer())
->device_memory();
CHECK(input_iterator != iterator_end);
input_iterator->second = {
device_buffers[i].type() ==
PjRtStreamExecutorBuffer::ScopedHold::kDonation,
device_buffers[i].type() == CommonPjRtBuffer::ScopedHold::kDonation,
buf};
++input_iterator;
CHECK(input_iterator == iterator_end);
@ -2154,7 +2154,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
PjRtDevice* device,
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
std::vector<CommonPjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
@ -2194,16 +2194,17 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
TF_RETURN_IF_ERROR(TestBufferDonationClashes(
handle, donation_clashes, must_donate, i, replica, partition));
device_buffers->emplace_back(handle->GetBufferWithHold(
must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
: PjRtStreamExecutorBuffer::ScopedHold::kUsage));
PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
device_buffers->back();
if (!device_buffer.ok()) {
must_donate ? CommonPjRtBuffer::ScopedHold::kDonation
: CommonPjRtBuffer::ScopedHold::kUsage));
CommonPjRtBuffer::ScopedHold& hold = device_buffers->back();
if (!hold.ok()) {
return InvalidArgument(
"Invalid buffer passed to Execute() as argument %d to replica %d: "
"%s",
i, replica, device_buffer.status().ToString());
i, replica, hold.status().ToString());
}
auto* device_buffer =
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
// If we are trying to donate the buffer wait on the usage events as well
// as the definition events to ensure that all reads have been completed
// before the buffer is mutated. Usage holds are excluded during a donation
@ -2408,9 +2409,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options,
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtDevice* device,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks,
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release)
const {
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
tsl::profiler::TraceMe traceme("MakeOutputBuffers");
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
@ -2426,7 +2425,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,
OutputBufferHelper(std::move(tuple_buffer), definition_event, client_,
device, device_state, buffers_to_release));
device, device_state));
outputs.push_back(std::move(buffer));
}
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
@ -2438,7 +2437,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,
OutputBufferHelper(std::move(result_buffer), definition_event, client_,
device, device_state, buffers_to_release));
device, device_state));
outputs.push_back(std::move(buffer));
}
return outputs;
@ -2447,13 +2446,15 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
static absl::Status GetFirstInputError(
absl::Span<PjRtBuffer* const> argument_handles) {
for (auto* handle : argument_handles) {
auto* buffer = tensorflow::down_cast<PjRtStreamExecutorBuffer*>(handle);
PjRtStreamExecutorBuffer::ScopedHold hold =
buffer->GetBufferWithUsageHold();
auto* buffer = tensorflow::down_cast<CommonPjRtBuffer*>(handle);
CommonPjRtBuffer::ScopedHold hold =
buffer->GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kUsage);
if (!hold.ok()) {
return hold.status();
}
for (const auto& event : hold->definition_events()) {
auto* tracked_buffer =
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
for (const auto& event : tracked_buffer->definition_events()) {
if (event->IsPredeterminedError()) {
return event->GetDefinedStatus();
}
@ -2524,7 +2525,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
int executable_idx = executables_.size() > 1 ? partition : 0;
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
device_buffers.reserve(argument_handles.size());
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
result_buffer_or_status =
@ -2543,33 +2544,46 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
se::Stream* stream = device_state->compute_stream();
auto definition_event = device_state->GetEventForComputeStreamSyncPoint(
auto definition_event_or = device_state->GetEventForComputeStreamSyncPoint(
device_state->GetNextComputeStreamSyncPoint(), client_->thread_pool());
if (!definition_event.ok()) {
if (!definition_event_or.ok()) {
StallStreamOnError(device_state, stream);
for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
if (b.type() == CommonPjRtBuffer::ScopedHold::kDonation) {
// Even though there was an error we need to call ConfirmDonation, which
// renders b invalid, since the computation has been enqueued and b has
// been donated.
b.ConfirmDonation();
}
}
return definition_event.status();
return definition_event_or.status();
}
std::vector<tsl::RCReference<RawSEDeviceMemory>> leaves_to_release;
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
leaves_to_release.reserve(result_buffer.leaf_count());
for (auto& node : result_buffer.leaves()) {
leaves_to_release.push_back(node.second);
}
}
std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release;
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<PjRtBuffer>> outputs,
MakeOutputBuffers(device_ordinal, options, std::move(result_buffer),
*definition_event, device, compute_callbacks,
buffers_to_release));
*definition_event_or, device, compute_callbacks));
for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
RecordUsage(std::move(b), device_state, device_state, *definition_event,
stream, &buffers_to_release);
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
if (b.type() == CommonPjRtBuffer::ScopedHold::kUsage) {
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
buffers_to_release.push_back(
tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(
b.buffer()->GetRawBuffer(b.parent()->memory_space()).get())
->device_buffer());
}
b.ConvertUsageHold(definition_event);
} else {
CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
CHECK(b.type() == CommonPjRtBuffer::ScopedHold::kDonation);
b.ConfirmDonation();
}
}
@ -2583,7 +2597,8 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
}
definition_event->AndThen(
[callbacks{std::move(compute_callbacks)},
buffers_to_release{std::move(buffers_to_release)}]() mutable {
buffers_to_release{std::move(buffers_to_release)},
leaves_to_release = std::move(leaves_to_release)]() mutable {
for (auto& fn : callbacks) {
std::move(fn)();
}

View File

@ -824,7 +824,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
int device_ordinal, const ExecuteOptions& options,
absl::Span<const Shape> executable_parameter_shapes,
absl::Span<PjRtBuffer* const> argument_handles,
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
@ -832,7 +832,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id,
const ExecuteOptions& options, PjRtDevice* device,
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
std::vector<CommonPjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
@ -841,9 +841,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
int device_ordinal, const ExecuteOptions& options,
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtDevice* device,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks,
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release)
const;
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
absl::StatusOr<Result> ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica,