Rollforward with fixes of "Change RawSEDeviceMemory to be AsyncValueRef".

Reverts c7055c2e5b

PiperOrigin-RevId: 826608975
This commit is contained in:
Parker Schuh 2025-10-31 13:30:51 -07:00 committed by TensorFlower Gardener
parent d008dc3999
commit eef0661fc5
8 changed files with 41 additions and 41 deletions

View File

@ -180,7 +180,6 @@ xla_cc_test(
":pjrt_client", ":pjrt_client",
":pjrt_common", ":pjrt_common",
":pjrt_stream_executor_client", ":pjrt_stream_executor_client",
"//xla:future",
"//xla:literal", "//xla:literal",
"//xla:literal_util", "//xla:literal_util",
"//xla:shape_util", "//xla:shape_util",
@ -192,6 +191,7 @@ xla_cc_test(
"//xla/hlo/testlib:test", "//xla/hlo/testlib:test",
"//xla/service:cpu_plugin", "//xla/service:cpu_plugin",
"//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:device_memory_allocator",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/concurrency:ref_count", "//xla/tsl/concurrency:ref_count",
"@com_google_absl//absl/log", "@com_google_absl//absl/log",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",

View File

@ -1367,7 +1367,7 @@ StreamExecutorGpuClient::RunAsync(
std::set<se::DeviceMemoryBase> buffers_in_result; std::set<se::DeviceMemoryBase> buffers_in_result;
xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results( xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results(
gpu_exec->result_shape()); gpu_exec->result_shape());
for (auto& p : results) { for (auto& p : results) {
@ -1452,7 +1452,7 @@ StreamExecutorGpuClient::RunAsync(
TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result, TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result,
gpu_exec->GetAllocations())); gpu_exec->GetAllocations()));
std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released; std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released;
// Free allocations for arguments. // Free allocations for arguments.
for (ShapeTree<PjRtStreamExecutorExecutionInput>& input : arguments) { for (ShapeTree<PjRtStreamExecutorExecutionInput>& input : arguments) {

View File

@ -789,7 +789,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error,
// Create an empty buffer. // Create an empty buffer.
auto dummy_device_buffer = std::make_unique<TrackedDeviceBuffer>( auto dummy_device_buffer = std::make_unique<TrackedDeviceBuffer>(
device, tsl::RCReference<RawSEDeviceMemory>(), device, tsl::AsyncValueRef<RawSEDeviceMemory>(),
absl::MakeSpan(&definition_event, 1)); absl::MakeSpan(&definition_event, 1));
return std::make_unique<CommonPjRtBufferImpl>( return std::make_unique<CommonPjRtBufferImpl>(
@ -1168,13 +1168,13 @@ MakeTupleHelper(PjRtStreamExecutorClient* client,
// Converts a ScopedShapedBuffer returned from an execution into a // Converts a ScopedShapedBuffer returned from an execution into a
// PjRtBuffer. // PjRtBuffer.
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper( absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer, ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtClient* client, BufferSequencingEventRef definition_event, PjRtClient* client,
PjRtDevice* device, LocalDeviceState* local_device) { PjRtDevice* device, LocalDeviceState* local_device) {
if (result_buffer.shape().IsTuple()) { if (result_buffer.shape().IsTuple()) {
return absl::InternalError("OutputBufferHelper called on tuple."); return absl::InternalError("OutputBufferHelper called on tuple.");
} }
absl::InlinedVector<tsl::RCReference<RawSEDeviceMemory>, 1> buffers; absl::InlinedVector<tsl::AsyncValueRef<RawSEDeviceMemory>, 1> buffers;
for (auto& item : result_buffer) { for (auto& item : result_buffer) {
buffers.push_back(std::move(item.second)); buffers.push_back(std::move(item.second));
} }
@ -1641,7 +1641,7 @@ PjRtStreamExecutorClient::RunAsync(
ExecutionOutput output, ExecutionOutput output,
exec.RunAsync(std::move(xla_arguments), std::move(run_options))); exec.RunAsync(std::move(xla_arguments), std::move(run_options)));
ScopedShapedBuffer ssb = output.ConsumeResult(); ScopedShapedBuffer ssb = output.ConsumeResult();
xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results( xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results(
ssb.on_device_shape()); ssb.on_device_shape());
auto it = results.begin(); auto it = results.begin();
se::DeviceMemoryAllocator* allocator = ssb.memory_allocator(); se::DeviceMemoryAllocator* allocator = ssb.memory_allocator();
@ -1672,7 +1672,7 @@ PjRtStreamExecutorClient::RunAsync(
// converted on success. // converted on success.
// When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id` // When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id`
// to initialize `run_options`. // to initialize `run_options`.
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>> absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
PjRtStreamExecutorLoadedExecutable::EnqueueExecution( PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition, absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
int executable_idx, const RunId& run_id, const ExecuteOptions& options, int executable_idx, const RunId& run_id, const ExecuteOptions& options,
@ -1930,7 +1930,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers( PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options, int device_ordinal, const ExecuteOptions& options,
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer, ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtDevice* device, BufferSequencingEventRef definition_event, PjRtDevice* device,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const { std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
tsl::profiler::TraceMe traceme("MakeOutputBuffers"); tsl::profiler::TraceMe traceme("MakeOutputBuffers");
@ -1943,7 +1943,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
// in result_buffer. // in result_buffer.
for (int i = 0; i < tuple_count; ++i) { for (int i = 0; i < tuple_count; ++i) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> tuple_buffer, ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> tuple_buffer,
result_buffer.SubShapeTree({i})); result_buffer.SubShapeTree({i}));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer, std::unique_ptr<PjRtBuffer> buffer,
@ -2050,7 +2050,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks; std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers; std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
device_buffers.reserve(argument_handles.size()); device_buffers.reserve(argument_handles.size());
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>> absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
result_buffer_or_status = result_buffer_or_status =
EnqueueExecution(argument_handles, replica, partition, executable_idx, EnqueueExecution(argument_handles, replica, partition, executable_idx,
run_id, options, device, &device_buffers, run_id, options, device, &device_buffers,
@ -2061,7 +2061,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
<< " failed: " << result_buffer_or_status.status(); << " failed: " << result_buffer_or_status.status();
return result_buffer_or_status.status(); return result_buffer_or_status.status();
} }
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer = ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer =
std::move(result_buffer_or_status).value(); std::move(result_buffer_or_status).value();
LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
@ -2081,14 +2081,14 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
} }
return definition_event_or.status(); return definition_event_or.status();
} }
std::vector<tsl::RCReference<RawSEDeviceMemory>> leaves_to_release; std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> leaves_to_release;
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
leaves_to_release.reserve(result_buffer.leaf_count()); leaves_to_release.reserve(result_buffer.leaf_count());
for (auto& node : result_buffer.leaves()) { for (auto& node : result_buffer.leaves()) {
leaves_to_release.push_back(node.second); leaves_to_release.push_back(node.second);
} }
} }
std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release; std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release;
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>( auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute"); *definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(

View File

@ -83,13 +83,13 @@ struct PjRtStreamExecutorExecutionInput {
// Donation is not complete until ReleaseDeviceMemory() is called on the // Donation is not complete until ReleaseDeviceMemory() is called on the
// TrackedDeviceBuffer that provides buf. // TrackedDeviceBuffer that provides buf.
bool is_donated; bool is_donated;
tsl::RCReference<RawSEDeviceMemory> buf; tsl::AsyncValueRef<RawSEDeviceMemory> buf;
}; };
struct PjRtStreamExecutorExecutionOutput { struct PjRtStreamExecutorExecutionOutput {
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result; ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result;
// Donated inputs which must be freed. // Donated inputs which must be freed.
std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released; std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released;
// For PjRtStreamExecutorClient implementations that // For PjRtStreamExecutorClient implementations that
// use OwningDeviceMemory for donated inputs. // use OwningDeviceMemory for donated inputs.
std::vector<se::OwningDeviceMemory> se_to_be_released; std::vector<se::OwningDeviceMemory> se_to_be_released;
@ -672,7 +672,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers, absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const; absl::flat_hash_set<BufferSequencingEvent*>& events) const;
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>> absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
EnqueueExecution( EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id, int partition, int executable_idx, const RunId& run_id,
@ -684,7 +684,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeOutputBuffers( MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options, int device_ordinal, const ExecuteOptions& options,
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer, ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtDevice* device, BufferSequencingEventRef definition_event, PjRtDevice* device,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const; std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;

View File

@ -91,10 +91,10 @@ class PjRtStreamExecutorDeviceEventPromise : public PjRtDeviceEventPromise {
class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer { class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
public: public:
PjRtStreamExecutorRawBuffer(PjRtStreamExecutorClient* client, PjRtStreamExecutorRawBuffer(
PjRtMemorySpace* memory_space, PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space,
LocalDeviceState* local_device, LocalDeviceState* local_device,
tsl::RCReference<RawSEDeviceMemory> device_buffer) tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer)
: client_(client), : client_(client),
memory_space_(memory_space), memory_space_(memory_space),
local_device_(local_device), local_device_(local_device),
@ -104,7 +104,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
LocalDeviceState* local_device() const { return local_device_; } LocalDeviceState* local_device() const { return local_device_; }
const tsl::RCReference<RawSEDeviceMemory>& device_buffer() const { const tsl::AsyncValueRef<RawSEDeviceMemory>& device_buffer() const {
return device_buffer_; return device_buffer_;
} }
@ -150,7 +150,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
PjRtStreamExecutorClient* client_; PjRtStreamExecutorClient* client_;
PjRtMemorySpace* memory_space_; PjRtMemorySpace* memory_space_;
LocalDeviceState* local_device_; LocalDeviceState* local_device_;
tsl::RCReference<RawSEDeviceMemory> device_buffer_; tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer_;
}; };
} // namespace xla } // namespace xla

View File

@ -119,11 +119,11 @@ class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory {
size_t sync_point_ = std::numeric_limits<size_t>::max(); size_t sync_point_ = std::numeric_limits<size_t>::max();
}; };
tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::Create( tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::Create(
se::DeviceMemoryBase value, LocalDeviceState* local_device, se::DeviceMemoryBase value, LocalDeviceState* local_device,
se::DeviceMemoryAllocator* allocator) { se::DeviceMemoryAllocator* allocator) {
return tsl::MakeRef<AllocatedRawSEDeviceMemory>(value, local_device, return tsl::MakeAvailableAsyncValueRef<AllocatedRawSEDeviceMemory>(
allocator); value, local_device, allocator);
} }
class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
@ -143,11 +143,11 @@ class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
absl::AnyInvocable<void() &&> on_delete_callback_; absl::AnyInvocable<void() &&> on_delete_callback_;
}; };
tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign( tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign(
se::DeviceMemoryBase value, se::DeviceMemoryBase value,
absl::AnyInvocable<void() &&> on_delete_callback) { absl::AnyInvocable<void() &&> on_delete_callback) {
return tsl::MakeRef<ForeignRawSEDeviceMemory>(value, return tsl::MakeAvailableAsyncValueRef<ForeignRawSEDeviceMemory>(
std::move(on_delete_callback)); value, std::move(on_delete_callback));
} }
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
@ -167,7 +167,7 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
} }
TrackedDeviceBuffer::TrackedDeviceBuffer( TrackedDeviceBuffer::TrackedDeviceBuffer(
PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory, PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory,
absl::Span<const BufferSequencingEventRef> definition_events) absl::Span<const BufferSequencingEventRef> definition_events)
: device_(device), : device_(device),
device_memory_(std::move(device_memory)), device_memory_(std::move(device_memory)),
@ -178,7 +178,7 @@ TrackedDeviceBuffer::TrackedDeviceBuffer(
TrackedDeviceBuffer::~TrackedDeviceBuffer() = default; TrackedDeviceBuffer::~TrackedDeviceBuffer() = default;
void TrackedDeviceBuffer::ReleaseDeviceMemory() { void TrackedDeviceBuffer::ReleaseDeviceMemory() {
device_memory_ = tsl::RCReference<RawSEDeviceMemory>(); device_memory_ = tsl::AsyncValueRef<RawSEDeviceMemory>();
} }
void TrackedDeviceBuffer::ConfirmDonation() { void TrackedDeviceBuffer::ConfirmDonation() {

View File

@ -51,8 +51,7 @@ limitations under the License.
namespace xla { namespace xla {
// TODO(parkers): Implement PjRtRawBuffer API. class RawSEDeviceMemory {
class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
public: public:
explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {} explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {}
@ -70,10 +69,10 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
ShapedBuffer AsShapedBuffer(PjRtDevice* device, ShapedBuffer AsShapedBuffer(PjRtDevice* device,
const Shape& on_device_shape) const; const Shape& on_device_shape) const;
static tsl::RCReference<RawSEDeviceMemory> Create( static tsl::AsyncValueRef<RawSEDeviceMemory> Create(
se::DeviceMemoryBase value, LocalDeviceState* local_device, se::DeviceMemoryBase value, LocalDeviceState* local_device,
se::DeviceMemoryAllocator* allocator); se::DeviceMemoryAllocator* allocator);
static tsl::RCReference<RawSEDeviceMemory> CreateForeign( static tsl::AsyncValueRef<RawSEDeviceMemory> CreateForeign(
se::DeviceMemoryBase value, se::DeviceMemoryBase value,
absl::AnyInvocable<void() &&> on_delete_callback); absl::AnyInvocable<void() &&> on_delete_callback);
@ -130,7 +129,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
ExecutionInput* execution_input, ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const; se::DeviceMemoryAllocator* allocator) const;
const tsl::RCReference<RawSEDeviceMemory>& device_memory() const { const tsl::AsyncValueRef<RawSEDeviceMemory>& device_memory() const {
return device_memory_; return device_memory_;
} }
@ -168,7 +167,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
StreamAndEventContainer LockUseAndTransferUsageEvents(); StreamAndEventContainer LockUseAndTransferUsageEvents();
TrackedDeviceBuffer( TrackedDeviceBuffer(
PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory, PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory,
absl::Span<const BufferSequencingEventRef> definition_events); absl::Span<const BufferSequencingEventRef> definition_events);
~TrackedDeviceBuffer() override; ~TrackedDeviceBuffer() override;
@ -199,7 +198,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
PjRtDevice* device_; PjRtDevice* device_;
// Each host-side buffer may have several buffers on-device. // Each host-side buffer may have several buffers on-device.
tsl::RCReference<RawSEDeviceMemory> device_memory_; tsl::AsyncValueRef<RawSEDeviceMemory> device_memory_;
// Events that are triggered when the content of one or more buffers is ready // Events that are triggered when the content of one or more buffers is ready
// during multistream execution. May be nullptr, which is used in the // during multistream execution. May be nullptr, which is used in the

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "xla/shape_util.h" #include "xla/shape_util.h"
#include "xla/status_macros.h" #include "xla/status_macros.h"
#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/device_memory_allocator.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/concurrency/ref_count.h"
#include "xla/util.h" #include "xla/util.h"
#include "xla/xla_data.pb.h" #include "xla/xla_data.pb.h"
@ -83,7 +84,7 @@ class TestDevice : public PjRtDevice {
absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray( absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray(
const Shape& shape, LocalClient* client, PjRtDevice* device) { const Shape& shape, LocalClient* client, PjRtDevice* device) {
std::vector<tsl::RCReference<RawSEDeviceMemory>> device_buffers; std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> device_buffers;
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
client->backend().transfer_manager()->HostShapeToDeviceShape(shape), client->backend().transfer_manager()->HostShapeToDeviceShape(shape),
[&](const Shape& subshape, const ShapeIndex&) -> absl::Status { [&](const Shape& subshape, const ShapeIndex&) -> absl::Status {