PiperOrigin-RevId: 819762394
This commit is contained in:
Peter Hawkins 2025-10-15 08:14:19 -07:00 committed by TensorFlower Gardener
parent 2b17e0e0c0
commit baf408c724
3 changed files with 71 additions and 94 deletions

View File

@ -718,6 +718,11 @@ void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
SetState(kConverted);
}
bool PjRtStreamExecutorBuffer::IsOnCpu() const {
return memory_space() != nullptr &&
memory_space()->kind() == PinnedHostMemorySpace::kKind;
}
bool PjRtStreamExecutorClient::IsOnCpu(PjRtMemorySpace* memory_space) {
return memory_space->kind() == PinnedHostMemorySpace::kKind;
}
@ -1399,6 +1404,30 @@ void PjRtStreamExecutorBuffer::ConvertUsageHold(TrackedDeviceBuffer* buffer,
DecrementUsage();
}
absl::StatusOr<size_t> PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes()
const {
absl::MutexLock lock(&mu_);
if (device_buffer() == nullptr || !device_buffer()->device_memory()) {
return InvalidArgument(
"GetOnDeviceSizeInBytes called on deleted or donated buffer");
}
return device_buffer()->device_memory()->mem().size();
}
Future<> PjRtStreamExecutorBuffer::CopyRawToHost(void* dst, int64_t offset,
int64_t transfer_size) {
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client());
return se_client->CopyRawSubBufferToHost(this, Future<void*>(dst), offset,
transfer_size);
}
Future<> PjRtStreamExecutorBuffer::CopyRawToHostFuture(Future<void*> dst,
int64_t offset,
int64_t transfer_size) {
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client());
return se_client->CopyRawSubBufferToHost(this, dst, offset, transfer_size);
}
PjRtStreamExecutorBuffer::ScopedHold
PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
absl::MutexLock lock(&mu_);

View File

@ -600,6 +600,14 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
absl::StatusOr<size_t> GetOnDeviceSizeInBytes() const override;
Future<> CopyRawToHost(void* dst, int64_t offset,
int64_t transfer_size) override;
Future<> CopyRawToHostFuture(Future<void*> dst, int64_t offset,
int64_t transfer_size) override;
// Drops the buffer's reference to its associated device memory, leaving the
// buffer in an invalid state. The memory will be freed lazily when all async
// operations using the buffer have completed, according to the allocation
@ -629,6 +637,8 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
Future<> GetReadyFuture() override;
bool IsOnCpu() const override;
// Similar to Delete, drops the buffer's reference to its associated device
// memory, leaving the buffer in an invalid state, but returns the
// TrackedDeviceBuffer rather than freeing the device memory, so that another

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "xla/pjrt/se_raw_buffer.h"
#include <cstdint>
#include <cstring>
#include <memory>
#include <optional>
#include <utility>
@ -132,55 +131,23 @@ PjRtStreamExecutorRawBuffer::CopyRawHostToDeviceAndReturnEvent(
const void* src, int64_t offset, int64_t transfer_size) {
se::Stream* stream = local_device_->host_to_device_stream();
auto device_event = BufferSequencingEvent::Create(client_->thread_pool());
client_->thread_pool()->Schedule([client = client_, device_event,
local_device = local_device_, stream, src,
offset, transfer_size,
buf = tsl::FormRef(this)]() mutable {
se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem();
if (transfer_size < sub_buffer.size()) {
sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size);
}
client->WaitForAllocation(stream, *buf);
std::shared_ptr<void> staging_buffer;
auto status = [&]() -> absl::Status {
if (transfer_size > 0) {
if (client->should_stage_host_to_device_transfers() &&
!client->IsDmaMapped(src, transfer_size)) {
if (client->host_memory_allocator() == nullptr) {
return absl::InvalidArgumentError(
"host_memory_allocator should be initialized for "
"staging buffer transfer.");
}
void* ptr = client->host_memory_allocator()->AllocateRaw(
tsl::Allocator::kAllocatorAlignment, transfer_size);
staging_buffer = std::shared_ptr<void>(
ptr,
[host_memory_allocator = client->host_memory_allocator()](
void* ptr) { host_memory_allocator->DeallocateRaw(ptr); });
auto copy_to_staging_buffer = [src, transfer_size,
staging_buffer]() mutable {
std::memcpy(staging_buffer.get(), src, transfer_size);
};
TF_RETURN_IF_ERROR(stream->DoHostCallback(copy_to_staging_buffer));
TF_RETURN_IF_ERROR(
stream->Memcpy(&sub_buffer, staging_buffer.get(), transfer_size));
} else {
TF_RETURN_IF_ERROR(stream->Memcpy(&sub_buffer, src, transfer_size));
client_->thread_pool()->Schedule(
[client = client_, device_event, local_device = local_device_, stream,
src, offset, transfer_size, buf = tsl::FormRef(this)]() mutable {
se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem();
if (transfer_size < sub_buffer.size()) {
sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size);
}
}
return absl::OkStatus();
}();
if (status.ok()) {
status =
client->AllocateAndRecordEvent(device_event, local_device, stream);
if (staging_buffer) {
device_event.AndThen([staging_buffer = std::move(staging_buffer)]() {});
}
}
if (!status.ok()) {
client->SetEventAsError(device_event, status);
}
});
client->WaitForAllocation(stream, *buf);
auto status = stream->Memcpy(&sub_buffer, src, transfer_size);
if (status.ok()) {
status = client->AllocateAndRecordEvent(device_event, local_device,
stream);
}
if (!status.ok()) {
client->SetEventAsError(device_event, status);
}
});
return tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
std::move(device_event), "PjRtStreamExecutorRawBuffer",
"CopyRawHostToDevice");
@ -191,52 +158,23 @@ PjRtStreamExecutorRawBuffer::CopyRawDeviceToHostAndReturnEvent(
void* dst, int64_t offset, int64_t transfer_size) {
se::Stream* stream = local_device_->GetDeviceToHostStream();
auto device_event = BufferSequencingEvent::Create(client_->thread_pool());
client_->thread_pool()->Schedule([client = client_, device_event,
local_device = local_device_, stream, dst,
offset, transfer_size,
buf = tsl::FormRef(this)]() mutable {
se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem();
if (transfer_size < sub_buffer.size()) {
sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size);
}
client->WaitForAllocation(stream, *buf);
auto status = [&]() -> absl::Status {
if (transfer_size > 0) {
if (client->should_stage_host_to_device_transfers() &&
!client->IsDmaMapped(dst, transfer_size)) {
if (client->host_memory_allocator() == nullptr) {
return absl::InvalidArgumentError(
"host_memory_allocator should be initialized for "
"staging buffer transfer.");
}
void* ptr = client->host_memory_allocator()->AllocateRaw(
tsl::Allocator::kAllocatorAlignment, transfer_size);
std::shared_ptr<void> staging_buffer = std::shared_ptr<void>(
ptr,
[host_memory_allocator = client->host_memory_allocator()](
void* ptr) { host_memory_allocator->DeallocateRaw(ptr); });
TF_RETURN_IF_ERROR(
stream->Memcpy(staging_buffer.get(), sub_buffer, transfer_size));
auto copy_from_staging_buffer = [dst, transfer_size,
staging_buffer]() mutable {
std::memcpy(dst, staging_buffer.get(), transfer_size);
};
// TODO(parkers): This failing maybe consitutes a race.
TF_RETURN_IF_ERROR(stream->DoHostCallback(copy_from_staging_buffer));
} else {
TF_RETURN_IF_ERROR(stream->Memcpy(dst, sub_buffer, transfer_size));
client_->thread_pool()->Schedule(
[client = client_, device_event, local_device = local_device_, stream,
dst, offset, transfer_size, buf = tsl::FormRef(this)]() mutable {
se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem();
if (transfer_size < sub_buffer.size()) {
sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size);
}
}
return absl::OkStatus();
}();
if (status.ok()) {
status =
client->AllocateAndRecordEvent(device_event, local_device, stream);
}
if (!status.ok()) {
client->SetEventAsError(device_event, status);
}
});
client->WaitForAllocation(stream, *buf);
auto status = stream->Memcpy(dst, sub_buffer, transfer_size);
if (status.ok()) {
status = client->AllocateAndRecordEvent(device_event, local_device,
stream);
}
if (!status.ok()) {
client->SetEventAsError(device_event, status);
}
});
return tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
std::move(device_event), "PjRtStreamExecutorRawBuffer",
"CopyRawDeviceToHost");