mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
parent
fe2a783077
commit
c7055c2e5b
|
|
@ -1367,7 +1367,7 @@ StreamExecutorGpuClient::RunAsync(
|
|||
|
||||
std::set<se::DeviceMemoryBase> buffers_in_result;
|
||||
|
||||
xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results(
|
||||
xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results(
|
||||
gpu_exec->result_shape());
|
||||
|
||||
for (auto& p : results) {
|
||||
|
|
@ -1452,7 +1452,7 @@ StreamExecutorGpuClient::RunAsync(
|
|||
TF_RETURN_IF_ERROR(buffer_allocations.TearDown(buffers_in_result,
|
||||
gpu_exec->GetAllocations()));
|
||||
|
||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released;
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released;
|
||||
|
||||
// Free allocations for arguments.
|
||||
for (ShapeTree<PjRtStreamExecutorExecutionInput>& input : arguments) {
|
||||
|
|
|
|||
|
|
@ -789,7 +789,7 @@ PjRtStreamExecutorClient::CreateErrorBuffer(absl::Status error,
|
|||
|
||||
// Create an empty buffer.
|
||||
auto dummy_device_buffer = std::make_unique<TrackedDeviceBuffer>(
|
||||
device, tsl::AsyncValueRef<RawSEDeviceMemory>(),
|
||||
device, tsl::RCReference<RawSEDeviceMemory>(),
|
||||
absl::MakeSpan(&definition_event, 1));
|
||||
|
||||
return std::make_unique<CommonPjRtBufferImpl>(
|
||||
|
|
@ -1168,13 +1168,13 @@ MakeTupleHelper(PjRtStreamExecutorClient* client,
|
|||
// Converts a ScopedShapedBuffer returned from an execution into a
|
||||
// PjRtBuffer.
|
||||
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
|
||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||
BufferSequencingEventRef definition_event, PjRtClient* client,
|
||||
PjRtDevice* device, LocalDeviceState* local_device) {
|
||||
if (result_buffer.shape().IsTuple()) {
|
||||
return absl::InternalError("OutputBufferHelper called on tuple.");
|
||||
}
|
||||
absl::InlinedVector<tsl::AsyncValueRef<RawSEDeviceMemory>, 1> buffers;
|
||||
absl::InlinedVector<tsl::RCReference<RawSEDeviceMemory>, 1> buffers;
|
||||
for (auto& item : result_buffer) {
|
||||
buffers.push_back(std::move(item.second));
|
||||
}
|
||||
|
|
@ -1641,7 +1641,7 @@ PjRtStreamExecutorClient::RunAsync(
|
|||
ExecutionOutput output,
|
||||
exec.RunAsync(std::move(xla_arguments), std::move(run_options)));
|
||||
ScopedShapedBuffer ssb = output.ConsumeResult();
|
||||
xla::ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> results(
|
||||
xla::ShapeTree<tsl::RCReference<RawSEDeviceMemory>> results(
|
||||
ssb.on_device_shape());
|
||||
auto it = results.begin();
|
||||
se::DeviceMemoryAllocator* allocator = ssb.memory_allocator();
|
||||
|
|
@ -1672,7 +1672,7 @@ PjRtStreamExecutorClient::RunAsync(
|
|||
// converted on success.
|
||||
// When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id`
|
||||
// to initialize `run_options`.
|
||||
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
|
||||
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||
PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
|
||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||
|
|
@ -1930,7 +1930,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
|||
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||
PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
||||
int device_ordinal, const ExecuteOptions& options,
|
||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
|
||||
tsl::profiler::TraceMe traceme("MakeOutputBuffers");
|
||||
|
|
@ -1943,7 +1943,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
|||
// in result_buffer.
|
||||
for (int i = 0; i < tuple_count; ++i) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> tuple_buffer,
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> tuple_buffer,
|
||||
result_buffer.SubShapeTree({i}));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> buffer,
|
||||
|
|
@ -2050,7 +2050,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
|||
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
|
||||
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
|
||||
device_buffers.reserve(argument_handles.size());
|
||||
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
|
||||
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||
result_buffer_or_status =
|
||||
EnqueueExecution(argument_handles, replica, partition, executable_idx,
|
||||
run_id, options, device, &device_buffers,
|
||||
|
|
@ -2061,7 +2061,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
|||
<< " failed: " << result_buffer_or_status.status();
|
||||
return result_buffer_or_status.status();
|
||||
}
|
||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer =
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer =
|
||||
std::move(result_buffer_or_status).value();
|
||||
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
|
|
@ -2081,14 +2081,14 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
|||
}
|
||||
return definition_event_or.status();
|
||||
}
|
||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> leaves_to_release;
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>> leaves_to_release;
|
||||
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||
leaves_to_release.reserve(result_buffer.leaf_count());
|
||||
for (auto& node : result_buffer.leaves()) {
|
||||
leaves_to_release.push_back(node.second);
|
||||
}
|
||||
}
|
||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release;
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release;
|
||||
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
|
||||
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
|
|
|||
|
|
@ -83,13 +83,13 @@ struct PjRtStreamExecutorExecutionInput {
|
|||
// Donation is not complete until ReleaseDeviceMemory() is called on the
|
||||
// TrackedDeviceBuffer that provides buf.
|
||||
bool is_donated;
|
||||
tsl::AsyncValueRef<RawSEDeviceMemory> buf;
|
||||
tsl::RCReference<RawSEDeviceMemory> buf;
|
||||
};
|
||||
|
||||
struct PjRtStreamExecutorExecutionOutput {
|
||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result;
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result;
|
||||
// Donated inputs which must be freed.
|
||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> to_be_released;
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>> to_be_released;
|
||||
// For PjRtStreamExecutorClient implementations that
|
||||
// use OwningDeviceMemory for donated inputs.
|
||||
std::vector<se::OwningDeviceMemory> se_to_be_released;
|
||||
|
|
@ -672,7 +672,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
|||
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
|
||||
|
||||
absl::StatusOr<ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>>>
|
||||
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||
EnqueueExecution(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
int partition, int executable_idx, const RunId& run_id,
|
||||
|
|
@ -684,7 +684,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
|||
virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||
MakeOutputBuffers(
|
||||
int device_ordinal, const ExecuteOptions& options,
|
||||
ShapeTree<tsl::AsyncValueRef<RawSEDeviceMemory>> result_buffer,
|
||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
|
||||
|
||||
|
|
|
|||
12
third_party/xla/xla/pjrt/se_raw_buffer.h
vendored
12
third_party/xla/xla/pjrt/se_raw_buffer.h
vendored
|
|
@ -91,10 +91,10 @@ class PjRtStreamExecutorDeviceEventPromise : public PjRtDeviceEventPromise {
|
|||
|
||||
class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
|
||||
public:
|
||||
PjRtStreamExecutorRawBuffer(
|
||||
PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space,
|
||||
LocalDeviceState* local_device,
|
||||
tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer)
|
||||
PjRtStreamExecutorRawBuffer(PjRtStreamExecutorClient* client,
|
||||
PjRtMemorySpace* memory_space,
|
||||
LocalDeviceState* local_device,
|
||||
tsl::RCReference<RawSEDeviceMemory> device_buffer)
|
||||
: client_(client),
|
||||
memory_space_(memory_space),
|
||||
local_device_(local_device),
|
||||
|
|
@ -104,7 +104,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
|
|||
|
||||
LocalDeviceState* local_device() const { return local_device_; }
|
||||
|
||||
const tsl::AsyncValueRef<RawSEDeviceMemory>& device_buffer() const {
|
||||
const tsl::RCReference<RawSEDeviceMemory>& device_buffer() const {
|
||||
return device_buffer_;
|
||||
}
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ class PjRtStreamExecutorRawBuffer : public CommonPjRtRawBuffer {
|
|||
PjRtStreamExecutorClient* client_;
|
||||
PjRtMemorySpace* memory_space_;
|
||||
LocalDeviceState* local_device_;
|
||||
tsl::AsyncValueRef<RawSEDeviceMemory> device_buffer_;
|
||||
tsl::RCReference<RawSEDeviceMemory> device_buffer_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
|||
|
|
@ -119,11 +119,11 @@ class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory {
|
|||
size_t sync_point_ = std::numeric_limits<size_t>::max();
|
||||
};
|
||||
|
||||
tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::Create(
|
||||
tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::Create(
|
||||
se::DeviceMemoryBase value, LocalDeviceState* local_device,
|
||||
se::DeviceMemoryAllocator* allocator) {
|
||||
return tsl::MakeAvailableAsyncValueRef<AllocatedRawSEDeviceMemory>(
|
||||
value, local_device, allocator);
|
||||
return tsl::MakeRef<AllocatedRawSEDeviceMemory>(value, local_device,
|
||||
allocator);
|
||||
}
|
||||
|
||||
class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
|
||||
|
|
@ -143,11 +143,11 @@ class ForeignRawSEDeviceMemory : public RawSEDeviceMemory {
|
|||
absl::AnyInvocable<void() &&> on_delete_callback_;
|
||||
};
|
||||
|
||||
tsl::AsyncValueRef<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign(
|
||||
tsl::RCReference<RawSEDeviceMemory> RawSEDeviceMemory::CreateForeign(
|
||||
se::DeviceMemoryBase value,
|
||||
absl::AnyInvocable<void() &&> on_delete_callback) {
|
||||
return tsl::MakeAvailableAsyncValueRef<ForeignRawSEDeviceMemory>(
|
||||
value, std::move(on_delete_callback));
|
||||
return tsl::MakeRef<ForeignRawSEDeviceMemory>(value,
|
||||
std::move(on_delete_callback));
|
||||
}
|
||||
|
||||
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
|
||||
|
|
@ -167,7 +167,7 @@ ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
|
|||
}
|
||||
|
||||
TrackedDeviceBuffer::TrackedDeviceBuffer(
|
||||
PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory,
|
||||
PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory,
|
||||
absl::Span<const BufferSequencingEventRef> definition_events)
|
||||
: device_(device),
|
||||
device_memory_(std::move(device_memory)),
|
||||
|
|
@ -178,7 +178,7 @@ TrackedDeviceBuffer::TrackedDeviceBuffer(
|
|||
TrackedDeviceBuffer::~TrackedDeviceBuffer() = default;
|
||||
|
||||
void TrackedDeviceBuffer::ReleaseDeviceMemory() {
|
||||
device_memory_ = tsl::AsyncValueRef<RawSEDeviceMemory>();
|
||||
device_memory_ = tsl::RCReference<RawSEDeviceMemory>();
|
||||
}
|
||||
|
||||
void TrackedDeviceBuffer::ConfirmDonation() {
|
||||
|
|
|
|||
11
third_party/xla/xla/pjrt/tracked_device_buffer.h
vendored
11
third_party/xla/xla/pjrt/tracked_device_buffer.h
vendored
|
|
@ -51,6 +51,7 @@ limitations under the License.
|
|||
|
||||
namespace xla {
|
||||
|
||||
// TODO(parkers): Implement PjRtRawBuffer API.
|
||||
class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
|
||||
public:
|
||||
explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {}
|
||||
|
|
@ -69,10 +70,10 @@ class RawSEDeviceMemory : public tsl::ReferenceCounted<RawSEDeviceMemory> {
|
|||
ShapedBuffer AsShapedBuffer(PjRtDevice* device,
|
||||
const Shape& on_device_shape) const;
|
||||
|
||||
static tsl::AsyncValueRef<RawSEDeviceMemory> Create(
|
||||
static tsl::RCReference<RawSEDeviceMemory> Create(
|
||||
se::DeviceMemoryBase value, LocalDeviceState* local_device,
|
||||
se::DeviceMemoryAllocator* allocator);
|
||||
static tsl::AsyncValueRef<RawSEDeviceMemory> CreateForeign(
|
||||
static tsl::RCReference<RawSEDeviceMemory> CreateForeign(
|
||||
se::DeviceMemoryBase value,
|
||||
absl::AnyInvocable<void() &&> on_delete_callback);
|
||||
|
||||
|
|
@ -129,7 +130,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
|||
ExecutionInput* execution_input,
|
||||
se::DeviceMemoryAllocator* allocator) const;
|
||||
|
||||
const tsl::AsyncValueRef<RawSEDeviceMemory>& device_memory() const {
|
||||
const tsl::RCReference<RawSEDeviceMemory>& device_memory() const {
|
||||
return device_memory_;
|
||||
}
|
||||
|
||||
|
|
@ -167,7 +168,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
|||
StreamAndEventContainer LockUseAndTransferUsageEvents();
|
||||
|
||||
TrackedDeviceBuffer(
|
||||
PjRtDevice* device, tsl::AsyncValueRef<RawSEDeviceMemory> device_memory,
|
||||
PjRtDevice* device, tsl::RCReference<RawSEDeviceMemory> device_memory,
|
||||
absl::Span<const BufferSequencingEventRef> definition_events);
|
||||
~TrackedDeviceBuffer() override;
|
||||
|
||||
|
|
@ -198,7 +199,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
|||
PjRtDevice* device_;
|
||||
|
||||
// Each host-side buffer may have several buffers on-device.
|
||||
tsl::AsyncValueRef<RawSEDeviceMemory> device_memory_;
|
||||
tsl::RCReference<RawSEDeviceMemory> device_memory_;
|
||||
|
||||
// Events that are triggered when the content of one or more buffers is ready
|
||||
// during multistream execution. May be nullptr, which is used in the
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class TestDevice : public PjRtDevice {
|
|||
|
||||
absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray(
|
||||
const Shape& shape, LocalClient* client, PjRtDevice* device) {
|
||||
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> device_buffers;
|
||||
std::vector<tsl::RCReference<RawSEDeviceMemory>> device_buffers;
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(shape),
|
||||
[&](const Shape& subshape, const ShapeIndex&) -> absl::Status {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user