PiperOrigin-RevId: 825779916
This commit is contained in:
A. Unique TensorFlower 2025-10-29 17:41:51 -07:00 committed by TensorFlower Gardener
parent fe2a783077
commit c7055c2e5b
7 changed files with 39 additions and 38 deletions

View File

@ -1367,7 +1367,7 @@ StreamExecutorGpuClient::RunAsync(
std::set<se::DeviceMemoryBase> buffers_in_result;
xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results(
xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results(
gpu_exec->result_shape());
for (auto& p : results) {
@ -1452,7 +1452,7 @@ StreamExecutorGpuClient::RunAsync(
TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result,
gpu_exec->GetAllocations()));
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released;
std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released;
// Free allocations for arguments.
for (ShapeTree<PjRtStreamExecutorExecutionInput>& input : arguments) {

View File

@ -789,7 +789,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error,
// Create an empty buffer.
auto dummy_device_buffer = std::make_unique<TrackedDeviceBuffer>(
device, tsl::AsyncValueRef<RawSEDeviceMemory>(),
device, tsl::RCReference<RawSEDeviceMemory>(),
absl::MakeSpan(&definition_event, 1));
return std::make_unique<CommonPjRtBufferImpl>(
@ -1168,13 +1168,13 @@ MakeTupleHelper(PjRtStreamExecutorClient* client,
// Converts a ScopedShapedBuffer returned from an execution into a
// PjRtBuffer.
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtClient* client,
PjRtDevice* device, LocalDeviceState* local_device) {
if (result_buffer.shape().IsTuple()) {
return absl::InternalError("OutputBufferHelper called on tuple.");
}
absl::InlinedVector<tsl::AsyncValueRef<RawSEDeviceMemory>, 1> buffers;
absl::InlinedVector<tsl::RCReference<RawSEDeviceMemory>, 1> buffers;
for (auto& item : result_buffer) {
buffers.push_back(std::move(item.second));
}
@ -1641,7 +1641,7 @@ PjRtStreamExecutorClient::RunAsync(
ExecutionOutput output,
exec.RunAsync(std::move(xla_arguments), std::move(run_options)));
ScopedShapedBuffer ssb = output.ConsumeResult();
xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results(
xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results(
ssb.on_device_shape());
auto it = results.begin();
se::DeviceMemoryAllocator* allocator = ssb.memory_allocator();
@ -1672,7 +1672,7 @@ PjRtStreamExecutorClient::RunAsync(
// converted on success.
// When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id`
// to initialize `run_options`.
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
@ -1930,7 +1930,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options,
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtDevice* device,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
tsl::profiler::TraceMe traceme("MakeOutputBuffers");
@ -1943,7 +1943,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
// in result_buffer.
for (int i = 0; i < tuple_count; ++i) {
TF_ASSIGN_OR_RETURN(
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> tuple_buffer,
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> tuple_buffer,
result_buffer.SubShapeTree({i}));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,
@ -2050,7 +2050,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
device_buffers.reserve(argument_handles.size());
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
result_buffer_or_status =
EnqueueExecution(argument_handles, replica, partition, executable_idx,
run_id, options, device, &device_buffers,
@ -2061,7 +2061,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
<< " failed: " << result_buffer_or_status.status();
return result_buffer_or_status.status();
}
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer =
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer =
std::move(result_buffer_or_status).value();
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
@ -2081,14 +2081,14 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
}
return definition_event_or.status();
}
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> leaves_to_release;
std::vector<tsl::RCReference<RawSEDeviceMemory>> leaves_to_release;
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
leaves_to_release.reserve(result_buffer.leaf_count());
for (auto& node : result_buffer.leaves()) {
leaves_to_release.push_back(node.second);
}
}
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release;
std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release;
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
TF_ASSIGN_OR_RETURN(

View File

@ -83,13 +83,13 @@ struct PjRtStreamExecutorExecutionInput {
// Donation is not complete until ReleaseDeviceMemory() is called on the
// TrackedDeviceBuffer that provides buf.
bool is_donated;
tsl::AsyncValueRef<RawSEDeviceMemory> buf;
tsl::RCReference<RawSEDeviceMemory> buf;
};
struct PjRtStreamExecutorExecutionOutput {
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result;
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result;
// Donated inputs which must be freed.
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released;
std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released;
// For PjRtStreamExecutorClient implementations that
// use OwningDeviceMemory for donated inputs.
std::vector<se::OwningDeviceMemory> se_to_be_released;
@ -672,7 +672,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id,
@ -684,7 +684,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options,
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
BufferSequencingEventRef definition_event, PjRtDevice* device,
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;

View File

@ -91,10 +91,10 @@ class PjRtStreamExecutorDeviceEventPromise : public PjRtDeviceEventPromise {
class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
public:
PjRtStreamExecutorRawBuffer(
PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space,
PjRtStreamExecutorRawBuffer(PjRtStreamExecutorClient* client,
PjRtMemorySpace* memory_space,
LocalDeviceState* local_device,
tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer)
tsl::RCReference<RawSEDeviceMemory> device_buffer)
: client_(client),
memory_space_(memory_space),
local_device_(local_device),
@ -104,7 +104,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
LocalDeviceState* local_device() const { return local_device_; }
const tsl::AsyncValueRef<RawSEDeviceMemory>& device_buffer() const {
const tsl::RCReference<RawSEDeviceMemory>& device_buffer() const {
return device_buffer_;
}
@ -150,7 +150,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
PjRtStreamExecutorClient* client_;
PjRtMemorySpace* memory_space_;
LocalDeviceState* local_device_;
tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer_;
tsl::RCReference<RawSEDeviceMemory> device_buffer_;
};
} // namespace xla

View File

@ -119,11 +119,11 @@ class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory {
size_t sync_point_ = std::numeric_limits<size_t>::max();
};
tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::Create(
tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::Create(
se::DeviceMemoryBase value, LocalDeviceState* local_device,
se::DeviceMemoryAllocator* allocator) {
return tsl::MakeAvailableAsyncValueRef<AllocatedRawSEDeviceMemory>(
value, local_device, allocator);
return tsl::MakeRef<AllocatedRawSEDeviceMemory>(value, local_device,
allocator);
}
class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
@ -143,11 +143,11 @@ class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
absl::AnyInvocable<void() &&> on_delete_callback_;
};
tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign(
tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign(
se::DeviceMemoryBase value,
absl::AnyInvocable<void() &&> on_delete_callback) {
return tsl::MakeAvailableAsyncValueRef<ForeignRawSEDeviceMemory>(
value, std::move(on_delete_callback));
return tsl::MakeRef<ForeignRawSEDeviceMemory>(value,
std::move(on_delete_callback));
}
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
@ -167,7 +167,7 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
}
TrackedDeviceBuffer::TrackedDeviceBuffer(
PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory,
PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory,
absl::Span<const BufferSequencingEventRef> definition_events)
: device_(device),
device_memory_(std::move(device_memory)),
@ -178,7 +178,7 @@ TrackedDeviceBuffer::TrackedDeviceBuffer(
TrackedDeviceBuffer::~TrackedDeviceBuffer() = default;
void TrackedDeviceBuffer::ReleaseDeviceMemory() {
device_memory_ = tsl::AsyncValueRef<RawSEDeviceMemory>();
device_memory_ = tsl::RCReference<RawSEDeviceMemory>();
}
void TrackedDeviceBuffer::ConfirmDonation() {

View File

@ -51,6 +51,7 @@ limitations under the License.
namespace xla {
// TODO(parkers): Implement PjRtRawBuffer API.
class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
public:
explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {}
@ -69,10 +70,10 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
ShapedBuffer AsShapedBuffer(PjRtDevice* device,
const Shape& on_device_shape) const;
static tsl::AsyncValueRef<RawSEDeviceMemory> Create(
static tsl::RCReference<RawSEDeviceMemory> Create(
se::DeviceMemoryBase value, LocalDeviceState* local_device,
se::DeviceMemoryAllocator* allocator);
static tsl::AsyncValueRef<RawSEDeviceMemory> CreateForeign(
static tsl::RCReference<RawSEDeviceMemory> CreateForeign(
se::DeviceMemoryBase value,
absl::AnyInvocable<void() &&> on_delete_callback);
@ -129,7 +130,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const;
const tsl::AsyncValueRef<RawSEDeviceMemory>& device_memory() const {
const tsl::RCReference<RawSEDeviceMemory>& device_memory() const {
return device_memory_;
}
@ -167,7 +168,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
StreamAndEventContainer LockUseAndTransferUsageEvents();
TrackedDeviceBuffer(
PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory,
PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory,
absl::Span<const BufferSequencingEventRef> definition_events);
~TrackedDeviceBuffer() override;
@ -198,7 +199,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
PjRtDevice* device_;
// Each host-side buffer may have several buffers on-device.
tsl::AsyncValueRef<RawSEDeviceMemory> device_memory_;
tsl::RCReference<RawSEDeviceMemory> device_memory_;
// Events that are triggered when the content of one or more buffers is ready
// during multistream execution. May be nullptr, which is used in the

View File

@ -83,7 +83,7 @@ class TestDevice : public PjRtDevice {
absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray(
const Shape& shape, LocalClient* client, PjRtDevice* device) {
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> device_buffers;
std::vector<tsl::RCReference<RawSEDeviceMemory>> device_buffers;
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
client->backend().transfer_manager()->HostShapeToDeviceShape(shape),
[&](const Shape& subshape, const ShapeIndex&) -> absl::Status {