PiperOrigin-RevId: 819777372
This commit is contained in:
Peter Hawkins 2025-10-15 08:57:03 -07:00 committed by TensorFlower Gardener
parent 4626ec956f
commit 009d8fdbf4
4 changed files with 235 additions and 117 deletions

View File

@ -1404,6 +1404,230 @@ void PjRtStreamExecutorBuffer::ConvertUsageHold(TrackedDeviceBuffer* buffer,
DecrementUsage();
}
Future<> PjRtStreamExecutorBuffer::LazyToLiteral(
absl::AnyInvocable<Future<MutableLiteralBase*>() &&> generator) {
auto buffer = std::move(generator)();
return ToLiteralHelper(std::move(buffer));
}
Future<> PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal) {
return ToLiteralHelper(Future<MutableLiteralBase*>(literal));
}
Future<> PjRtStreamExecutorBuffer::ToLiteralHelper(
Future<MutableLiteralBase*> literal) {
VLOG(3) << "PjRtStreamExecutorBuffer::ToLiteral";
auto* se_device = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device());
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client());
LocalDeviceState* local_device = se_device->local_device_state();
se::Stream* stream = local_device->GetDeviceToHostStream();
auto device_buffer = GetBufferWithUsageHold();
if (!device_buffer.ok()) {
return Future<>(
InvalidArgument("ToLiteral() called on deleted or donated buffer: %s",
device_buffer.status().ToString()));
}
auto [promise, future] = Future<>::MakePromise();
auto usage_event = BufferSequencingEvent::Create(se_client->thread_pool());
TransferManager* transfer_manager =
se_client->client()->backend().transfer_manager();
auto device_memory = device_buffer->device_memory();
auto definition_events = device_buffer->definition_events();
auto first_definition_event = definition_events[0];
// When using the ComputeSynchronized allocation model, retain a
// reference to the device_buffer until the copy completes, to
// ensure that the buffer isn't deleted or donated while it is still
// in use. The choice of retaining a reference at the host is a
// heuristic; the alternative is to ensure, before freeing the
// buffer, that the compute stream is synchronized past the
// transfer, but it seems better to hold onto the buffer too long
// than to stall the compute stream, particularly since the
// overwhelmingly common use case of CopyToHostAsync will hold onto
// the reference long enough to read the buffer in a subsequent call
// to ToLiteral.
device_buffer.ConvertUsageHold(stream, usage_event, /*reference_held=*/true);
auto [literal_and_transpose_promise, literal_and_transpose_future] =
Future<std::pair<MutableLiteralBase*,
std::shared_ptr<TransposePlan>>>::MakePromise();
literal.OnReady(
[client = se_client, on_device_shape{on_device_shape()},
promise = std::move(literal_and_transpose_promise)](
const absl::StatusOr<MutableLiteralBase*>& value) mutable {
if (!value.ok()) {
promise.Set(value.status());
return;
}
MutableLiteralBase* literal = *std::move(value);
std::shared_ptr<TransposePlan> transpose;
if (on_device_shape.IsArray()) {
xla::Layout literal_layout;
if (literal->shape().has_layout()) {
literal_layout = literal->shape().layout();
} else {
literal_layout = LayoutUtil::MakeDescendingLayout(
on_device_shape.dimensions().size());
}
if (on_device_shape.layout() != literal_layout) {
absl::InlinedVector<int64_t, 4> byte_strides(
on_device_shape.dimensions().size());
absl::Status s = ShapeUtil::ByteStrides(
on_device_shape, absl::MakeSpan(byte_strides));
if (!s.ok()) {
promise.Set(s);
return;
}
absl::Span<const int64_t> dims = on_device_shape.dimensions();
absl::InlinedVector<int64_t, 4> permutation(dims.size());
absl::c_reverse_copy(literal_layout.minor_to_major(),
permutation.begin());
TransposePlan::Options options;
options.elem_size_in_bytes =
primitive_util::ByteWidth(on_device_shape.element_type());
options.dims = on_device_shape.dimensions();
options.permutation = permutation;
options.input_layout = TransposePlan::Striding{byte_strides};
{
absl::MutexLock lock(&client->transpose_mu_);
absl::StatusOr<std::shared_ptr<TransposePlan>> t =
client->transpose_cache_.GetOrCreate(options);
if (!t.ok()) {
promise.Set(t.status());
return;
}
transpose = *std::move(t);
}
}
}
promise.Set(std::make_pair(literal, std::move(transpose)));
});
auto async_to_literal = [client = se_client, usage_event,
device_memory = std::move(device_memory),
definition_events = std::move(definition_events),
stream, device = se_device,
transfer_manager = std::move(transfer_manager),
on_device_shape{on_device_shape()},
literal_and_transpose =
std::move(literal_and_transpose_future),
promise = std::move(promise).ToShared(),
local_device]() mutable {
absl::StatusOr<EventPool::Handle> event_or =
local_device->event_pool().AllocateEvent(stream->parent());
if (!event_or.ok()) {
promise->Set(event_or.status());
return;
}
absl::Status defined_status = definition_events[0]->GetDefinedStatus();
if (!defined_status.ok()) {
promise->Set(defined_status);
return;
}
literal_and_transpose.OnReady(
[client, usage_event = std::move(usage_event),
device_memory = std::move(device_memory),
definition_events = std::move(definition_events),
stream = std::move(stream), device,
transfer_manager = std::move(transfer_manager),
on_device_shape = std::move(on_device_shape),
promise = std::move(promise), local_device = std::move(local_device),
event_or = std::move(event_or)](
const absl::StatusOr<
std::pair<MutableLiteralBase*, std::shared_ptr<TransposePlan>>>&
value) mutable {
if (!value.ok()) {
promise->Set(value.status());
return;
}
auto [literal, transpose] = *std::move(value);
WaitForBufferDefinitionEventsOnStream(
absl::MakeSpan(definition_events), stream);
ShapedBuffer shaped_buffer =
device_memory->AsShapedBuffer(device, on_device_shape);
GenericTransferManager::LiteralFromDeviceMetadata transfer_metadata;
// We never call device functions from the `done` callback.
transfer_metadata.callback_is_host_callback_safe = true;
TransferManager::TransferMetadata* transfer_metadata_ptr =
(dynamic_cast<GenericTransferManager*>(transfer_manager) !=
nullptr)
? &transfer_metadata
: nullptr;
if (transpose) {
// Copy the device buffer to a temporary literal with descending
// layout and transpose to the requested layout.
Shape stage_shape = literal->shape();
*stage_shape.mutable_layout() = LayoutUtil::MakeDescendingLayout(
stage_shape.dimensions().size());
auto staged = std::make_shared<Literal>(stage_shape);
transfer_manager->TransferLiteralFromDevice(
stream, shaped_buffer, staged.get(),
[transpose = std::move(transpose), promise, staged,
literal = std::move(literal)](absl::Status status) mutable {
if (status.ok()) {
transpose->Execute(staged->untyped_data(),
literal->untyped_data());
}
promise->Set(std::move(status));
},
transfer_metadata_ptr);
} else {
transfer_manager->TransferLiteralFromDevice(
stream, shaped_buffer, literal,
[promise](absl::Status status) mutable {
promise->Set(std::move(status));
},
transfer_metadata_ptr);
}
client->ThenRecordEvent(usage_event, local_device,
std::move(event_or).value(), stream);
absl::Status defined_status =
local_device->ThenRelease(stream, device_memory);
if (!defined_status.ok()) {
promise->Set(defined_status);
}
});
};
first_definition_event->ExecuteOrAddToFutureTasks(
"async_to_literal", std::move(async_to_literal));
return FutureHelpers::WithProfiling(
std::move(future),
/*on_block_start=*/
[]() {
tsl::profiler::TraceMeProducer traceme(
"PjRtStreamExecutorBuffer::ToLiteral");
VLOG(3) << "PjRtStreamExecutorBuffer::ToLiteral";
return FutureHelpers::ProfilingKeys(
{/*traceme_context_id =*/traceme.GetContextId()});
},
/*on_block_end=*/
[](FutureHelpers::ProfilingKeys keys) {
tsl::profiler::TraceMeConsumer traceme(
"PjRtStreamExecutorBuffer::ToLiteral", keys.traceme_context_id);
});
}
absl::StatusOr<size_t> PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes()
const {
absl::MutexLock lock(&mu_);

View File

@ -600,6 +600,11 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
using PjRtBuffer::ToLiteralSync;
Future<> ToLiteral(MutableLiteralBase* literal) override;
Future<> LazyToLiteral(
absl::AnyInvocable<Future<MutableLiteralBase*>() &&> generator) override;
absl::StatusOr<size_t> GetOnDeviceSizeInBytes() const override;
Future<> CopyRawToHost(void* dst, int64_t offset,
@ -684,6 +689,8 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
const TrackedDeviceBuffer& src_device_buffer);
absl::StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDeviceMemorySpace(
PjRtDevice* dst_device, PjRtMemorySpace* dst_memory_space = nullptr);
Future<> ToLiteralHelper(Future<MutableLiteralBase*> literal);
};
// Allocates the device buffers for a buffer that will be used as the

View File

@ -16,22 +16,13 @@ limitations under the License.
#include "xla/pjrt/se_raw_buffer.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/future.h"
#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/pjrt/buffer_sequencing_event.h"
#include "xla/pjrt/device_event.h"
#include "xla/pjrt/local_device_state.h"
@ -40,10 +31,6 @@ limitations under the License.
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/raw_buffer.h"
#include "xla/pjrt/tracked_device_buffer.h"
#include "xla/pjrt/transpose.h"
#include "xla/primitive_util.h"
#include "xla/service/generic_transfer_manager.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream.h"
#include "xla/tsl/concurrency/async_value_ref.h"
@ -214,107 +201,9 @@ void PjRtStreamExecutorRawBuffer::ReadDynamicShape(
void PjRtStreamExecutorRawBuffer::CopyToLiteralAsync(
Promise<> promise, tsl::RCReference<PjRtDeviceEventPromise> device_promise,
MutableLiteralBase* literal, xla::Shape shape) {
auto usage_event = BufferSequencingEvent::Create(client_->thread_pool());
client_->async_work_runner()->Schedule(
[usage_event, local_device = local_device_,
on_device_shape = std::move(shape), promise = std::move(promise),
literal, client = client_, memory_space = memory_space_,
device_buffer = device_buffer_]() mutable {
std::shared_ptr<TransposePlan> transpose;
se::Stream* stream = local_device->GetDeviceToHostStream();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
if (on_device_shape.IsArray()) {
xla::Layout literal_layout;
if (literal->shape().has_layout()) {
literal_layout = literal->shape().layout();
} else {
literal_layout = LayoutUtil::MakeDescendingLayout(
on_device_shape.dimensions().size());
}
if (on_device_shape.layout() != literal_layout) {
absl::InlinedVector<int64_t, 4> byte_strides(
on_device_shape.dimensions().size());
absl::Status s = ShapeUtil::ByteStrides(
on_device_shape, absl::MakeSpan(byte_strides));
if (!s.ok()) {
promise.Set(s);
client->SetEventAsError(usage_event, s);
return;
}
absl::Span<const int64_t> dims = on_device_shape.dimensions();
absl::InlinedVector<int64_t, 4> permutation(dims.size());
absl::c_reverse_copy(literal_layout.minor_to_major(),
permutation.begin());
TransposePlan::Options options;
options.elem_size_in_bytes =
primitive_util::ByteWidth(on_device_shape.element_type());
options.dims = on_device_shape.dimensions();
options.permutation = permutation;
options.input_layout = TransposePlan::Striding{byte_strides};
{
absl::MutexLock lock(&client->transpose_mu_);
absl::StatusOr<std::shared_ptr<TransposePlan>> t =
client->transpose_cache_.GetOrCreate(options);
if (!t.ok()) {
promise.Set(t.status());
client->SetEventAsError(usage_event, t.status());
return;
}
transpose = *std::move(t);
}
}
}
ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(
memory_space->devices()[0], on_device_shape);
GenericTransferManager::LiteralFromDeviceMetadata transfer_metadata;
// We never call device functions from the `done` callback.
transfer_metadata.callback_is_host_callback_safe = true;
TransferManager::TransferMetadata* transfer_metadata_ptr =
(dynamic_cast<GenericTransferManager*>(transfer_manager) != nullptr)
? &transfer_metadata
: nullptr;
if (transpose) {
// Copy the device buffer to a temporary literal with descending
// layout and transpose to the requested layout.
Shape stage_shape = literal->shape();
*stage_shape.mutable_layout() =
LayoutUtil::MakeDescendingLayout(stage_shape.dimensions().size());
auto staged = std::make_shared<Literal>(stage_shape);
transfer_manager->TransferLiteralFromDevice(
stream, shaped_buffer, staged.get(),
[transpose = std::move(transpose),
promise = std::move(promise).ToShared(), staged,
literal = std::move(literal)](absl::Status status) mutable {
if (status.ok()) {
transpose->Execute(staged->untyped_data(),
literal->untyped_data());
}
promise->Set(std::move(status));
},
transfer_metadata_ptr);
} else {
transfer_manager->TransferLiteralFromDevice(
stream, shaped_buffer, literal,
[promise =
std::move(promise).ToShared()](absl::Status status) mutable {
promise->Set(std::move(status));
},
transfer_metadata_ptr);
}
CHECK_OK(local_device->AllocateAndRecordEvent(usage_event, stream));
});
device_promise->Set(
tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(std::move(usage_event)));
device_promise->SetError(
absl::UnimplementedError("Cannot CopyToLiteralAsync."));
promise.Set(absl::UnimplementedError("Cannot CopyToLiteralAsync."));
}
absl::StatusOr<tsl::RCReference<PjRtDeviceEvent>>

View File

@ -372,10 +372,8 @@ absl::StatusOr<PerDeviceLiteralVecType> FetchAndLogOutput(
TF_RET_CHECK(buffer->device() == output_buffers[i][0]->device())
<< "All outputs from a given vector of outputs should be for the "
"same device";
TF_ASSIGN_OR_RETURN(auto logical_shape,
buffer->logical_on_device_shape());
output_slice.emplace_back(
ShapeUtil::DeviceShapeToHostShape(logical_shape));
ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape()));
buffer->ToLiteral(&output_slice.back()).OnReady([&](absl::Status s) {
absl::MutexLock lock(mu);
--num_pending_transfers;