PiperOrigin-RevId: 825779916
This commit is contained in:
A. Unique TensorFlower 2025-10-29 17:41:51 -07:00 committed by TensorFlower Gardener
parent fe2a783077
commit c7055c2e5b
7 changed files with 39 additions and 38 deletions

View File

@ -1367,7 +1367,7 @@ StreamExecutorGpuClient::RunAsync(
std::set<se::DeviceMemoryBase> buffers_in_result; std::set<se::DeviceMemoryBase> buffers_in_result;
xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results( xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results(
gpu_exec->result_shape()); gpu_exec->result_shape());
for (auto& p : results) { for (auto& p : results) {
@ -1452,7 +1452,7 @@ StreamExecutorGpuClient::RunAsync(
TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result, TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result,
gpu_exec->GetAllocations())); gpu_exec->GetAllocations()));
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released; std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released;
// Free allocations for arguments. // Free allocations for arguments.
for (ShapeTree<PjRtStreamExecutorExecutionInput>& input : arguments) { for (ShapeTree<PjRtStreamExecutorExecutionInput>& input : arguments) {

View File

@ -789,7 +789,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error,
// Create an empty buffer. // Create an empty buffer.
auto dummy_device_buffer = std::make_unique<TrackedDeviceBuffer>( auto dummy_device_buffer = std::make_unique<TrackedDeviceBuffer>(
device, tsl::AsyncValueRef<RawSEDeviceMemory>(), device, tsl::RCReference<RawSEDeviceMemory>(),
absl::MakeSpan(&definition_event, 1)); absl::MakeSpan(&definition_event, 1));
return std::make_unique<CommonPjRtBufferImpl>( return std::make_unique<CommonPjRtBufferImpl>(
@ -1168,13 +1168,13 @@ MakeTupleHelper(PjRtStreamExecutorClient* client,
// Converts a ScopedShapedBuffer returned from an execution into a // Converts a ScopedShapedBuffer returned from an execution into a
// PjRtBuffer. // PjRtBuffer.
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper( absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer, ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtClient* client, BufferSequencingEventRef definition_event, PjRtClient* client,
PjRtDevice* device, LocalDeviceState* local_device) { PjRtDevice* device, LocalDeviceState* local_device) {
if (result_buffer.shape().IsTuple()) { if (result_buffer.shape().IsTuple()) {
return absl::InternalError("OutputBufferHelper called on tuple."); return absl::InternalError("OutputBufferHelper called on tuple.");
} }
absl::InlinedVector<tsl::AsyncValueRef<RawSEDeviceMemory>, 1> buffers; absl::InlinedVector<tsl::RCReference<RawSEDeviceMemory>, 1> buffers;
for (auto& item : result_buffer) { for (auto& item : result_buffer) {
buffers.push_back(std::move(item.second)); buffers.push_back(std::move(item.second));
} }
@ -1641,7 +1641,7 @@ PjRtStreamExecutorClient::RunAsync(
ExecutionOutput output, ExecutionOutput output,
exec.RunAsync(std::move(xla_arguments), std::move(run_options))); exec.RunAsync(std::move(xla_arguments), std::move(run_options)));
ScopedShapedBuffer ssb = output.ConsumeResult(); ScopedShapedBuffer ssb = output.ConsumeResult();
xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results( xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results(
ssb.on_device_shape()); ssb.on_device_shape());
auto it = results.begin(); auto it = results.begin();
se::DeviceMemoryAllocator* allocator = ssb.memory_allocator(); se::DeviceMemoryAllocator* allocator = ssb.memory_allocator();
@ -1672,7 +1672,7 @@ PjRtStreamExecutorClient::RunAsync(
// converted on success. // converted on success.
// When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id` // When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id`
// to initialize `run_options`. // to initialize `run_options`.
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>> absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
PjRtStreamExecutorLoadedExecutable::EnqueueExecution( PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition, absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
int executable_idx, const RunId& run_id, const ExecuteOptions& options, int executable_idx, const RunId& run_id, const ExecuteOptions& options,
@ -1930,7 +1930,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers( PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options, int device_ordinal, const ExecuteOptions& options,
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer, ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtDevice* device, BufferSequencingEventRef definition_event, PjRtDevice* device,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const { std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
tsl::profiler::TraceMe traceme("MakeOutputBuffers"); tsl::profiler::TraceMe traceme("MakeOutputBuffers");
@ -1943,7 +1943,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
// in result_buffer. // in result_buffer.
for (int i = 0; i < tuple_count; ++i) { for (int i = 0; i < tuple_count; ++i) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> tuple_buffer, ShapeTree<tsl::RCReference<RawSEDeviceMemory>> tuple_buffer,
result_buffer.SubShapeTree({i})); result_buffer.SubShapeTree({i}));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer, std::unique_ptr<PjRtBuffer> buffer,
@ -2050,7 +2050,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks; std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers; std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
device_buffers.reserve(argument_handles.size()); device_buffers.reserve(argument_handles.size());
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>> absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
result_buffer_or_status = result_buffer_or_status =
EnqueueExecution(argument_handles, replica, partition, executable_idx, EnqueueExecution(argument_handles, replica, partition, executable_idx,
run_id, options, device, &device_buffers, run_id, options, device, &device_buffers,
@ -2061,7 +2061,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
<< " failed: " << result_buffer_or_status.status(); << " failed: " << result_buffer_or_status.status();
return result_buffer_or_status.status(); return result_buffer_or_status.status();
} }
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer = ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer =
std::move(result_buffer_or_status).value(); std::move(result_buffer_or_status).value();
LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
@ -2081,14 +2081,14 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
} }
return definition_event_or.status(); return definition_event_or.status();
} }
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> leaves_to_release; std::vector<tsl::RCReference<RawSEDeviceMemory>> leaves_to_release;
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
leaves_to_release.reserve(result_buffer.leaf_count()); leaves_to_release.reserve(result_buffer.leaf_count());
for (auto& node : result_buffer.leaves()) { for (auto& node : result_buffer.leaves()) {
leaves_to_release.push_back(node.second); leaves_to_release.push_back(node.second);
} }
} }
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release; std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release;
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>( auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute"); *definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(

View File

@ -83,13 +83,13 @@ struct PjRtStreamExecutorExecutionInput {
// Donation is not complete until ReleaseDeviceMemory() is called on the // Donation is not complete until ReleaseDeviceMemory() is called on the
// TrackedDeviceBuffer that provides buf. // TrackedDeviceBuffer that provides buf.
bool is_donated; bool is_donated;
tsl::AsyncValueRef<RawSEDeviceMemory> buf; tsl::RCReference<RawSEDeviceMemory> buf;
}; };
struct PjRtStreamExecutorExecutionOutput { struct PjRtStreamExecutorExecutionOutput {
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result; ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result;
// Donated inputs which must be freed. // Donated inputs which must be freed.
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released; std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released;
// For PjRtStreamExecutorClient implementations that // For PjRtStreamExecutorClient implementations that
// use OwningDeviceMemory for donated inputs. // use OwningDeviceMemory for donated inputs.
std::vector<se::OwningDeviceMemory> se_to_be_released; std::vector<se::OwningDeviceMemory> se_to_be_released;
@ -672,7 +672,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers, absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const; absl::flat_hash_set<BufferSequencingEvent*>& events) const;
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>> absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
EnqueueExecution( EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id, int partition, int executable_idx, const RunId& run_id,
@ -684,7 +684,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeOutputBuffers( MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options, int device_ordinal, const ExecuteOptions& options,
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer, ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtDevice* device, BufferSequencingEventRef definition_event, PjRtDevice* device,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const; std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;

View File

@ -91,10 +91,10 @@ class PjRtStreamExecutorDeviceEventPromise : public PjRtDeviceEventPromise {
class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer { class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
public: public:
PjRtStreamExecutorRawBuffer( PjRtStreamExecutorRawBuffer(PjRtStreamExecutorClient* client,
PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space, PjRtMemorySpace* memory_space,
LocalDeviceState* local_device, LocalDeviceState* local_device,
tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer) tsl::RCReference<RawSEDeviceMemory> device_buffer)
: client_(client), : client_(client),
memory_space_(memory_space), memory_space_(memory_space),
local_device_(local_device), local_device_(local_device),
@ -104,7 +104,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
LocalDeviceState* local_device() const { return local_device_; } LocalDeviceState* local_device() const { return local_device_; }
const tsl::AsyncValueRef<RawSEDeviceMemory>& device_buffer() const { const tsl::RCReference<RawSEDeviceMemory>& device_buffer() const {
return device_buffer_; return device_buffer_;
} }
@ -150,7 +150,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
PjRtStreamExecutorClient* client_; PjRtStreamExecutorClient* client_;
PjRtMemorySpace* memory_space_; PjRtMemorySpace* memory_space_;
LocalDeviceState* local_device_; LocalDeviceState* local_device_;
tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer_; tsl::RCReference<RawSEDeviceMemory> device_buffer_;
}; };
} // namespace xla } // namespace xla

View File

@ -119,11 +119,11 @@ class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory {
size_t sync_point_ = std::numeric_limits<size_t>::max(); size_t sync_point_ = std::numeric_limits<size_t>::max();
}; };
tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::Create( tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::Create(
se::DeviceMemoryBase value, LocalDeviceState* local_device, se::DeviceMemoryBase value, LocalDeviceState* local_device,
se::DeviceMemoryAllocator* allocator) { se::DeviceMemoryAllocator* allocator) {
return tsl::MakeAvailableAsyncValueRef<AllocatedRawSEDeviceMemory>( return tsl::MakeRef<AllocatedRawSEDeviceMemory>(value, local_device,
value, local_device, allocator); allocator);
} }
class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
@ -143,11 +143,11 @@ class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
absl::AnyInvocable<void() &&> on_delete_callback_; absl::AnyInvocable<void() &&> on_delete_callback_;
}; };
tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign( tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign(
se::DeviceMemoryBase value, se::DeviceMemoryBase value,
absl::AnyInvocable<void() &&> on_delete_callback) { absl::AnyInvocable<void() &&> on_delete_callback) {
return tsl::MakeAvailableAsyncValueRef<ForeignRawSEDeviceMemory>( return tsl::MakeRef<ForeignRawSEDeviceMemory>(value,
value, std::move(on_delete_callback)); std::move(on_delete_callback));
} }
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
@ -167,7 +167,7 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
} }
TrackedDeviceBuffer::TrackedDeviceBuffer( TrackedDeviceBuffer::TrackedDeviceBuffer(
PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory, PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory,
absl::Span<const BufferSequencingEventRef> definition_events) absl::Span<const BufferSequencingEventRef> definition_events)
: device_(device), : device_(device),
device_memory_(std::move(device_memory)), device_memory_(std::move(device_memory)),
@ -178,7 +178,7 @@ TrackedDeviceBuffer::TrackedDeviceBuffer(
TrackedDeviceBuffer::~TrackedDeviceBuffer() = default; TrackedDeviceBuffer::~TrackedDeviceBuffer() = default;
void TrackedDeviceBuffer::ReleaseDeviceMemory() { void TrackedDeviceBuffer::ReleaseDeviceMemory() {
device_memory_ = tsl::AsyncValueRef<RawSEDeviceMemory>(); device_memory_ = tsl::RCReference<RawSEDeviceMemory>();
} }
void TrackedDeviceBuffer::ConfirmDonation() { void TrackedDeviceBuffer::ConfirmDonation() {

View File

@ -51,6 +51,7 @@ limitations under the License.
namespace xla { namespace xla {
// TODO(parkers): Implement PjRtRawBuffer API.
class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> { class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
public: public:
explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {} explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {}
@ -69,10 +70,10 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
ShapedBuffer AsShapedBuffer(PjRtDevice* device, ShapedBuffer AsShapedBuffer(PjRtDevice* device,
const Shape& on_device_shape) const; const Shape& on_device_shape) const;
static tsl::AsyncValueRef<RawSEDeviceMemory> Create( static tsl::RCReference<RawSEDeviceMemory> Create(
se::DeviceMemoryBase value, LocalDeviceState* local_device, se::DeviceMemoryBase value, LocalDeviceState* local_device,
se::DeviceMemoryAllocator* allocator); se::DeviceMemoryAllocator* allocator);
static tsl::AsyncValueRef<RawSEDeviceMemory> CreateForeign( static tsl::RCReference<RawSEDeviceMemory> CreateForeign(
se::DeviceMemoryBase value, se::DeviceMemoryBase value,
absl::AnyInvocable<void() &&> on_delete_callback); absl::AnyInvocable<void() &&> on_delete_callback);
@ -129,7 +130,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
ExecutionInput* execution_input, ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const; se::DeviceMemoryAllocator* allocator) const;
const tsl::AsyncValueRef<RawSEDeviceMemory>& device_memory() const { const tsl::RCReference<RawSEDeviceMemory>& device_memory() const {
return device_memory_; return device_memory_;
} }
@ -167,7 +168,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
StreamAndEventContainer LockUseAndTransferUsageEvents(); StreamAndEventContainer LockUseAndTransferUsageEvents();
TrackedDeviceBuffer( TrackedDeviceBuffer(
PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory, PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory,
absl::Span<const BufferSequencingEventRef> definition_events); absl::Span<const BufferSequencingEventRef> definition_events);
~TrackedDeviceBuffer() override; ~TrackedDeviceBuffer() override;
@ -198,7 +199,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
PjRtDevice* device_; PjRtDevice* device_;
// Each host-side buffer may have several buffers on-device. // Each host-side buffer may have several buffers on-device.
tsl::AsyncValueRef<RawSEDeviceMemory> device_memory_; tsl::RCReference<RawSEDeviceMemory> device_memory_;
// Events that are triggered when the content of one or more buffers is ready // Events that are triggered when the content of one or more buffers is ready
// during multistream execution. May be nullptr, which is used in the // during multistream execution. May be nullptr, which is used in the

View File

@ -83,7 +83,7 @@ class TestDevice : public PjRtDevice {
absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray( absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray(
const Shape& shape, LocalClient* client, PjRtDevice* device) { const Shape& shape, LocalClient* client, PjRtDevice* device) {
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> device_buffers; std::vector<tsl::RCReference<RawSEDeviceMemory>> device_buffers;
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
client->backend().transfer_manager()->HostShapeToDeviceShape(shape), client->backend().transfer_manager()->HostShapeToDeviceShape(shape),
[&](const Shape& subshape, const ShapeIndex&) -> absl::Status { [&](const Shape& subshape, const ShapeIndex&) -> absl::Status {