mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
parent
2b17e0e0c0
commit
baf408c724
|
|
@ -718,6 +718,11 @@ void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
|
|||
SetState(kConverted);
|
||||
}
|
||||
|
||||
bool PjRtStreamExecutorBuffer::IsOnCpu() const {
|
||||
return memory_space() != nullptr &&
|
||||
memory_space()->kind() == PinnedHostMemorySpace::kKind;
|
||||
}
|
||||
|
||||
bool PjRtStreamExecutorClient::IsOnCpu(PjRtMemorySpace* memory_space) {
|
||||
return memory_space->kind() == PinnedHostMemorySpace::kKind;
|
||||
}
|
||||
|
|
@ -1399,6 +1404,30 @@ void PjRtStreamExecutorBuffer::ConvertUsageHold(TrackedDeviceBuffer* buffer,
|
|||
DecrementUsage();
|
||||
}
|
||||
|
||||
absl::StatusOr<size_t> PjRtStreamExecutorBuffer::GetOnDeviceSizeInBytes()
|
||||
const {
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (device_buffer() == nullptr || !device_buffer()->device_memory()) {
|
||||
return InvalidArgument(
|
||||
"GetOnDeviceSizeInBytes called on deleted or donated buffer");
|
||||
}
|
||||
return device_buffer()->device_memory()->mem().size();
|
||||
}
|
||||
|
||||
Future<> PjRtStreamExecutorBuffer::CopyRawToHost(void* dst, int64_t offset,
|
||||
int64_t transfer_size) {
|
||||
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client());
|
||||
return se_client->CopyRawSubBufferToHost(this, Future<void*>(dst), offset,
|
||||
transfer_size);
|
||||
}
|
||||
|
||||
Future<> PjRtStreamExecutorBuffer::CopyRawToHostFuture(Future<void*> dst,
|
||||
int64_t offset,
|
||||
int64_t transfer_size) {
|
||||
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client());
|
||||
return se_client->CopyRawSubBufferToHost(this, dst, offset, transfer_size);
|
||||
}
|
||||
|
||||
PjRtStreamExecutorBuffer::ScopedHold
|
||||
PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
|
|
|||
|
|
@ -600,6 +600,14 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
|
|||
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
|
||||
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
|
||||
|
||||
absl::StatusOr<size_t> GetOnDeviceSizeInBytes() const override;
|
||||
|
||||
Future<> CopyRawToHost(void* dst, int64_t offset,
|
||||
int64_t transfer_size) override;
|
||||
|
||||
Future<> CopyRawToHostFuture(Future<void*> dst, int64_t offset,
|
||||
int64_t transfer_size) override;
|
||||
|
||||
// Drops the buffer's reference to its associated device memory, leaving the
|
||||
// buffer in an invalid state. The memory will be freed lazily when all async
|
||||
// operations using the buffer have completed, according to the allocation
|
||||
|
|
@ -629,6 +637,8 @@ class PjRtStreamExecutorBuffer : public CommonPjRtBufferImpl {
|
|||
|
||||
Future<> GetReadyFuture() override;
|
||||
|
||||
bool IsOnCpu() const override;
|
||||
|
||||
// Similar to Delete, drops the buffer's reference to its associated device
|
||||
// memory, leaving the buffer in an invalid state, but returns the
|
||||
// TrackedDeviceBuffer rather than freeing the device memory, so that another
|
||||
|
|
|
|||
126
third_party/xla/xla/pjrt/se_raw_buffer.cc
vendored
126
third_party/xla/xla/pjrt/se_raw_buffer.cc
vendored
|
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||
#include "xla/pjrt/se_raw_buffer.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
|
@ -132,55 +131,23 @@ PjRtStreamExecutorRawBuffer::CopyRawHostToDeviceAndReturnEvent(
|
|||
const void* src, int64_t offset, int64_t transfer_size) {
|
||||
se::Stream* stream = local_device_->host_to_device_stream();
|
||||
auto device_event = BufferSequencingEvent::Create(client_->thread_pool());
|
||||
client_->thread_pool()->Schedule([client = client_, device_event,
|
||||
local_device = local_device_, stream, src,
|
||||
offset, transfer_size,
|
||||
buf = tsl::FormRef(this)]() mutable {
|
||||
se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem();
|
||||
if (transfer_size < sub_buffer.size()) {
|
||||
sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size);
|
||||
}
|
||||
client->WaitForAllocation(stream, *buf);
|
||||
std::shared_ptr<void> staging_buffer;
|
||||
auto status = [&]() -> absl::Status {
|
||||
if (transfer_size > 0) {
|
||||
if (client->should_stage_host_to_device_transfers() &&
|
||||
!client->IsDmaMapped(src, transfer_size)) {
|
||||
if (client->host_memory_allocator() == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"host_memory_allocator should be initialized for "
|
||||
"staging buffer transfer.");
|
||||
}
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
tsl::Allocator::kAllocatorAlignment, transfer_size);
|
||||
staging_buffer = std::shared_ptr<void>(
|
||||
ptr,
|
||||
[host_memory_allocator = client->host_memory_allocator()](
|
||||
void* ptr) { host_memory_allocator->DeallocateRaw(ptr); });
|
||||
auto copy_to_staging_buffer = [src, transfer_size,
|
||||
staging_buffer]() mutable {
|
||||
std::memcpy(staging_buffer.get(), src, transfer_size);
|
||||
};
|
||||
TF_RETURN_IF_ERROR(stream->DoHostCallback(copy_to_staging_buffer));
|
||||
TF_RETURN_IF_ERROR(
|
||||
stream->Memcpy(&sub_buffer, staging_buffer.get(), transfer_size));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(stream->Memcpy(&sub_buffer, src, transfer_size));
|
||||
client_->thread_pool()->Schedule(
|
||||
[client = client_, device_event, local_device = local_device_, stream,
|
||||
src, offset, transfer_size, buf = tsl::FormRef(this)]() mutable {
|
||||
se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem();
|
||||
if (transfer_size < sub_buffer.size()) {
|
||||
sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}();
|
||||
if (status.ok()) {
|
||||
status =
|
||||
client->AllocateAndRecordEvent(device_event, local_device, stream);
|
||||
if (staging_buffer) {
|
||||
device_event.AndThen([staging_buffer = std::move(staging_buffer)]() {});
|
||||
}
|
||||
}
|
||||
if (!status.ok()) {
|
||||
client->SetEventAsError(device_event, status);
|
||||
}
|
||||
});
|
||||
client->WaitForAllocation(stream, *buf);
|
||||
auto status = stream->Memcpy(&sub_buffer, src, transfer_size);
|
||||
if (status.ok()) {
|
||||
status = client->AllocateAndRecordEvent(device_event, local_device,
|
||||
stream);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
client->SetEventAsError(device_event, status);
|
||||
}
|
||||
});
|
||||
return tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
|
||||
std::move(device_event), "PjRtStreamExecutorRawBuffer",
|
||||
"CopyRawHostToDevice");
|
||||
|
|
@ -191,52 +158,23 @@ PjRtStreamExecutorRawBuffer::CopyRawDeviceToHostAndReturnEvent(
|
|||
void* dst, int64_t offset, int64_t transfer_size) {
|
||||
se::Stream* stream = local_device_->GetDeviceToHostStream();
|
||||
auto device_event = BufferSequencingEvent::Create(client_->thread_pool());
|
||||
client_->thread_pool()->Schedule([client = client_, device_event,
|
||||
local_device = local_device_, stream, dst,
|
||||
offset, transfer_size,
|
||||
buf = tsl::FormRef(this)]() mutable {
|
||||
se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem();
|
||||
if (transfer_size < sub_buffer.size()) {
|
||||
sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size);
|
||||
}
|
||||
client->WaitForAllocation(stream, *buf);
|
||||
auto status = [&]() -> absl::Status {
|
||||
if (transfer_size > 0) {
|
||||
if (client->should_stage_host_to_device_transfers() &&
|
||||
!client->IsDmaMapped(dst, transfer_size)) {
|
||||
if (client->host_memory_allocator() == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"host_memory_allocator should be initialized for "
|
||||
"staging buffer transfer.");
|
||||
}
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
tsl::Allocator::kAllocatorAlignment, transfer_size);
|
||||
std::shared_ptr<void> staging_buffer = std::shared_ptr<void>(
|
||||
ptr,
|
||||
[host_memory_allocator = client->host_memory_allocator()](
|
||||
void* ptr) { host_memory_allocator->DeallocateRaw(ptr); });
|
||||
TF_RETURN_IF_ERROR(
|
||||
stream->Memcpy(staging_buffer.get(), sub_buffer, transfer_size));
|
||||
auto copy_from_staging_buffer = [dst, transfer_size,
|
||||
staging_buffer]() mutable {
|
||||
std::memcpy(dst, staging_buffer.get(), transfer_size);
|
||||
};
|
||||
// TODO(parkers): This failing maybe consitutes a race.
|
||||
TF_RETURN_IF_ERROR(stream->DoHostCallback(copy_from_staging_buffer));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(stream->Memcpy(dst, sub_buffer, transfer_size));
|
||||
client_->thread_pool()->Schedule(
|
||||
[client = client_, device_event, local_device = local_device_, stream,
|
||||
dst, offset, transfer_size, buf = tsl::FormRef(this)]() mutable {
|
||||
se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem();
|
||||
if (transfer_size < sub_buffer.size()) {
|
||||
sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}();
|
||||
if (status.ok()) {
|
||||
status =
|
||||
client->AllocateAndRecordEvent(device_event, local_device, stream);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
client->SetEventAsError(device_event, status);
|
||||
}
|
||||
});
|
||||
client->WaitForAllocation(stream, *buf);
|
||||
auto status = stream->Memcpy(dst, sub_buffer, transfer_size);
|
||||
if (status.ok()) {
|
||||
status = client->AllocateAndRecordEvent(device_event, local_device,
|
||||
stream);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
client->SetEventAsError(device_event, status);
|
||||
}
|
||||
});
|
||||
return tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
|
||||
std::move(device_event), "PjRtStreamExecutorRawBuffer",
|
||||
"CopyRawDeviceToHost");
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user