From eef0661fc5074d6461cbf8e79acad5058b1aaef6 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 31 Oct 2025 13:30:51 -0700 Subject: [PATCH] Rollforward with fixes of "Change RawSEDeviceMemory to be AsyncValueRef". Reverts c7055c2e5bb6e70365de867dc0ac56e56689a122 PiperOrigin-RevId: 826608975 --- third_party/xla/xla/pjrt/BUILD | 2 +- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 4 ++-- .../xla/pjrt/pjrt_stream_executor_client.cc | 22 +++++++++---------- .../xla/pjrt/pjrt_stream_executor_client.h | 10 ++++----- third_party/xla/xla/pjrt/se_raw_buffer.h | 12 +++++----- .../xla/xla/pjrt/tracked_device_buffer.cc | 16 +++++++------- .../xla/xla/pjrt/tracked_device_buffer.h | 13 +++++------ .../xla/pjrt/tracked_device_buffer_test.cc | 3 ++- 8 files changed, 41 insertions(+), 41 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 8bfc6a903e0..1309d8309eb 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -180,7 +180,6 @@ xla_cc_test( ":pjrt_client", ":pjrt_common", ":pjrt_stream_executor_client", - "//xla:future", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -192,6 +191,7 @@ xla_cc_test( "//xla/hlo/testlib:test", "//xla/service:cpu_plugin", "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/log", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 738278d94f6..cb5c4950aab 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1367,7 +1367,7 @@ StreamExecutorGpuClient::RunAsync( std::set buffers_in_result; - xla::ShapeTree> results( + xla::ShapeTree> results( gpu_exec->result_shape()); for (auto& p : results) { @@ -1452,7 +1452,7 @@ StreamExecutorGpuClient::RunAsync( TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result, gpu_exec->GetAllocations())); - std::vector> to_be_released; + std::vector> to_be_released; // Free allocations for arguments. for (ShapeTree& input : arguments) { diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 2a1c57e5960..a3877580f46 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -789,7 +789,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error, // Create an empty buffer. auto dummy_device_buffer = std::make_unique( - device, tsl::RCReference(), + device, tsl::AsyncValueRef(), absl::MakeSpan(&definition_event, 1)); return std::make_unique( @@ -1168,13 +1168,13 @@ MakeTupleHelper(PjRtStreamExecutorClient* client, // Converts a ScopedShapedBuffer returned from an execution into a // PjRtBuffer. absl::StatusOr> OutputBufferHelper( - ShapeTree> result_buffer, + ShapeTree> result_buffer, BufferSequencingEventRef definition_event, PjRtClient* client, PjRtDevice* device, LocalDeviceState* local_device) { if (result_buffer.shape().IsTuple()) { return absl::InternalError("OutputBufferHelper called on tuple."); } - absl::InlinedVector, 1> buffers; + absl::InlinedVector, 1> buffers; for (auto& item : result_buffer) { buffers.push_back(std::move(item.second)); } @@ -1641,7 +1641,7 @@ PjRtStreamExecutorClient::RunAsync( ExecutionOutput output, exec.RunAsync(std::move(xla_arguments), std::move(run_options))); ScopedShapedBuffer ssb = output.ConsumeResult(); - xla::ShapeTree> results( + xla::ShapeTree> results( ssb.on_device_shape()); auto it = results.begin(); se::DeviceMemoryAllocator* allocator = ssb.memory_allocator(); @@ -1672,7 +1672,7 @@ PjRtStreamExecutorClient::RunAsync( // converted on success. // When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id` // to initialize `run_options`. -absl::StatusOr>> +absl::StatusOr>> PjRtStreamExecutorLoadedExecutable::EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, const ExecuteOptions& options, @@ -1930,7 +1930,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution( absl::StatusOr>> PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers( int device_ordinal, const ExecuteOptions& options, - ShapeTree> result_buffer, + ShapeTree> result_buffer, BufferSequencingEventRef definition_event, PjRtDevice* device, std::vector>& compute_callbacks) const { tsl::profiler::TraceMe traceme("MakeOutputBuffers"); @@ -1943,7 +1943,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers( // in result_buffer. for (int i = 0; i < tuple_count; ++i) { TF_ASSIGN_OR_RETURN( - ShapeTree> tuple_buffer, + ShapeTree> tuple_buffer, result_buffer.SubShapeTree({i})); TF_ASSIGN_OR_RETURN( std::unique_ptr buffer, @@ -2050,7 +2050,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( std::vector> compute_callbacks; std::vector device_buffers; device_buffers.reserve(argument_handles.size()); - absl::StatusOr>> + absl::StatusOr>> result_buffer_or_status = EnqueueExecution(argument_handles, replica, partition, executable_idx, run_id, options, device, &device_buffers, @@ -2061,7 +2061,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( << " failed: " << result_buffer_or_status.status(); return result_buffer_or_status.status(); } - ShapeTree> result_buffer = + ShapeTree> result_buffer = std::move(result_buffer_or_status).value(); LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); @@ -2081,14 +2081,14 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( } return definition_event_or.status(); } - std::vector> leaves_to_release; + std::vector> leaves_to_release; if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { leaves_to_release.reserve(result_buffer.leaf_count()); for (auto& node : result_buffer.leaves()) { leaves_to_release.push_back(node.second); } } - std::vector> buffers_to_release; + std::vector> buffers_to_release; auto definition_event = tsl::MakeRef( *definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute"); TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 16935a9a521..0f2999338d0 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -83,13 +83,13 @@ struct PjRtStreamExecutorExecutionInput { // Donation is not complete until ReleaseDeviceMemory() is called on the // TrackedDeviceBuffer that provides buf. bool is_donated; - tsl::RCReference buf; + tsl::AsyncValueRef buf; }; struct PjRtStreamExecutorExecutionOutput { - ShapeTree> result; + ShapeTree> result; // Donated inputs which must be freed. - std::vector> to_be_released; + std::vector> to_be_released; // For PjRtStreamExecutorClient implementations that // use OwningDeviceMemory for donated inputs. std::vector se_to_be_released; @@ -672,7 +672,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { absl::Span device_buffers, absl::flat_hash_set& events) const; - absl::StatusOr>> + absl::StatusOr>> EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, @@ -684,7 +684,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { virtual absl::StatusOr>> MakeOutputBuffers( int device_ordinal, const ExecuteOptions& options, - ShapeTree> result_buffer, + ShapeTree> result_buffer, BufferSequencingEventRef definition_event, PjRtDevice* device, std::vector>& compute_callbacks) const; diff --git a/third_party/xla/xla/pjrt/se_raw_buffer.h b/third_party/xla/xla/pjrt/se_raw_buffer.h index 85fa80992f3..3ab487a645f 100644 --- a/third_party/xla/xla/pjrt/se_raw_buffer.h +++ b/third_party/xla/xla/pjrt/se_raw_buffer.h @@ -91,10 +91,10 @@ class PjRtStreamExecutorDeviceEventPromise : public PjRtDeviceEventPromise { class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer { public: - PjRtStreamExecutorRawBuffer(PjRtStreamExecutorClient* client, - PjRtMemorySpace* memory_space, - LocalDeviceState* local_device, - tsl::RCReference device_buffer) + PjRtStreamExecutorRawBuffer( + PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space, + LocalDeviceState* local_device, + tsl::AsyncValueRef device_buffer) : client_(client), memory_space_(memory_space), local_device_(local_device), @@ -104,7 +104,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer { LocalDeviceState* local_device() const { return local_device_; } - const tsl::RCReference& device_buffer() const { + const tsl::AsyncValueRef& device_buffer() const { return device_buffer_; } @@ -150,7 +150,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer { PjRtStreamExecutorClient* client_; PjRtMemorySpace* memory_space_; LocalDeviceState* local_device_; - tsl::RCReference device_buffer_; + tsl::AsyncValueRef device_buffer_; }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index ae198731505..b22378974a6 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -119,11 +119,11 @@ class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory { size_t sync_point_ = std::numeric_limits::max(); }; -tsl::RCReference RawSEDeviceMemory::Create( +tsl::AsyncValueRef RawSEDeviceMemory::Create( se::DeviceMemoryBase value, LocalDeviceState* local_device, se::DeviceMemoryAllocator* allocator) { - return tsl::MakeRef(value, local_device, - allocator); + return tsl::MakeAvailableAsyncValueRef( + value, local_device, allocator); } class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { @@ -143,11 +143,11 @@ class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { absl::AnyInvocable on_delete_callback_; }; -tsl::RCReference RawSEDeviceMemory::CreateForeign( +tsl::AsyncValueRef RawSEDeviceMemory::CreateForeign( se::DeviceMemoryBase value, absl::AnyInvocable on_delete_callback) { - return tsl::MakeRef(value, - std::move(on_delete_callback)); + return tsl::MakeAvailableAsyncValueRef( + value, std::move(on_delete_callback)); } ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( @@ -167,7 +167,7 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( } TrackedDeviceBuffer::TrackedDeviceBuffer( - PjRtDevice* device, tsl::RCReference device_memory, + PjRtDevice* device, tsl::AsyncValueRef device_memory, absl::Span definition_events) : device_(device), device_memory_(std::move(device_memory)), @@ -178,7 +178,7 @@ TrackedDeviceBuffer::TrackedDeviceBuffer( TrackedDeviceBuffer::~TrackedDeviceBuffer() = default; void TrackedDeviceBuffer::ReleaseDeviceMemory() { - device_memory_ = tsl::RCReference(); + device_memory_ = tsl::AsyncValueRef(); } void TrackedDeviceBuffer::ConfirmDonation() { diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.h b/third_party/xla/xla/pjrt/tracked_device_buffer.h index 24ed1db4501..062287ad773 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.h @@ -51,8 +51,7 @@ limitations under the License. namespace xla { -// TODO(parkers): Implement PjRtRawBuffer API. -class RawSEDeviceMemory : public tsl::ReferenceCounted { +class RawSEDeviceMemory { public: explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {} @@ -70,10 +69,10 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted { ShapedBuffer AsShapedBuffer(PjRtDevice* device, const Shape& on_device_shape) const; - static tsl::RCReference Create( + static tsl::AsyncValueRef Create( se::DeviceMemoryBase value, LocalDeviceState* local_device, se::DeviceMemoryAllocator* allocator); - static tsl::RCReference CreateForeign( + static tsl::AsyncValueRef CreateForeign( se::DeviceMemoryBase value, absl::AnyInvocable on_delete_callback); @@ -130,7 +129,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { ExecutionInput* execution_input, se::DeviceMemoryAllocator* allocator) const; - const tsl::RCReference& device_memory() const { + const tsl::AsyncValueRef& device_memory() const { return device_memory_; } @@ -168,7 +167,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { StreamAndEventContainer LockUseAndTransferUsageEvents(); TrackedDeviceBuffer( - PjRtDevice* device, tsl::RCReference device_memory, + PjRtDevice* device, tsl::AsyncValueRef device_memory, absl::Span definition_events); ~TrackedDeviceBuffer() override; @@ -199,7 +198,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { PjRtDevice* device_; // Each host-side buffer may have several buffers on-device. - tsl::RCReference device_memory_; + tsl::AsyncValueRef device_memory_; // Events that are triggered when the content of one or more buffers is ready // during multistream execution. May be nullptr, which is used in the diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc index c757f3c6cbb..bc35b8cc115 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -83,7 +84,7 @@ class TestDevice : public PjRtDevice { absl::StatusOr> MakeArray( const Shape& shape, LocalClient* client, PjRtDevice* device) { - std::vector> device_buffers; + std::vector> device_buffers; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( client->backend().transfer_manager()->HostShapeToDeviceShape(shape), [&](const Shape& subshape, const ShapeIndex&) -> absl::Status {