mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
parent
fe2a783077
commit
c7055c2e5b
|
|
@ -1367,7 +1367,7 @@ StreamExecutorGpuClient::RunAsync(
|
||||||
|
|
||||||
std::set<se::DeviceMemoryBase> buffers_in_result;
|
std::set<se::DeviceMemoryBase> buffers_in_result;
|
||||||
|
|
||||||
xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results(
|
xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results(
|
||||||
gpu_exec->result_shape());
|
gpu_exec->result_shape());
|
||||||
|
|
||||||
for (auto& p : results) {
|
for (auto& p : results) {
|
||||||
|
|
@ -1452,7 +1452,7 @@ StreamExecutorGpuClient::RunAsync(
|
||||||
TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result,
|
TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result,
|
||||||
gpu_exec->GetAllocations()));
|
gpu_exec->GetAllocations()));
|
||||||
|
|
||||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released;
|
std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released;
|
||||||
|
|
||||||
// Free allocations for arguments.
|
// Free allocations for arguments.
|
||||||
for (ShapeTree<PjRtStreamExecutorExecutionInput>& input : arguments) {
|
for (ShapeTree<PjRtStreamExecutorExecutionInput>& input : arguments) {
|
||||||
|
|
|
||||||
|
|
@ -789,7 +789,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error,
|
||||||
|
|
||||||
// Create an empty buffer.
|
// Create an empty buffer.
|
||||||
auto dummy_device_buffer = std::make_unique<TrackedDeviceBuffer>(
|
auto dummy_device_buffer = std::make_unique<TrackedDeviceBuffer>(
|
||||||
device, tsl::AsyncValueRef<RawSEDeviceMemory>(),
|
device, tsl::RCReference<RawSEDeviceMemory>(),
|
||||||
absl::MakeSpan(&definition_event, 1));
|
absl::MakeSpan(&definition_event, 1));
|
||||||
|
|
||||||
return std::make_unique<CommonPjRtBufferImpl>(
|
return std::make_unique<CommonPjRtBufferImpl>(
|
||||||
|
|
@ -1168,13 +1168,13 @@ MakeTupleHelper(PjRtStreamExecutorClient* client,
|
||||||
// Converts a ScopedShapedBuffer returned from an execution into a
|
// Converts a ScopedShapedBuffer returned from an execution into a
|
||||||
// PjRtBuffer.
|
// PjRtBuffer.
|
||||||
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
|
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
|
||||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||||
BufferSequencingEventRef definition_event, PjRtClient* client,
|
BufferSequencingEventRef definition_event, PjRtClient* client,
|
||||||
PjRtDevice* device, LocalDeviceState* local_device) {
|
PjRtDevice* device, LocalDeviceState* local_device) {
|
||||||
if (result_buffer.shape().IsTuple()) {
|
if (result_buffer.shape().IsTuple()) {
|
||||||
return absl::InternalError("OutputBufferHelper called on tuple.");
|
return absl::InternalError("OutputBufferHelper called on tuple.");
|
||||||
}
|
}
|
||||||
absl::InlinedVector<tsl::AsyncValueRef<RawSEDeviceMemory>, 1> buffers;
|
absl::InlinedVector<tsl::RCReference<RawSEDeviceMemory>, 1> buffers;
|
||||||
for (auto& item : result_buffer) {
|
for (auto& item : result_buffer) {
|
||||||
buffers.push_back(std::move(item.second));
|
buffers.push_back(std::move(item.second));
|
||||||
}
|
}
|
||||||
|
|
@ -1641,7 +1641,7 @@ PjRtStreamExecutorClient::RunAsync(
|
||||||
ExecutionOutput output,
|
ExecutionOutput output,
|
||||||
exec.RunAsync(std::move(xla_arguments), std::move(run_options)));
|
exec.RunAsync(std::move(xla_arguments), std::move(run_options)));
|
||||||
ScopedShapedBuffer ssb = output.ConsumeResult();
|
ScopedShapedBuffer ssb = output.ConsumeResult();
|
||||||
xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results(
|
xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results(
|
||||||
ssb.on_device_shape());
|
ssb.on_device_shape());
|
||||||
auto it = results.begin();
|
auto it = results.begin();
|
||||||
se::DeviceMemoryAllocator* allocator = ssb.memory_allocator();
|
se::DeviceMemoryAllocator* allocator = ssb.memory_allocator();
|
||||||
|
|
@ -1672,7 +1672,7 @@ PjRtStreamExecutorClient::RunAsync(
|
||||||
// converted on success.
|
// converted on success.
|
||||||
// When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id`
|
// When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id`
|
||||||
// to initialize `run_options`.
|
// to initialize `run_options`.
|
||||||
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
|
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||||
PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
|
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
|
||||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||||
|
|
@ -1930,7 +1930,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
||||||
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||||
PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
||||||
int device_ordinal, const ExecuteOptions& options,
|
int device_ordinal, const ExecuteOptions& options,
|
||||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||||
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
||||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
|
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
|
||||||
tsl::profiler::TraceMe traceme("MakeOutputBuffers");
|
tsl::profiler::TraceMe traceme("MakeOutputBuffers");
|
||||||
|
|
@ -1943,7 +1943,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
||||||
// in result_buffer.
|
// in result_buffer.
|
||||||
for (int i = 0; i < tuple_count; ++i) {
|
for (int i = 0; i < tuple_count; ++i) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> tuple_buffer,
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> tuple_buffer,
|
||||||
result_buffer.SubShapeTree({i}));
|
result_buffer.SubShapeTree({i}));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
std::unique_ptr<PjRtBuffer> buffer,
|
||||||
|
|
@ -2050,7 +2050,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
||||||
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
|
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
|
||||||
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
|
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
|
||||||
device_buffers.reserve(argument_handles.size());
|
device_buffers.reserve(argument_handles.size());
|
||||||
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
|
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||||
result_buffer_or_status =
|
result_buffer_or_status =
|
||||||
EnqueueExecution(argument_handles, replica, partition, executable_idx,
|
EnqueueExecution(argument_handles, replica, partition, executable_idx,
|
||||||
run_id, options, device, &device_buffers,
|
run_id, options, device, &device_buffers,
|
||||||
|
|
@ -2061,7 +2061,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
||||||
<< " failed: " << result_buffer_or_status.status();
|
<< " failed: " << result_buffer_or_status.status();
|
||||||
return result_buffer_or_status.status();
|
return result_buffer_or_status.status();
|
||||||
}
|
}
|
||||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer =
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer =
|
||||||
std::move(result_buffer_or_status).value();
|
std::move(result_buffer_or_status).value();
|
||||||
|
|
||||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
|
|
@ -2081,14 +2081,14 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
||||||
}
|
}
|
||||||
return definition_event_or.status();
|
return definition_event_or.status();
|
||||||
}
|
}
|
||||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> leaves_to_release;
|
std::vector<tsl::RCReference<RawSEDeviceMemory>> leaves_to_release;
|
||||||
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||||
leaves_to_release.reserve(result_buffer.leaf_count());
|
leaves_to_release.reserve(result_buffer.leaf_count());
|
||||||
for (auto& node : result_buffer.leaves()) {
|
for (auto& node : result_buffer.leaves()) {
|
||||||
leaves_to_release.push_back(node.second);
|
leaves_to_release.push_back(node.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release;
|
std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release;
|
||||||
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
|
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
|
||||||
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
|
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
|
|
||||||
|
|
@ -83,13 +83,13 @@ struct PjRtStreamExecutorExecutionInput {
|
||||||
// Donation is not complete until ReleaseDeviceMemory() is called on the
|
// Donation is not complete until ReleaseDeviceMemory() is called on the
|
||||||
// TrackedDeviceBuffer that provides buf.
|
// TrackedDeviceBuffer that provides buf.
|
||||||
bool is_donated;
|
bool is_donated;
|
||||||
tsl::AsyncValueRef<RawSEDeviceMemory> buf;
|
tsl::RCReference<RawSEDeviceMemory> buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PjRtStreamExecutorExecutionOutput {
|
struct PjRtStreamExecutorExecutionOutput {
|
||||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result;
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result;
|
||||||
// Donated inputs which must be freed.
|
// Donated inputs which must be freed.
|
||||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released;
|
std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released;
|
||||||
// For PjRtStreamExecutorClient implementations that
|
// For PjRtStreamExecutorClient implementations that
|
||||||
// use OwningDeviceMemory for donated inputs.
|
// use OwningDeviceMemory for donated inputs.
|
||||||
std::vector<se::OwningDeviceMemory> se_to_be_released;
|
std::vector<se::OwningDeviceMemory> se_to_be_released;
|
||||||
|
|
@ -672,7 +672,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
||||||
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
||||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
|
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
|
||||||
|
|
||||||
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
|
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||||
EnqueueExecution(
|
EnqueueExecution(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||||
int partition, int executable_idx, const RunId& run_id,
|
int partition, int executable_idx, const RunId& run_id,
|
||||||
|
|
@ -684,7 +684,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
||||||
virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||||
MakeOutputBuffers(
|
MakeOutputBuffers(
|
||||||
int device_ordinal, const ExecuteOptions& options,
|
int device_ordinal, const ExecuteOptions& options,
|
||||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||||
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
||||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
|
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
|
||||||
|
|
||||||
|
|
|
||||||
12
third_party/xla/xla/pjrt/se_raw_buffer.h
vendored
12
third_party/xla/xla/pjrt/se_raw_buffer.h
vendored
|
|
@ -91,10 +91,10 @@ class PjRtStreamExecutorDeviceEventPromise : public PjRtDeviceEventPromise {
|
||||||
|
|
||||||
class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
|
class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
|
||||||
public:
|
public:
|
||||||
PjRtStreamExecutorRawBuffer(
|
PjRtStreamExecutorRawBuffer(PjRtStreamExecutorClient* client,
|
||||||
PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space,
|
PjRtMemorySpace* memory_space,
|
||||||
LocalDeviceState* local_device,
|
LocalDeviceState* local_device,
|
||||||
tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer)
|
tsl::RCReference<RawSEDeviceMemory> device_buffer)
|
||||||
: client_(client),
|
: client_(client),
|
||||||
memory_space_(memory_space),
|
memory_space_(memory_space),
|
||||||
local_device_(local_device),
|
local_device_(local_device),
|
||||||
|
|
@ -104,7 +104,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
|
||||||
|
|
||||||
LocalDeviceState* local_device() const { return local_device_; }
|
LocalDeviceState* local_device() const { return local_device_; }
|
||||||
|
|
||||||
const tsl::AsyncValueRef<RawSEDeviceMemory>& device_buffer() const {
|
const tsl::RCReference<RawSEDeviceMemory>& device_buffer() const {
|
||||||
return device_buffer_;
|
return device_buffer_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -150,7 +150,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
|
||||||
PjRtStreamExecutorClient* client_;
|
PjRtStreamExecutorClient* client_;
|
||||||
PjRtMemorySpace* memory_space_;
|
PjRtMemorySpace* memory_space_;
|
||||||
LocalDeviceState* local_device_;
|
LocalDeviceState* local_device_;
|
||||||
tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer_;
|
tsl::RCReference<RawSEDeviceMemory> device_buffer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
|
|
@ -119,11 +119,11 @@ class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory {
|
||||||
size_t sync_point_ = std::numeric_limits<size_t>::max();
|
size_t sync_point_ = std::numeric_limits<size_t>::max();
|
||||||
};
|
};
|
||||||
|
|
||||||
tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::Create(
|
tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::Create(
|
||||||
se::DeviceMemoryBase value, LocalDeviceState* local_device,
|
se::DeviceMemoryBase value, LocalDeviceState* local_device,
|
||||||
se::DeviceMemoryAllocator* allocator) {
|
se::DeviceMemoryAllocator* allocator) {
|
||||||
return tsl::MakeAvailableAsyncValueRef<AllocatedRawSEDeviceMemory>(
|
return tsl::MakeRef<AllocatedRawSEDeviceMemory>(value, local_device,
|
||||||
value, local_device, allocator);
|
allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
|
class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
|
||||||
|
|
@ -143,11 +143,11 @@ class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
|
||||||
absl::AnyInvocable<void() &&> on_delete_callback_;
|
absl::AnyInvocable<void() &&> on_delete_callback_;
|
||||||
};
|
};
|
||||||
|
|
||||||
tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign(
|
tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign(
|
||||||
se::DeviceMemoryBase value,
|
se::DeviceMemoryBase value,
|
||||||
absl::AnyInvocable<void() &&> on_delete_callback) {
|
absl::AnyInvocable<void() &&> on_delete_callback) {
|
||||||
return tsl::MakeAvailableAsyncValueRef<ForeignRawSEDeviceMemory>(
|
return tsl::MakeRef<ForeignRawSEDeviceMemory>(value,
|
||||||
value, std::move(on_delete_callback));
|
std::move(on_delete_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
|
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
|
||||||
|
|
@ -167,7 +167,7 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
|
||||||
}
|
}
|
||||||
|
|
||||||
TrackedDeviceBuffer::TrackedDeviceBuffer(
|
TrackedDeviceBuffer::TrackedDeviceBuffer(
|
||||||
PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory,
|
PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory,
|
||||||
absl::Span<const BufferSequencingEventRef> definition_events)
|
absl::Span<const BufferSequencingEventRef> definition_events)
|
||||||
: device_(device),
|
: device_(device),
|
||||||
device_memory_(std::move(device_memory)),
|
device_memory_(std::move(device_memory)),
|
||||||
|
|
@ -178,7 +178,7 @@ TrackedDeviceBuffer::TrackedDeviceBuffer(
|
||||||
TrackedDeviceBuffer::~TrackedDeviceBuffer() = default;
|
TrackedDeviceBuffer::~TrackedDeviceBuffer() = default;
|
||||||
|
|
||||||
void TrackedDeviceBuffer::ReleaseDeviceMemory() {
|
void TrackedDeviceBuffer::ReleaseDeviceMemory() {
|
||||||
device_memory_ = tsl::AsyncValueRef<RawSEDeviceMemory>();
|
device_memory_ = tsl::RCReference<RawSEDeviceMemory>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TrackedDeviceBuffer::ConfirmDonation() {
|
void TrackedDeviceBuffer::ConfirmDonation() {
|
||||||
|
|
|
||||||
11
third_party/xla/xla/pjrt/tracked_device_buffer.h
vendored
11
third_party/xla/xla/pjrt/tracked_device_buffer.h
vendored
|
|
@ -51,6 +51,7 @@ limitations under the License.
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
// TODO(parkers): Implement PjRtRawBuffer API.
|
||||||
class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
|
class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
|
||||||
public:
|
public:
|
||||||
explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {}
|
explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {}
|
||||||
|
|
@ -69,10 +70,10 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
|
||||||
ShapedBuffer AsShapedBuffer(PjRtDevice* device,
|
ShapedBuffer AsShapedBuffer(PjRtDevice* device,
|
||||||
const Shape& on_device_shape) const;
|
const Shape& on_device_shape) const;
|
||||||
|
|
||||||
static tsl::AsyncValueRef<RawSEDeviceMemory> Create(
|
static tsl::RCReference<RawSEDeviceMemory> Create(
|
||||||
se::DeviceMemoryBase value, LocalDeviceState* local_device,
|
se::DeviceMemoryBase value, LocalDeviceState* local_device,
|
||||||
se::DeviceMemoryAllocator* allocator);
|
se::DeviceMemoryAllocator* allocator);
|
||||||
static tsl::AsyncValueRef<RawSEDeviceMemory> CreateForeign(
|
static tsl::RCReference<RawSEDeviceMemory> CreateForeign(
|
||||||
se::DeviceMemoryBase value,
|
se::DeviceMemoryBase value,
|
||||||
absl::AnyInvocable<void() &&> on_delete_callback);
|
absl::AnyInvocable<void() &&> on_delete_callback);
|
||||||
|
|
||||||
|
|
@ -129,7 +130,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
||||||
ExecutionInput* execution_input,
|
ExecutionInput* execution_input,
|
||||||
se::DeviceMemoryAllocator* allocator) const;
|
se::DeviceMemoryAllocator* allocator) const;
|
||||||
|
|
||||||
const tsl::AsyncValueRef<RawSEDeviceMemory>& device_memory() const {
|
const tsl::RCReference<RawSEDeviceMemory>& device_memory() const {
|
||||||
return device_memory_;
|
return device_memory_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -167,7 +168,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
||||||
StreamAndEventContainer LockUseAndTransferUsageEvents();
|
StreamAndEventContainer LockUseAndTransferUsageEvents();
|
||||||
|
|
||||||
TrackedDeviceBuffer(
|
TrackedDeviceBuffer(
|
||||||
PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory,
|
PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory,
|
||||||
absl::Span<const BufferSequencingEventRef> definition_events);
|
absl::Span<const BufferSequencingEventRef> definition_events);
|
||||||
~TrackedDeviceBuffer() override;
|
~TrackedDeviceBuffer() override;
|
||||||
|
|
||||||
|
|
@ -198,7 +199,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
||||||
PjRtDevice* device_;
|
PjRtDevice* device_;
|
||||||
|
|
||||||
// Each host-side buffer may have several buffers on-device.
|
// Each host-side buffer may have several buffers on-device.
|
||||||
tsl::AsyncValueRef<RawSEDeviceMemory> device_memory_;
|
tsl::RCReference<RawSEDeviceMemory> device_memory_;
|
||||||
|
|
||||||
// Events that are triggered when the content of one or more buffers is ready
|
// Events that are triggered when the content of one or more buffers is ready
|
||||||
// during multistream execution. May be nullptr, which is used in the
|
// during multistream execution. May be nullptr, which is used in the
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class TestDevice : public PjRtDevice {
|
||||||
|
|
||||||
absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray(
|
absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray(
|
||||||
const Shape& shape, LocalClient* client, PjRtDevice* device) {
|
const Shape& shape, LocalClient* client, PjRtDevice* device) {
|
||||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> device_buffers;
|
std::vector<tsl::RCReference<RawSEDeviceMemory>> device_buffers;
|
||||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||||
client->backend().transfer_manager()->HostShapeToDeviceShape(shape),
|
client->backend().transfer_manager()->HostShapeToDeviceShape(shape),
|
||||||
[&](const Shape& subshape, const ShapeIndex&) -> absl::Status {
|
[&](const Shape& subshape, const ShapeIndex&) -> absl::Status {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user