mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Implement TrackedDeviceBuffer::GetReadyFuture, TrackedDeviceBuffer::Delete, and
TrackedDeviceBuffer::CloneWithControlDependency. PiperOrigin-RevId: 822804278
This commit is contained in:
parent
03be5156fe
commit
c3e202374c
|
|
@ -674,63 +674,6 @@ bool PjRtStreamExecutorClient::IsOnCpu(PjRtMemorySpace* memory_space) {
|
|||
return memory_space->kind() == PinnedHostMemorySpace::kKind;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) {
|
||||
VLOG(1) << "PjRtStreamExecutorBuffer::DonateWithControlDependency";
|
||||
std::unique_ptr<PjRtBuffer> new_buffer;
|
||||
|
||||
auto hold = GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kDonation);
|
||||
|
||||
if (!hold.ok()) {
|
||||
return InvalidArgument(
|
||||
"Invalid buffer passed to DonateWithControlDependency: %s",
|
||||
hold.status().ToString());
|
||||
}
|
||||
|
||||
auto* tracked_buffer =
|
||||
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
|
||||
// Copy all the data in the existing tracked_buffer.
|
||||
const auto& original_definition_events = tracked_buffer->definition_events();
|
||||
absl::InlinedVector<BufferSequencingEventRef, 4> definition_events;
|
||||
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client());
|
||||
|
||||
auto definition_event_for_status =
|
||||
BufferSequencingEvent::Create(se_client->thread_pool());
|
||||
// definition_event_for_status must be the first one so that it blocks other
|
||||
// actions like D2H transfer from execution before the buffer is ready.
|
||||
definition_events.push_back(definition_event_for_status);
|
||||
definition_events.insert(definition_events.end(),
|
||||
original_definition_events.begin(),
|
||||
original_definition_events.end());
|
||||
|
||||
auto new_device_buffer = std::make_unique<TrackedDeviceBuffer>(
|
||||
device(), tracked_buffer->device_memory(), std::move(definition_events));
|
||||
|
||||
// Make the new buffer which is identical to the old, except for the new
|
||||
// definition event.
|
||||
new_buffer =
|
||||
std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
|
||||
on_device_shape(), std::move(new_device_buffer), se_client, device(),
|
||||
device()->default_memory_space().value_or(nullptr)));
|
||||
|
||||
auto* device =
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(this->device());
|
||||
LocalDeviceState* local_device = device->local_device_state();
|
||||
dependency.OnReady(
|
||||
[definition_event_for_status = std::move(definition_event_for_status),
|
||||
local_device, client = se_client](absl::Status status) mutable {
|
||||
// Forward the absl::Status from the supplied dependency to the
|
||||
// definition event.
|
||||
auto stream = local_device->BorrowStreamFromPool();
|
||||
TF_CHECK_OK(client->AllocateAndRecordEvent(definition_event_for_status,
|
||||
local_device, stream.get()));
|
||||
local_device->ReturnStreamToPool(std::move(stream));
|
||||
});
|
||||
|
||||
hold.ConfirmDonation();
|
||||
return new_buffer;
|
||||
}
|
||||
|
||||
absl::StatusOr<tsl::RCReference<PjRtDeviceEvent>>
|
||||
PjRtStreamExecutorClient::LinearizeHostBufferInto(
|
||||
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
|
||||
|
|
@ -1330,111 +1273,6 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
|
|||
return device_memory;
|
||||
}
|
||||
|
||||
void PjRtStreamExecutorBuffer::Delete() {
|
||||
VLOG(3) << "PjRtStreamExecutorBuffer::Delete";
|
||||
|
||||
// When wait_for_reads_to_complete is false, Release should never fail.
|
||||
//
|
||||
// The only usage events that
|
||||
// Release(/*wait_for_operations_to_complete=*/false) doesn't wait for are
|
||||
// events defined on the compute stream. All streams other than the compute
|
||||
// stream are expected to WaitFor compute stream before any write operations.
|
||||
TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
|
||||
}
|
||||
|
||||
Future<> PjRtStreamExecutorBuffer::GetReadyFuture() {
|
||||
absl::InlinedVector<BufferSequencingEventRef, 2> definition_events;
|
||||
Promise<> definition_promise;
|
||||
Future<> definition_future;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (device_buffer() == nullptr) {
|
||||
return Future<>(InvalidArgument(
|
||||
"GetReadyFuture() called on deleted or donated buffer"));
|
||||
}
|
||||
if (!definition_future_) {
|
||||
definition_events =
|
||||
tensorflow::down_cast<TrackedDeviceBuffer*>(device_buffer())
|
||||
->definition_events();
|
||||
std::tie(definition_promise, definition_future_) =
|
||||
Future<>::MakePromise();
|
||||
}
|
||||
definition_future = definition_future_;
|
||||
}
|
||||
|
||||
if (!definition_events.empty()) {
|
||||
auto* se_device =
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device());
|
||||
LocalDeviceState* local_device_state = se_device->local_device_state();
|
||||
auto first_definition_event = definition_events[0];
|
||||
auto async_wait_for_events =
|
||||
[definition_events = std::move(definition_events),
|
||||
local_device_state = std::move(local_device_state),
|
||||
definition_promise = std::make_shared<Promise<>>(
|
||||
std::move(definition_promise))]() mutable {
|
||||
std::unique_ptr<se::Stream> stream;
|
||||
absl::Status defined_status =
|
||||
definition_events[0]->GetDefinedStatus();
|
||||
if (!defined_status.ok()) {
|
||||
definition_promise->Set(defined_status);
|
||||
return;
|
||||
}
|
||||
for (auto& event : definition_events) {
|
||||
if (!event->IsComplete()) {
|
||||
if (stream == nullptr) {
|
||||
stream = local_device_state->BorrowStreamFromPool();
|
||||
}
|
||||
event->WaitForEventOnStream(stream.get());
|
||||
}
|
||||
}
|
||||
|
||||
if (stream != nullptr) {
|
||||
auto* stream_ptr = stream.release();
|
||||
// We already borrowed a stream from the pool so we can safely do
|
||||
// the callback directly on that stream instead of bouncing through
|
||||
// local_device_state->ThenExecuteCallback. The direct callback
|
||||
// saves significant time.
|
||||
auto status = stream_ptr->DoHostCallback(
|
||||
[definition_promise, stream_ptr, local_device_state,
|
||||
event_with_status = definition_events[0]]() mutable {
|
||||
local_device_state->ReturnStreamToPool(
|
||||
std::unique_ptr<se::Stream>(stream_ptr));
|
||||
definition_promise->Set(
|
||||
event_with_status->GetDefinedStatus());
|
||||
});
|
||||
if (!status.ok()) {
|
||||
definition_promise->Set(status);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// All events are already complete; set the `definition_promise`
|
||||
// with the status of the buffer's first definition event which may
|
||||
// have error status to propagate.
|
||||
definition_promise->Set(definition_events[0]->GetDefinedStatus());
|
||||
}
|
||||
};
|
||||
first_definition_event->ExecuteOrAddToFutureTasks(
|
||||
absl::StrFormat("async_wait_for_events_%p", &async_wait_for_events),
|
||||
std::move(async_wait_for_events));
|
||||
}
|
||||
|
||||
return FutureHelpers::WithProfiling(
|
||||
std::move(definition_future),
|
||||
/*on_block_start=*/
|
||||
[] {
|
||||
tsl::profiler::TraceMeProducer traceme(
|
||||
"PjRtStreamExecutorBuffer::Await");
|
||||
VLOG(3) << "PjRtStreamExecutorBuffer::Await";
|
||||
return FutureHelpers::ProfilingKeys(
|
||||
{/*traceme_context_id=*/traceme.GetContextId()});
|
||||
},
|
||||
/*on_block_end=*/
|
||||
[](FutureHelpers::ProfilingKeys keys) {
|
||||
tsl::profiler::TraceMeConsumer traceme(
|
||||
"PjRtStreamExecutorBuffer::Await", keys.traceme_context_id);
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper struct for the tuple that is transiently constructed to hold the
|
||||
|
|
|
|||
|
|
@ -539,18 +539,6 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
|
|||
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
|
||||
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
|
||||
|
||||
// Drops the buffer's reference to its associated device memory, leaving the
|
||||
// buffer in an invalid state. The memory will be freed lazily when all async
|
||||
// operations using the buffer have completed, according to the allocation
|
||||
// semantics of the underlying platform. Delete may briefly block if another
|
||||
// thread is in the process of enqueuing an operation on this buffer, but it
|
||||
// will never block for a stream operation to complete. If an external
|
||||
// framework holds a reference to the TrackedDeviceBuffer via
|
||||
// GetBufferWithExternalReference, the memory will not be freed until the
|
||||
// external framework drops the reference.
|
||||
void Delete() override;
|
||||
|
||||
Future<> GetReadyFuture() override;
|
||||
|
||||
// Similar to Delete, drops the buffer's reference to its associated device
|
||||
// memory, leaving the buffer in an invalid state, but returns the
|
||||
|
|
@ -569,9 +557,6 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
|
|||
// accesses via the buffer returned from Release.
|
||||
absl::StatusOr<tsl::RCReference<RawSEDeviceMemory>> Release(
|
||||
bool wait_for_operations_to_complete);
|
||||
|
||||
absl::StatusOr<std::unique_ptr<PjRtBuffer>> DonateWithControlDependency(
|
||||
Future<> dependency) override;
|
||||
};
|
||||
|
||||
// Allocates the device buffers for a buffer that will be used as the
|
||||
|
|
|
|||
3
third_party/xla/xla/pjrt/se_raw_buffer.cc
vendored
3
third_party/xla/xla/pjrt/se_raw_buffer.cc
vendored
|
|
@ -107,8 +107,9 @@ void PjRtStreamExecutorDeviceEventPromise::SetFromSEEvent(
|
|||
event.AndThen([event = event_, original_event = event]() {
|
||||
if (auto* error = original_event.GetErrorIfPresent()) {
|
||||
event.SetError(*error);
|
||||
} else {
|
||||
event.SetStateConcrete();
|
||||
}
|
||||
event.SetStateConcrete();
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,21 +25,26 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/future.h"
|
||||
#include "xla/pjrt/abstract_tracked_device_buffer.h"
|
||||
#include "xla/pjrt/buffer_sequencing_event.h"
|
||||
#include "xla/pjrt/device_event.h"
|
||||
#include "xla/pjrt/event_pool.h"
|
||||
#include "xla/pjrt/local_device_state.h"
|
||||
#include "xla/pjrt/pjrt_client.h"
|
||||
#include "xla/pjrt/pjrt_common.h"
|
||||
#include "xla/pjrt/pjrt_stream_executor_client.h"
|
||||
#include "xla/pjrt/se_raw_buffer.h"
|
||||
#include "xla/service/shaped_buffer.h"
|
||||
#include "xla/shape.h"
|
||||
|
|
@ -47,10 +52,13 @@ limitations under the License.
|
|||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/stream_executor/device_memory_allocator.h"
|
||||
#include "xla/stream_executor/event.h"
|
||||
#include "xla/tsl/concurrency/async_value.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
#include "xla/tsl/concurrency/ref_count.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
#include "xla/tsl/platform/status.h"
|
||||
#include "xla/tsl/platform/threadpool.h"
|
||||
#include "tsl/platform/casts.h"
|
||||
#include "tsl/profiler/lib/connected_traceme.h"
|
||||
#include "tsl/profiler/lib/context_types.h"
|
||||
|
||||
|
|
@ -207,6 +215,78 @@ void TrackedDeviceBuffer::AddUsageEvent(BufferSequencingEventRef event,
|
|||
usage_events_.push_back({event, reference_held});
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<AbstractTrackedDeviceBuffer>>
|
||||
TrackedDeviceBuffer::CloneWithControlDependency(PjRtMemorySpace* memory_space,
|
||||
Future<> dependency) {
|
||||
auto* se_client =
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(memory_space->client());
|
||||
|
||||
// Copy all the data in the existing tracked_buffer.
|
||||
const auto& original_definition_events = definition_events();
|
||||
absl::InlinedVector<BufferSequencingEventRef, 4> definition_events;
|
||||
|
||||
auto definition_event_for_status =
|
||||
BufferSequencingEvent::Create(se_client->thread_pool());
|
||||
// definition_event_for_status must be the first one so that it blocks other
|
||||
// actions like D2H transfer from execution before the buffer is ready.
|
||||
definition_events.push_back(definition_event_for_status);
|
||||
definition_events.insert(definition_events.end(),
|
||||
original_definition_events.begin(),
|
||||
original_definition_events.end());
|
||||
|
||||
auto new_device_buffer = std::make_unique<TrackedDeviceBuffer>(
|
||||
device_, device_memory(), std::move(definition_events));
|
||||
|
||||
auto* device = tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||
memory_space->devices()[0]);
|
||||
LocalDeviceState* local_device = device->local_device_state();
|
||||
dependency.OnReady(
|
||||
[definition_event_for_status = std::move(definition_event_for_status),
|
||||
local_device, client = se_client](absl::Status status) mutable {
|
||||
// Forward the absl::Status from the supplied dependency to the
|
||||
// definition event.
|
||||
if (!status.ok()) {
|
||||
client->SetEventAsError(definition_event_for_status, status);
|
||||
return;
|
||||
}
|
||||
auto stream = local_device->BorrowStreamFromPool();
|
||||
TF_CHECK_OK(client->AllocateAndRecordEvent(definition_event_for_status,
|
||||
local_device, stream.get()));
|
||||
local_device->ReturnStreamToPool(std::move(stream));
|
||||
});
|
||||
return new_device_buffer;
|
||||
}
|
||||
|
||||
Future<> TrackedDeviceBuffer::GetReadyFuture(PjRtMemorySpace* memory_space) {
|
||||
auto [promise, future] = Future<>::MakePromise();
|
||||
std::vector<tsl::RCReference<tsl::AsyncValue>> definition_events;
|
||||
definition_events.reserve(definition_events_.size());
|
||||
for (const auto& event : definition_events_) {
|
||||
definition_events.push_back(event.CopyRCRef());
|
||||
}
|
||||
absl::Span<tsl::RCReference<tsl::AsyncValue> const> definition_events_span =
|
||||
definition_events;
|
||||
tsl::RunWhenReady(
|
||||
definition_events_span,
|
||||
[promise = std::move(promise),
|
||||
definition_events = std::move(definition_events)]() mutable {
|
||||
for (auto& event : definition_events) {
|
||||
if (const absl::Status* error = event->GetErrorIfPresent()) {
|
||||
promise.Set(*error);
|
||||
return;
|
||||
}
|
||||
}
|
||||
promise.Set();
|
||||
});
|
||||
return future;
|
||||
}
|
||||
|
||||
void TrackedDeviceBuffer::Delete(PjRtMemorySpace* memory_space) {
|
||||
std::unique_ptr<TrackedDeviceBuffer> device_buffer(this);
|
||||
// All events already hold onto refs to the buffer to ensure liveness so there
|
||||
// is no work to do.
|
||||
}
|
||||
|
||||
TrackedDeviceBuffer::StreamAndEventContainer
|
||||
TrackedDeviceBuffer::LockUseAndTransferUsageEvents() {
|
||||
CHECK(in_use_);
|
||||
|
|
|
|||
10
third_party/xla/xla/pjrt/tracked_device_buffer.h
vendored
10
third_party/xla/xla/pjrt/tracked_device_buffer.h
vendored
|
|
@ -180,9 +180,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
|||
|
||||
void AddUsageEvent(tsl::RCReference<PjRtDeviceEvent> event) override;
|
||||
|
||||
void Delete(PjRtMemorySpace* memory_space) override {
|
||||
LOG(FATAL) << "Implement";
|
||||
}
|
||||
void Delete(PjRtMemorySpace* memory_space) override;
|
||||
|
||||
absl::Status WaitUntilBufferReadyOnStream(std::intptr_t stream) override {
|
||||
for (const BufferSequencingEventRef& event : definition_events()) {
|
||||
|
|
@ -191,6 +189,12 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<AbstractTrackedDeviceBuffer>>
|
||||
CloneWithControlDependency(PjRtMemorySpace* memory_space,
|
||||
Future<> dependency) override;
|
||||
|
||||
Future<> GetReadyFuture(PjRtMemorySpace* memory_space) override;
|
||||
|
||||
private:
|
||||
PjRtDevice* device_;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user