diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index cb5c4950aab..738278d94f6 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1367,7 +1367,7 @@ StreamExecutorGpuClient::RunAsync( std::set buffers_in_result; - xla::ShapeTree> results( + xla::ShapeTree> results( gpu_exec->result_shape()); for (auto& p : results) { @@ -1452,7 +1452,7 @@ StreamExecutorGpuClient::RunAsync( TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result, gpu_exec->GetAllocations())); - std::vector> to_be_released; + std::vector> to_be_released; // Free allocations for arguments. for (ShapeTree& input : arguments) { diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index a3877580f46..2a1c57e5960 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -789,7 +789,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error, // Create an empty buffer. auto dummy_device_buffer = std::make_unique( - device, tsl::AsyncValueRef(), + device, tsl::RCReference(), absl::MakeSpan(&definition_event, 1)); return std::make_unique( @@ -1168,13 +1168,13 @@ MakeTupleHelper(PjRtStreamExecutorClient* client, // Converts a ScopedShapedBuffer returned from an execution into a // PjRtBuffer. absl::StatusOr> OutputBufferHelper( - ShapeTree> result_buffer, + ShapeTree> result_buffer, BufferSequencingEventRef definition_event, PjRtClient* client, PjRtDevice* device, LocalDeviceState* local_device) { if (result_buffer.shape().IsTuple()) { return absl::InternalError("OutputBufferHelper called on tuple."); } - absl::InlinedVector, 1> buffers; + absl::InlinedVector, 1> buffers; for (auto& item : result_buffer) { buffers.push_back(std::move(item.second)); } @@ -1641,7 +1641,7 @@ PjRtStreamExecutorClient::RunAsync( ExecutionOutput output, exec.RunAsync(std::move(xla_arguments), std::move(run_options))); ScopedShapedBuffer ssb = output.ConsumeResult(); - xla::ShapeTree> results( + xla::ShapeTree> results( ssb.on_device_shape()); auto it = results.begin(); se::DeviceMemoryAllocator* allocator = ssb.memory_allocator(); @@ -1672,7 +1672,7 @@ PjRtStreamExecutorClient::RunAsync( // converted on success. // When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id` // to initialize `run_options`. -absl::StatusOr>> +absl::StatusOr>> PjRtStreamExecutorLoadedExecutable::EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, const ExecuteOptions& options, @@ -1930,7 +1930,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution( absl::StatusOr>> PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers( int device_ordinal, const ExecuteOptions& options, - ShapeTree> result_buffer, + ShapeTree> result_buffer, BufferSequencingEventRef definition_event, PjRtDevice* device, std::vector>& compute_callbacks) const { tsl::profiler::TraceMe traceme("MakeOutputBuffers"); @@ -1943,7 +1943,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers( // in result_buffer. for (int i = 0; i < tuple_count; ++i) { TF_ASSIGN_OR_RETURN( - ShapeTree> tuple_buffer, + ShapeTree> tuple_buffer, result_buffer.SubShapeTree({i})); TF_ASSIGN_OR_RETURN( std::unique_ptr buffer, @@ -2050,7 +2050,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( std::vector> compute_callbacks; std::vector device_buffers; device_buffers.reserve(argument_handles.size()); - absl::StatusOr>> + absl::StatusOr>> result_buffer_or_status = EnqueueExecution(argument_handles, replica, partition, executable_idx, run_id, options, device, &device_buffers, @@ -2061,7 +2061,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( << " failed: " << result_buffer_or_status.status(); return result_buffer_or_status.status(); } - ShapeTree> result_buffer = + ShapeTree> result_buffer = std::move(result_buffer_or_status).value(); LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); @@ -2081,14 +2081,14 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( } return definition_event_or.status(); } - std::vector> leaves_to_release; + std::vector> leaves_to_release; if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { leaves_to_release.reserve(result_buffer.leaf_count()); for (auto& node : result_buffer.leaves()) { leaves_to_release.push_back(node.second); } } - std::vector> buffers_to_release; + std::vector> buffers_to_release; auto definition_event = tsl::MakeRef( *definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute"); TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 0f2999338d0..16935a9a521 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -83,13 +83,13 @@ struct PjRtStreamExecutorExecutionInput { // Donation is not complete until ReleaseDeviceMemory() is called on the // TrackedDeviceBuffer that provides buf. bool is_donated; - tsl::AsyncValueRef buf; + tsl::RCReference buf; }; struct PjRtStreamExecutorExecutionOutput { - ShapeTree> result; + ShapeTree> result; // Donated inputs which must be freed. - std::vector> to_be_released; + std::vector> to_be_released; // For PjRtStreamExecutorClient implementations that // use OwningDeviceMemory for donated inputs. std::vector se_to_be_released; @@ -672,7 +672,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { absl::Span device_buffers, absl::flat_hash_set& events) const; - absl::StatusOr>> + absl::StatusOr>> EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, @@ -684,7 +684,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { virtual absl::StatusOr>> MakeOutputBuffers( int device_ordinal, const ExecuteOptions& options, - ShapeTree> result_buffer, + ShapeTree> result_buffer, BufferSequencingEventRef definition_event, PjRtDevice* device, std::vector>& compute_callbacks) const; diff --git a/third_party/xla/xla/pjrt/se_raw_buffer.h b/third_party/xla/xla/pjrt/se_raw_buffer.h index 3ab487a645f..85fa80992f3 100644 --- a/third_party/xla/xla/pjrt/se_raw_buffer.h +++ b/third_party/xla/xla/pjrt/se_raw_buffer.h @@ -91,10 +91,10 @@ class PjRtStreamExecutorDeviceEventPromise : public PjRtDeviceEventPromise { class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer { public: - PjRtStreamExecutorRawBuffer( - PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space, - LocalDeviceState* local_device, - tsl::AsyncValueRef device_buffer) + PjRtStreamExecutorRawBuffer(PjRtStreamExecutorClient* client, + PjRtMemorySpace* memory_space, + LocalDeviceState* local_device, + tsl::RCReference device_buffer) : client_(client), memory_space_(memory_space), local_device_(local_device), @@ -104,7 +104,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer { LocalDeviceState* local_device() const { return local_device_; } - const tsl::AsyncValueRef& device_buffer() const { + const tsl::RCReference& device_buffer() const { return device_buffer_; } @@ -150,7 +150,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer { PjRtStreamExecutorClient* client_; PjRtMemorySpace* memory_space_; LocalDeviceState* local_device_; - tsl::AsyncValueRef device_buffer_; + tsl::RCReference device_buffer_; }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index b22378974a6..ae198731505 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -119,11 +119,11 @@ class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory { size_t sync_point_ = std::numeric_limits::max(); }; -tsl::AsyncValueRef RawSEDeviceMemory::Create( +tsl::RCReference RawSEDeviceMemory::Create( se::DeviceMemoryBase value, LocalDeviceState* local_device, se::DeviceMemoryAllocator* allocator) { - return tsl::MakeAvailableAsyncValueRef( - value, local_device, allocator); + return tsl::MakeRef(value, local_device, + allocator); } class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { @@ -143,11 +143,11 @@ class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { absl::AnyInvocable on_delete_callback_; }; -tsl::AsyncValueRef RawSEDeviceMemory::CreateForeign( +tsl::RCReference RawSEDeviceMemory::CreateForeign( se::DeviceMemoryBase value, absl::AnyInvocable on_delete_callback) { - return tsl::MakeAvailableAsyncValueRef( - value, std::move(on_delete_callback)); + return tsl::MakeRef(value, + std::move(on_delete_callback)); } ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( @@ -167,7 +167,7 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( } TrackedDeviceBuffer::TrackedDeviceBuffer( - PjRtDevice* device, tsl::AsyncValueRef device_memory, + PjRtDevice* device, tsl::RCReference device_memory, absl::Span definition_events) : device_(device), device_memory_(std::move(device_memory)), @@ -178,7 +178,7 @@ TrackedDeviceBuffer::TrackedDeviceBuffer( TrackedDeviceBuffer::~TrackedDeviceBuffer() = default; void TrackedDeviceBuffer::ReleaseDeviceMemory() { - device_memory_ = tsl::AsyncValueRef(); + device_memory_ = tsl::RCReference(); } void TrackedDeviceBuffer::ConfirmDonation() { diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.h b/third_party/xla/xla/pjrt/tracked_device_buffer.h index 4dc8f30093c..24ed1db4501 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.h @@ -51,6 +51,7 @@ limitations under the License. namespace xla { +// TODO(parkers): Implement PjRtRawBuffer API. class RawSEDeviceMemory : public tsl::ReferenceCounted { public: explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {} @@ -69,10 +70,10 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted { ShapedBuffer AsShapedBuffer(PjRtDevice* device, const Shape& on_device_shape) const; - static tsl::AsyncValueRef Create( + static tsl::RCReference Create( se::DeviceMemoryBase value, LocalDeviceState* local_device, se::DeviceMemoryAllocator* allocator); - static tsl::AsyncValueRef CreateForeign( + static tsl::RCReference CreateForeign( se::DeviceMemoryBase value, absl::AnyInvocable on_delete_callback); @@ -129,7 +130,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { ExecutionInput* execution_input, se::DeviceMemoryAllocator* allocator) const; - const tsl::AsyncValueRef& device_memory() const { + const tsl::RCReference& device_memory() const { return device_memory_; } @@ -167,7 +168,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { StreamAndEventContainer LockUseAndTransferUsageEvents(); TrackedDeviceBuffer( - PjRtDevice* device, tsl::AsyncValueRef device_memory, + PjRtDevice* device, tsl::RCReference device_memory, absl::Span definition_events); ~TrackedDeviceBuffer() override; @@ -198,7 +199,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { PjRtDevice* device_; // Each host-side buffer may have several buffers on-device. - tsl::AsyncValueRef device_memory_; + tsl::RCReference device_memory_; // Events that are triggered when the content of one or more buffers is ready // during multistream execution. May be nullptr, which is used in the diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc index 0431bf5a522..c757f3c6cbb 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc @@ -83,7 +83,7 @@ class TestDevice : public PjRtDevice { absl::StatusOr> MakeArray( const Shape& shape, LocalClient* client, PjRtDevice* device) { - std::vector> device_buffers; + std::vector> device_buffers; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( client->backend().transfer_manager()->HostShapeToDeviceShape(shape), [&](const Shape& subshape, const ShapeIndex&) -> absl::Status {