Implement TrackedDeviceBuffer::GetReadyFuture, TrackedDeviceBuffer::Delete, and

TrackedDeviceBuffer::CloneWithControlDependency.

PiperOrigin-RevId: 822804278
This commit is contained in:
Parker Schuh 2025-10-22 17:42:11 -07:00 committed by TensorFlower Gardener
parent 03be5156fe
commit c3e202374c
5 changed files with 89 additions and 181 deletions

View File

@ -674,63 +674,6 @@ bool PjRtStreamExecutorClient::IsOnCpu(PjRtMemorySpace* memory_space) {
return memory_space->kind() == PinnedHostMemorySpace::kKind;
}
absl::StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) {
VLOG(1) << "PjRtStreamExecutorBuffer::DonateWithControlDependency";
std::unique_ptr<PjRtBuffer> new_buffer;
auto hold = GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kDonation);
if (!hold.ok()) {
return InvalidArgument(
"Invalid buffer passed to DonateWithControlDependency: %s",
hold.status().ToString());
}
auto* tracked_buffer =
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
// Copy all the data in the existing tracked_buffer.
const auto& original_definition_events = tracked_buffer->definition_events();
absl::InlinedVector<BufferSequencingEventRef, 4> definition_events;
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client());
auto definition_event_for_status =
BufferSequencingEvent::Create(se_client->thread_pool());
// definition_event_for_status must be the first one so that it blocks other
// actions like D2H transfer from execution before the buffer is ready.
definition_events.push_back(definition_event_for_status);
definition_events.insert(definition_events.end(),
original_definition_events.begin(),
original_definition_events.end());
auto new_device_buffer = std::make_unique<TrackedDeviceBuffer>(
device(), tracked_buffer->device_memory(), std::move(definition_events));
// Make the new buffer which is identical to the old, except for the new
// definition event.
new_buffer =
std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
on_device_shape(), std::move(new_device_buffer), se_client, device(),
device()->default_memory_space().value_or(nullptr)));
auto* device =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(this->device());
LocalDeviceState* local_device = device->local_device_state();
dependency.OnReady(
[definition_event_for_status = std::move(definition_event_for_status),
local_device, client = se_client](absl::Status status) mutable {
// Forward the absl::Status from the supplied dependency to the
// definition event.
auto stream = local_device->BorrowStreamFromPool();
TF_CHECK_OK(client->AllocateAndRecordEvent(definition_event_for_status,
local_device, stream.get()));
local_device->ReturnStreamToPool(std::move(stream));
});
hold.ConfirmDonation();
return new_buffer;
}
absl::StatusOr<tsl::RCReference<PjRtDeviceEvent>>
PjRtStreamExecutorClient::LinearizeHostBufferInto(
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
@ -1330,111 +1273,6 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
return device_memory;
}
void PjRtStreamExecutorBuffer::Delete() {
VLOG(3) << "PjRtStreamExecutorBuffer::Delete";
// When wait_for_reads_to_complete is false, Release should never fail.
//
// The only usage events that
// Release(/*wait_for_operations_to_complete=*/false) doesn't wait for are
// events defined on the compute stream. All streams other than the compute
// stream are expected to WaitFor compute stream before any write operations.
TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
}
Future<> PjRtStreamExecutorBuffer::GetReadyFuture() {
absl::InlinedVector<BufferSequencingEventRef, 2> definition_events;
Promise<> definition_promise;
Future<> definition_future;
{
absl::MutexLock lock(&mu_);
if (device_buffer() == nullptr) {
return Future<>(InvalidArgument(
"GetReadyFuture() called on deleted or donated buffer"));
}
if (!definition_future_) {
definition_events =
tensorflow::down_cast<TrackedDeviceBuffer*>(device_buffer())
->definition_events();
std::tie(definition_promise, definition_future_) =
Future<>::MakePromise();
}
definition_future = definition_future_;
}
if (!definition_events.empty()) {
auto* se_device =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device());
LocalDeviceState* local_device_state = se_device->local_device_state();
auto first_definition_event = definition_events[0];
auto async_wait_for_events =
[definition_events = std::move(definition_events),
local_device_state = std::move(local_device_state),
definition_promise = std::make_shared<Promise<>>(
std::move(definition_promise))]() mutable {
std::unique_ptr<se::Stream> stream;
absl::Status defined_status =
definition_events[0]->GetDefinedStatus();
if (!defined_status.ok()) {
definition_promise->Set(defined_status);
return;
}
for (auto& event : definition_events) {
if (!event->IsComplete()) {
if (stream == nullptr) {
stream = local_device_state->BorrowStreamFromPool();
}
event->WaitForEventOnStream(stream.get());
}
}
if (stream != nullptr) {
auto* stream_ptr = stream.release();
// We already borrowed a stream from the pool so we can safely do
// the callback directly on that stream instead of bouncing through
// local_device_state->ThenExecuteCallback. The direct callback
// saves significant time.
auto status = stream_ptr->DoHostCallback(
[definition_promise, stream_ptr, local_device_state,
event_with_status = definition_events[0]]() mutable {
local_device_state->ReturnStreamToPool(
std::unique_ptr<se::Stream>(stream_ptr));
definition_promise->Set(
event_with_status->GetDefinedStatus());
});
if (!status.ok()) {
definition_promise->Set(status);
return;
}
} else {
// All events are already complete; set the `definition_promise`
// with the status of the buffer's first definition event which may
// have error status to propagate.
definition_promise->Set(definition_events[0]->GetDefinedStatus());
}
};
first_definition_event->ExecuteOrAddToFutureTasks(
absl::StrFormat("async_wait_for_events_%p", &async_wait_for_events),
std::move(async_wait_for_events));
}
return FutureHelpers::WithProfiling(
std::move(definition_future),
/*on_block_start=*/
[] {
tsl::profiler::TraceMeProducer traceme(
"PjRtStreamExecutorBuffer::Await");
VLOG(3) << "PjRtStreamExecutorBuffer::Await";
return FutureHelpers::ProfilingKeys(
{/*traceme_context_id=*/traceme.GetContextId()});
},
/*on_block_end=*/
[](FutureHelpers::ProfilingKeys keys) {
tsl::profiler::TraceMeConsumer traceme(
"PjRtStreamExecutorBuffer::Await", keys.traceme_context_id);
});
}
namespace {
// Helper struct for the tuple that is transiently constructed to hold the

View File

@ -539,18 +539,6 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
// Drops the buffer's reference to its associated device memory, leaving the
// buffer in an invalid state. The memory will be freed lazily when all async
// operations using the buffer have completed, according to the allocation
// semantics of the underlying platform. Delete may briefly block if another
// thread is in the process of enqueuing an operation on this buffer, but it
// will never block for a stream operation to complete. If an external
// framework holds a reference to the TrackedDeviceBuffer via
// GetBufferWithExternalReference, the memory will not be freed until the
// external framework drops the reference.
void Delete() override;
Future<> GetReadyFuture() override;
// Similar to Delete, drops the buffer's reference to its associated device
// memory, leaving the buffer in an invalid state, but returns the
@ -569,9 +557,6 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
// accesses via the buffer returned from Release.
absl::StatusOr<tsl::RCReference<RawSEDeviceMemory>> Release(
bool wait_for_operations_to_complete);
absl::StatusOr<std::unique_ptr<PjRtBuffer>> DonateWithControlDependency(
Future<> dependency) override;
};
// Allocates the device buffers for a buffer that will be used as the

View File

@ -107,8 +107,9 @@ void PjRtStreamExecutorDeviceEventPromise::SetFromSEEvent(
event.AndThen([event = event_, original_event = event]() {
if (auto* error = original_event.GetErrorIfPresent()) {
event.SetError(*error);
} else {
event.SetStateConcrete();
}
event.SetStateConcrete();
});
}

View File

@ -25,21 +25,26 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/future.h"
#include "xla/pjrt/abstract_tracked_device_buffer.h"
#include "xla/pjrt/buffer_sequencing_event.h"
#include "xla/pjrt/device_event.h"
#include "xla/pjrt/event_pool.h"
#include "xla/pjrt/local_device_state.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/se_raw_buffer.h"
#include "xla/service/shaped_buffer.h"
#include "xla/shape.h"
@ -47,10 +52,13 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/event.h"
#include "xla/tsl/concurrency/async_value.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/threadpool.h"
#include "tsl/platform/casts.h"
#include "tsl/profiler/lib/connected_traceme.h"
#include "tsl/profiler/lib/context_types.h"
@ -207,6 +215,78 @@ void TrackedDeviceBuffer::AddUsageEvent(BufferSequencingEventRef event,
usage_events_.push_back({event, reference_held});
}
absl::StatusOr<std::unique_ptr<AbstractTrackedDeviceBuffer>>
TrackedDeviceBuffer::CloneWithControlDependency(PjRtMemorySpace* memory_space,
Future<> dependency) {
auto* se_client =
tensorflow::down_cast<PjRtStreamExecutorClient*>(memory_space->client());
// Copy all the data in the existing tracked_buffer.
const auto& original_definition_events = definition_events();
absl::InlinedVector<BufferSequencingEventRef, 4> definition_events;
auto definition_event_for_status =
BufferSequencingEvent::Create(se_client->thread_pool());
// definition_event_for_status must be the first one so that it blocks other
// actions like D2H transfer from execution before the buffer is ready.
definition_events.push_back(definition_event_for_status);
definition_events.insert(definition_events.end(),
original_definition_events.begin(),
original_definition_events.end());
auto new_device_buffer = std::make_unique<TrackedDeviceBuffer>(
device_, device_memory(), std::move(definition_events));
auto* device = tensorflow::down_cast<PjRtStreamExecutorDevice*>(
memory_space->devices()[0]);
LocalDeviceState* local_device = device->local_device_state();
dependency.OnReady(
[definition_event_for_status = std::move(definition_event_for_status),
local_device, client = se_client](absl::Status status) mutable {
// Forward the absl::Status from the supplied dependency to the
// definition event.
if (!status.ok()) {
client->SetEventAsError(definition_event_for_status, status);
return;
}
auto stream = local_device->BorrowStreamFromPool();
TF_CHECK_OK(client->AllocateAndRecordEvent(definition_event_for_status,
local_device, stream.get()));
local_device->ReturnStreamToPool(std::move(stream));
});
return new_device_buffer;
}
Future<> TrackedDeviceBuffer::GetReadyFuture(PjRtMemorySpace* memory_space) {
auto [promise, future] = Future<>::MakePromise();
std::vector<tsl::RCReference<tsl::AsyncValue>> definition_events;
definition_events.reserve(definition_events_.size());
for (const auto& event : definition_events_) {
definition_events.push_back(event.CopyRCRef());
}
absl::Span<tsl::RCReference<tsl::AsyncValue> const> definition_events_span =
definition_events;
tsl::RunWhenReady(
definition_events_span,
[promise = std::move(promise),
definition_events = std::move(definition_events)]() mutable {
for (auto& event : definition_events) {
if (const absl::Status* error = event->GetErrorIfPresent()) {
promise.Set(*error);
return;
}
}
promise.Set();
});
return future;
}
void TrackedDeviceBuffer::Delete(PjRtMemorySpace* memory_space) {
std::unique_ptr<TrackedDeviceBuffer> device_buffer(this);
// All events already hold onto refs to the buffer to ensure liveness so there
// is no work to do.
}
TrackedDeviceBuffer::StreamAndEventContainer
TrackedDeviceBuffer::LockUseAndTransferUsageEvents() {
CHECK(in_use_);

View File

@ -180,9 +180,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
void AddUsageEvent(tsl::RCReference<PjRtDeviceEvent> event) override;
void Delete(PjRtMemorySpace* memory_space) override {
LOG(FATAL) << "Implement";
}
void Delete(PjRtMemorySpace* memory_space) override;
absl::Status WaitUntilBufferReadyOnStream(std::intptr_t stream) override {
for (const BufferSequencingEventRef& event : definition_events()) {
@ -191,6 +189,12 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<AbstractTrackedDeviceBuffer>>
CloneWithControlDependency(PjRtMemorySpace* memory_space,
Future<> dependency) override;
Future<> GetReadyFuture(PjRtMemorySpace* memory_space) override;
private:
PjRtDevice* device_;