diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index f51db0d8a61..22f905d9576 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -674,63 +674,6 @@ bool PjRtStreamExecutorClient::IsOnCpu(PjRtMemorySpace* memory_space) { return memory_space->kind() == PinnedHostMemorySpace::kKind; } -absl::StatusOr> -PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) { - VLOG(1) << "PjRtStreamExecutorBuffer::DonateWithControlDependency"; - std::unique_ptr new_buffer; - - auto hold = GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kDonation); - - if (!hold.ok()) { - return InvalidArgument( - "Invalid buffer passed to DonateWithControlDependency: %s", - hold.status().ToString()); - } - - auto* tracked_buffer = - tensorflow::down_cast(hold.buffer()); - // Copy all the data in the existing tracked_buffer. - const auto& original_definition_events = tracked_buffer->definition_events(); - absl::InlinedVector definition_events; - auto* se_client = tensorflow::down_cast(client()); - - auto definition_event_for_status = - BufferSequencingEvent::Create(se_client->thread_pool()); - // definition_event_for_status must be the first one so that it blocks other - // actions like D2H transfer from execution before the buffer is ready. - definition_events.push_back(definition_event_for_status); - definition_events.insert(definition_events.end(), - original_definition_events.begin(), - original_definition_events.end()); - - auto new_device_buffer = std::make_unique( - device(), tracked_buffer->device_memory(), std::move(definition_events)); - - // Make the new buffer which is identical to the old, except for the new - // definition event. - new_buffer = - std::unique_ptr(std::make_unique( - on_device_shape(), std::move(new_device_buffer), se_client, device(), - device()->default_memory_space().value_or(nullptr))); - - auto* device = - tensorflow::down_cast(this->device()); - LocalDeviceState* local_device = device->local_device_state(); - dependency.OnReady( - [definition_event_for_status = std::move(definition_event_for_status), - local_device, client = se_client](absl::Status status) mutable { - // Forward the absl::Status from the supplied dependency to the - // definition event. - auto stream = local_device->BorrowStreamFromPool(); - TF_CHECK_OK(client->AllocateAndRecordEvent(definition_event_for_status, - local_device, stream.get())); - local_device->ReturnStreamToPool(std::move(stream)); - }); - - hold.ConfirmDonation(); - return new_buffer; -} - absl::StatusOr> PjRtStreamExecutorClient::LinearizeHostBufferInto( const void* data, PrimitiveType type, absl::Span dims, @@ -1330,111 +1273,6 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { return device_memory; } -void PjRtStreamExecutorBuffer::Delete() { - VLOG(3) << "PjRtStreamExecutorBuffer::Delete"; - - // When wait_for_reads_to_complete is false, Release should never fail. - // - // The only usage events that - // Release(/*wait_for_operations_to_complete=*/false) doesn't wait for are - // events defined on the compute stream. All streams other than the compute - // stream are expected to WaitFor compute stream before any write operations. - TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status()); -} - -Future<> PjRtStreamExecutorBuffer::GetReadyFuture() { - absl::InlinedVector definition_events; - Promise<> definition_promise; - Future<> definition_future; - { - absl::MutexLock lock(&mu_); - if (device_buffer() == nullptr) { - return Future<>(InvalidArgument( - "GetReadyFuture() called on deleted or donated buffer")); - } - if (!definition_future_) { - definition_events = - tensorflow::down_cast(device_buffer()) - ->definition_events(); - std::tie(definition_promise, definition_future_) = - Future<>::MakePromise(); - } - definition_future = definition_future_; - } - - if (!definition_events.empty()) { - auto* se_device = - tensorflow::down_cast(device()); - LocalDeviceState* local_device_state = se_device->local_device_state(); - auto first_definition_event = definition_events[0]; - auto async_wait_for_events = - [definition_events = std::move(definition_events), - local_device_state = std::move(local_device_state), - definition_promise = std::make_shared>( - std::move(definition_promise))]() mutable { - std::unique_ptr stream; - absl::Status defined_status = - definition_events[0]->GetDefinedStatus(); - if (!defined_status.ok()) { - definition_promise->Set(defined_status); - return; - } - for (auto& event : definition_events) { - if (!event->IsComplete()) { - if (stream == nullptr) { - stream = local_device_state->BorrowStreamFromPool(); - } - event->WaitForEventOnStream(stream.get()); - } - } - - if (stream != nullptr) { - auto* stream_ptr = stream.release(); - // We already borrowed a stream from the pool so we can safely do - // the callback directly on that stream instead of bouncing through - // local_device_state->ThenExecuteCallback. The direct callback - // saves significant time. - auto status = stream_ptr->DoHostCallback( - [definition_promise, stream_ptr, local_device_state, - event_with_status = definition_events[0]]() mutable { - local_device_state->ReturnStreamToPool( - std::unique_ptr(stream_ptr)); - definition_promise->Set( - event_with_status->GetDefinedStatus()); - }); - if (!status.ok()) { - definition_promise->Set(status); - return; - } - } else { - // All events are already complete; set the `definition_promise` - // with the status of the buffer's first definition event which may - // have error status to propagate. - definition_promise->Set(definition_events[0]->GetDefinedStatus()); - } - }; - first_definition_event->ExecuteOrAddToFutureTasks( - absl::StrFormat("async_wait_for_events_%p", &async_wait_for_events), - std::move(async_wait_for_events)); - } - - return FutureHelpers::WithProfiling( - std::move(definition_future), - /*on_block_start=*/ - [] { - tsl::profiler::TraceMeProducer traceme( - "PjRtStreamExecutorBuffer::Await"); - VLOG(3) << "PjRtStreamExecutorBuffer::Await"; - return FutureHelpers::ProfilingKeys( - {/*traceme_context_id=*/traceme.GetContextId()}); - }, - /*on_block_end=*/ - [](FutureHelpers::ProfilingKeys keys) { - tsl::profiler::TraceMeConsumer traceme( - "PjRtStreamExecutorBuffer::Await", keys.traceme_context_id); - }); -} - namespace { // Helper struct for the tuple that is transiently constructed to hold the diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 4ea2d660a01..8b598737b35 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -539,18 +539,6 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl { PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete; PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; - // Drops the buffer's reference to its associated device memory, leaving the - // buffer in an invalid state. The memory will be freed lazily when all async - // operations using the buffer have completed, according to the allocation - // semantics of the underlying platform. Delete may briefly block if another - // thread is in the process of enqueuing an operation on this buffer, but it - // will never block for a stream operation to complete. If an external - // framework holds a reference to the TrackedDeviceBuffer via - // GetBufferWithExternalReference, the memory will not be freed until the - // external framework drops the reference. - void Delete() override; - - Future<> GetReadyFuture() override; // Similar to Delete, drops the buffer's reference to its associated device // memory, leaving the buffer in an invalid state, but returns the @@ -569,9 +557,6 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl { // accesses via the buffer returned from Release. absl::StatusOr> Release( bool wait_for_operations_to_complete); - - absl::StatusOr> DonateWithControlDependency( - Future<> dependency) override; }; // Allocates the device buffers for a buffer that will be used as the diff --git a/third_party/xla/xla/pjrt/se_raw_buffer.cc b/third_party/xla/xla/pjrt/se_raw_buffer.cc index 48610001432..bf346402a7f 100644 --- a/third_party/xla/xla/pjrt/se_raw_buffer.cc +++ b/third_party/xla/xla/pjrt/se_raw_buffer.cc @@ -107,8 +107,9 @@ void PjRtStreamExecutorDeviceEventPromise::SetFromSEEvent( event.AndThen([event = event_, original_event = event]() { if (auto* error = original_event.GetErrorIfPresent()) { event.SetError(*error); + } else { + event.SetStateConcrete(); } - event.SetStateConcrete(); }); } diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index ffaeeaca092..ae198731505 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -25,21 +25,26 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/future.h" +#include "xla/pjrt/abstract_tracked_device_buffer.h" #include "xla/pjrt/buffer_sequencing_event.h" #include "xla/pjrt/device_event.h" #include "xla/pjrt/event_pool.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_stream_executor_client.h" #include "xla/pjrt/se_raw_buffer.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" @@ -47,10 +52,13 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/event.h" +#include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/platform/threadpool.h" +#include "tsl/platform/casts.h" #include "tsl/profiler/lib/connected_traceme.h" #include "tsl/profiler/lib/context_types.h" @@ -207,6 +215,78 @@ void TrackedDeviceBuffer::AddUsageEvent(BufferSequencingEventRef event, usage_events_.push_back({event, reference_held}); } +absl::StatusOr> +TrackedDeviceBuffer::CloneWithControlDependency(PjRtMemorySpace* memory_space, + Future<> dependency) { + auto* se_client = + tensorflow::down_cast(memory_space->client()); + + // Copy all the data in the existing tracked_buffer. + const auto& original_definition_events = definition_events(); + absl::InlinedVector definition_events; + + auto definition_event_for_status = + BufferSequencingEvent::Create(se_client->thread_pool()); + // definition_event_for_status must be the first one so that it blocks other + // actions like D2H transfer from execution before the buffer is ready. + definition_events.push_back(definition_event_for_status); + definition_events.insert(definition_events.end(), + original_definition_events.begin(), + original_definition_events.end()); + + auto new_device_buffer = std::make_unique( + device_, device_memory(), std::move(definition_events)); + + auto* device = tensorflow::down_cast( + memory_space->devices()[0]); + LocalDeviceState* local_device = device->local_device_state(); + dependency.OnReady( + [definition_event_for_status = std::move(definition_event_for_status), + local_device, client = se_client](absl::Status status) mutable { + // Forward the absl::Status from the supplied dependency to the + // definition event. + if (!status.ok()) { + client->SetEventAsError(definition_event_for_status, status); + return; + } + auto stream = local_device->BorrowStreamFromPool(); + TF_CHECK_OK(client->AllocateAndRecordEvent(definition_event_for_status, + local_device, stream.get())); + local_device->ReturnStreamToPool(std::move(stream)); + }); + return new_device_buffer; +} + +Future<> TrackedDeviceBuffer::GetReadyFuture(PjRtMemorySpace* memory_space) { + auto [promise, future] = Future<>::MakePromise(); + std::vector> definition_events; + definition_events.reserve(definition_events_.size()); + for (const auto& event : definition_events_) { + definition_events.push_back(event.CopyRCRef()); + } + absl::Span const> definition_events_span = + definition_events; + tsl::RunWhenReady( + definition_events_span, + [promise = std::move(promise), + definition_events = std::move(definition_events)]() mutable { + for (auto& event : definition_events) { + if (const absl::Status* error = event->GetErrorIfPresent()) { + promise.Set(*error); + return; + } + } + promise.Set(); + }); + return future; +} + +void TrackedDeviceBuffer::Delete(PjRtMemorySpace* memory_space) { + std::unique_ptr device_buffer(this); + // All events already hold onto refs to the buffer to ensure liveness so there + // is no work to do. +} + TrackedDeviceBuffer::StreamAndEventContainer TrackedDeviceBuffer::LockUseAndTransferUsageEvents() { CHECK(in_use_); diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.h b/third_party/xla/xla/pjrt/tracked_device_buffer.h index 83cabee4116..24ed1db4501 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.h @@ -180,9 +180,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { void AddUsageEvent(tsl::RCReference event) override; - void Delete(PjRtMemorySpace* memory_space) override { - LOG(FATAL) << "Implement"; - } + void Delete(PjRtMemorySpace* memory_space) override; absl::Status WaitUntilBufferReadyOnStream(std::intptr_t stream) override { for (const BufferSequencingEventRef& event : definition_events()) { @@ -191,6 +189,12 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { return absl::OkStatus(); } + absl::StatusOr> + CloneWithControlDependency(PjRtMemorySpace* memory_space, + Future<> dependency) override; + + Future<> GetReadyFuture(PjRtMemorySpace* memory_space) override; + private: PjRtDevice* device_;