mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Update PjRtStreamExecutorClient main execute path to use CommonPjRtBuffer::ScopedHold. Crucially
this now passes reference_held=true always. This is fine because the only time this was ever passed as false was if this was already on the compute stream and this bool is basically ignored if the stream is the compute stream (see MaybeWaitForEventOnStream). PiperOrigin-RevId: 822758577
This commit is contained in:
parent
880f245b56
commit
13ea97f3a9
1
third_party/xla/xla/pjrt/gpu/BUILD
vendored
1
third_party/xla/xla/pjrt/gpu/BUILD
vendored
|
|
@ -79,6 +79,7 @@ cc_library(
|
||||||
"//xla/core/collectives:communicator",
|
"//xla/core/collectives:communicator",
|
||||||
"//xla/core/collectives:rank_id",
|
"//xla/core/collectives:rank_id",
|
||||||
"//xla/hlo/builder:xla_computation",
|
"//xla/hlo/builder:xla_computation",
|
||||||
|
"//xla/pjrt:abstract_tracked_device_buffer",
|
||||||
"//xla/pjrt:common_pjrt_client",
|
"//xla/pjrt:common_pjrt_client",
|
||||||
"//xla/pjrt:device_event",
|
"//xla/pjrt:device_event",
|
||||||
"//xla/pjrt:event_pool",
|
"//xla/pjrt:event_pool",
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ limitations under the License.
|
||||||
#include "xla/hlo/builder/xla_computation.h"
|
#include "xla/hlo/builder/xla_computation.h"
|
||||||
#include "xla/layout.h"
|
#include "xla/layout.h"
|
||||||
#include "xla/literal.h"
|
#include "xla/literal.h"
|
||||||
|
#include "xla/pjrt/abstract_tracked_device_buffer.h"
|
||||||
#include "xla/pjrt/device_event.h"
|
#include "xla/pjrt/device_event.h"
|
||||||
#include "xla/pjrt/distributed/client.h"
|
#include "xla/pjrt/distributed/client.h"
|
||||||
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
|
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
|
||||||
|
|
@ -1043,11 +1044,14 @@ StreamExecutorGpuClient::MakeCrossHostReceiveBuffers(
|
||||||
definition_event));
|
definition_event));
|
||||||
|
|
||||||
// Acquire a hold on the buffer to access the underlying memory.
|
// Acquire a hold on the buffer to access the underlying memory.
|
||||||
PjRtStreamExecutorBuffer::ScopedHold hold = buffer->GetBufferWithUsageHold();
|
CommonPjRtBuffer::ScopedHold hold =
|
||||||
|
buffer->GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kUsage);
|
||||||
|
|
||||||
auto recv = [this, gpu_collectives, notifier, local_device, definition_event,
|
auto recv = [this, gpu_collectives, notifier, local_device, definition_event,
|
||||||
stream, mem = hold->device_memory(), shape = shapes[0],
|
stream,
|
||||||
dtype = buffer->element_type()]() mutable {
|
mem = tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer())
|
||||||
|
->device_memory(),
|
||||||
|
shape = shapes[0], dtype = buffer->element_type()]() mutable {
|
||||||
auto f = [&]() -> absl::Status {
|
auto f = [&]() -> absl::Status {
|
||||||
// Create a CliqueId.
|
// Create a CliqueId.
|
||||||
TF_ASSIGN_OR_RETURN(CliqueId clique_id,
|
TF_ASSIGN_OR_RETURN(CliqueId clique_id,
|
||||||
|
|
|
||||||
|
|
@ -727,15 +727,16 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) {
|
||||||
VLOG(1) << "PjRtStreamExecutorBuffer::DonateWithControlDependency";
|
VLOG(1) << "PjRtStreamExecutorBuffer::DonateWithControlDependency";
|
||||||
std::unique_ptr<PjRtBuffer> new_buffer;
|
std::unique_ptr<PjRtBuffer> new_buffer;
|
||||||
|
|
||||||
auto tracked_buffer =
|
auto hold = GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kDonation);
|
||||||
GetBufferWithHold(PjRtStreamExecutorBuffer::ScopedHold::kDonation);
|
|
||||||
|
|
||||||
if (!tracked_buffer.ok()) {
|
if (!hold.ok()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Invalid buffer passed to DonateWithControlDependency: %s",
|
"Invalid buffer passed to DonateWithControlDependency: %s",
|
||||||
tracked_buffer.status().ToString());
|
hold.status().ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto* tracked_buffer =
|
||||||
|
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
|
||||||
// Copy all the data in the existing tracked_buffer.
|
// Copy all the data in the existing tracked_buffer.
|
||||||
const auto& original_definition_events = tracked_buffer->definition_events();
|
const auto& original_definition_events = tracked_buffer->definition_events();
|
||||||
absl::InlinedVector<BufferSequencingEventRef, 4> definition_events;
|
absl::InlinedVector<BufferSequencingEventRef, 4> definition_events;
|
||||||
|
|
@ -774,7 +775,7 @@ PjRtStreamExecutorBuffer::DonateWithControlDependency(Future<> dependency) {
|
||||||
local_device->ReturnStreamToPool(std::move(stream));
|
local_device->ReturnStreamToPool(std::move(stream));
|
||||||
});
|
});
|
||||||
|
|
||||||
tracked_buffer.ConfirmDonation();
|
hold.ConfirmDonation();
|
||||||
return new_buffer;
|
return new_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1420,7 +1421,9 @@ Future<> PjRtStreamExecutorBuffer::GetReadyFuture() {
|
||||||
"GetReadyFuture() called on deleted or donated buffer"));
|
"GetReadyFuture() called on deleted or donated buffer"));
|
||||||
}
|
}
|
||||||
if (!definition_future_) {
|
if (!definition_future_) {
|
||||||
definition_events = device_buffer()->definition_events();
|
definition_events =
|
||||||
|
tensorflow::down_cast<TrackedDeviceBuffer*>(device_buffer())
|
||||||
|
->definition_events();
|
||||||
std::tie(definition_promise, definition_future_) =
|
std::tie(definition_promise, definition_future_) =
|
||||||
Future<>::MakePromise();
|
Future<>::MakePromise();
|
||||||
}
|
}
|
||||||
|
|
@ -1563,12 +1566,12 @@ absl::Status CheckCompatibleShapes(bool strict_shape_checking,
|
||||||
// Makes a tuple from the arguments to an execution.
|
// Makes a tuple from the arguments to an execution.
|
||||||
static absl::StatusOr<std::pair<ShapeTree<PjRtStreamExecutorExecutionInput>,
|
static absl::StatusOr<std::pair<ShapeTree<PjRtStreamExecutorExecutionInput>,
|
||||||
BufferSequencingEventRef>>
|
BufferSequencingEventRef>>
|
||||||
MakeTupleHelper(
|
MakeTupleHelper(PjRtStreamExecutorClient* client,
|
||||||
PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
|
LocalDeviceState* local_device, bool strict_shape_checking,
|
||||||
bool strict_shape_checking, const Shape& tupled_parameter_shape,
|
const Shape& tupled_parameter_shape,
|
||||||
absl::Span<PjRtBuffer* const> py_buffers,
|
absl::Span<PjRtBuffer* const> py_buffers,
|
||||||
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
|
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
||||||
int device_ordinal) {
|
int device_ordinal) {
|
||||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager =
|
||||||
client->client()->backend().transfer_manager();
|
client->client()->backend().transfer_manager();
|
||||||
|
|
@ -1610,11 +1613,11 @@ MakeTupleHelper(
|
||||||
local_device, allocator)};
|
local_device, allocator)};
|
||||||
++input_iterator;
|
++input_iterator;
|
||||||
// Then set each sub-tuple in turn from the parameters.
|
// Then set each sub-tuple in turn from the parameters.
|
||||||
for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
|
for (const CommonPjRtBuffer::ScopedHold& device_buffer : device_buffers) {
|
||||||
device_buffers) {
|
|
||||||
input_iterator->second = {
|
input_iterator->second = {
|
||||||
device_buffer.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation,
|
device_buffer.type() == CommonPjRtBuffer::ScopedHold::kDonation,
|
||||||
device_buffer->device_memory()};
|
tensorflow::down_cast<TrackedDeviceBuffer*>(device_buffer.buffer())
|
||||||
|
->device_memory()};
|
||||||
++input_iterator;
|
++input_iterator;
|
||||||
}
|
}
|
||||||
CHECK(input_iterator == iterator_end);
|
CHECK(input_iterator == iterator_end);
|
||||||
|
|
@ -1643,8 +1646,7 @@ MakeTupleHelper(
|
||||||
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
|
absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
|
||||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||||
BufferSequencingEventRef definition_event, PjRtClient* client,
|
BufferSequencingEventRef definition_event, PjRtClient* client,
|
||||||
PjRtDevice* device, LocalDeviceState* local_device,
|
PjRtDevice* device, LocalDeviceState* local_device) {
|
||||||
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release) {
|
|
||||||
if (result_buffer.shape().IsTuple()) {
|
if (result_buffer.shape().IsTuple()) {
|
||||||
return absl::InternalError("OutputBufferHelper called on tuple.");
|
return absl::InternalError("OutputBufferHelper called on tuple.");
|
||||||
}
|
}
|
||||||
|
|
@ -1680,9 +1682,6 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
|
||||||
auto pjrt_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
|
auto pjrt_buffer = std::make_unique<PjRtStreamExecutorBuffer>(
|
||||||
result_buffer.shape(), std::move(out_buffer), client, device,
|
result_buffer.shape(), std::move(out_buffer), client, device,
|
||||||
memory_space);
|
memory_space);
|
||||||
RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
|
|
||||||
definition_event, local_device->compute_stream(),
|
|
||||||
&buffers_to_release);
|
|
||||||
return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
|
return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1798,7 +1797,7 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents(
|
||||||
int device_ordinal, const ExecuteOptions& options,
|
int device_ordinal, const ExecuteOptions& options,
|
||||||
absl::Span<const Shape> executable_parameter_shapes,
|
absl::Span<const Shape> executable_parameter_shapes,
|
||||||
absl::Span<PjRtBuffer* const> argument_handles,
|
absl::Span<PjRtBuffer* const> argument_handles,
|
||||||
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
|
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
||||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
||||||
std::vector<ShapeTree<PjRtStreamExecutorExecutionInput>> execution_inputs;
|
std::vector<ShapeTree<PjRtStreamExecutorExecutionInput>> execution_inputs;
|
||||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
|
|
@ -1834,11 +1833,12 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents(
|
||||||
execution_inputs.back();
|
execution_inputs.back();
|
||||||
auto input_iterator = execution_input.begin();
|
auto input_iterator = execution_input.begin();
|
||||||
auto iterator_end = execution_input.end();
|
auto iterator_end = execution_input.end();
|
||||||
const auto& buf = device_buffers[i]->device_memory();
|
const auto& buf = tensorflow::down_cast<TrackedDeviceBuffer*>(
|
||||||
|
device_buffers[i].buffer())
|
||||||
|
->device_memory();
|
||||||
CHECK(input_iterator != iterator_end);
|
CHECK(input_iterator != iterator_end);
|
||||||
input_iterator->second = {
|
input_iterator->second = {
|
||||||
device_buffers[i].type() ==
|
device_buffers[i].type() == CommonPjRtBuffer::ScopedHold::kDonation,
|
||||||
PjRtStreamExecutorBuffer::ScopedHold::kDonation,
|
|
||||||
buf};
|
buf};
|
||||||
++input_iterator;
|
++input_iterator;
|
||||||
CHECK(input_iterator == iterator_end);
|
CHECK(input_iterator == iterator_end);
|
||||||
|
|
@ -2154,7 +2154,7 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
|
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
|
||||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||||
PjRtDevice* device,
|
PjRtDevice* device,
|
||||||
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
|
std::vector<CommonPjRtBuffer::ScopedHold>* device_buffers,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
|
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
|
||||||
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
|
@ -2194,16 +2194,17 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
|
||||||
TF_RETURN_IF_ERROR(TestBufferDonationClashes(
|
TF_RETURN_IF_ERROR(TestBufferDonationClashes(
|
||||||
handle, donation_clashes, must_donate, i, replica, partition));
|
handle, donation_clashes, must_donate, i, replica, partition));
|
||||||
device_buffers->emplace_back(handle->GetBufferWithHold(
|
device_buffers->emplace_back(handle->GetBufferWithHold(
|
||||||
must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
|
must_donate ? CommonPjRtBuffer::ScopedHold::kDonation
|
||||||
: PjRtStreamExecutorBuffer::ScopedHold::kUsage));
|
: CommonPjRtBuffer::ScopedHold::kUsage));
|
||||||
PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
|
CommonPjRtBuffer::ScopedHold& hold = device_buffers->back();
|
||||||
device_buffers->back();
|
if (!hold.ok()) {
|
||||||
if (!device_buffer.ok()) {
|
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Invalid buffer passed to Execute() as argument %d to replica %d: "
|
"Invalid buffer passed to Execute() as argument %d to replica %d: "
|
||||||
"%s",
|
"%s",
|
||||||
i, replica, device_buffer.status().ToString());
|
i, replica, hold.status().ToString());
|
||||||
}
|
}
|
||||||
|
auto* device_buffer =
|
||||||
|
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
|
||||||
// If we are trying to donate the buffer wait on the usage events as well
|
// If we are trying to donate the buffer wait on the usage events as well
|
||||||
// as the definition events to ensure that all reads have been completed
|
// as the definition events to ensure that all reads have been completed
|
||||||
// before the buffer is mutated. Usage holds are excluded during a donation
|
// before the buffer is mutated. Usage holds are excluded during a donation
|
||||||
|
|
@ -2408,9 +2409,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
||||||
int device_ordinal, const ExecuteOptions& options,
|
int device_ordinal, const ExecuteOptions& options,
|
||||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||||
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
||||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks,
|
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
|
||||||
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release)
|
|
||||||
const {
|
|
||||||
tsl::profiler::TraceMe traceme("MakeOutputBuffers");
|
tsl::profiler::TraceMe traceme("MakeOutputBuffers");
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
||||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
|
|
@ -2426,7 +2425,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
std::unique_ptr<PjRtBuffer> buffer,
|
||||||
OutputBufferHelper(std::move(tuple_buffer), definition_event, client_,
|
OutputBufferHelper(std::move(tuple_buffer), definition_event, client_,
|
||||||
device, device_state, buffers_to_release));
|
device, device_state));
|
||||||
outputs.push_back(std::move(buffer));
|
outputs.push_back(std::move(buffer));
|
||||||
}
|
}
|
||||||
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||||
|
|
@ -2438,7 +2437,7 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<PjRtBuffer> buffer,
|
std::unique_ptr<PjRtBuffer> buffer,
|
||||||
OutputBufferHelper(std::move(result_buffer), definition_event, client_,
|
OutputBufferHelper(std::move(result_buffer), definition_event, client_,
|
||||||
device, device_state, buffers_to_release));
|
device, device_state));
|
||||||
outputs.push_back(std::move(buffer));
|
outputs.push_back(std::move(buffer));
|
||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
|
|
@ -2447,13 +2446,15 @@ PjRtStreamExecutorLoadedExecutable::MakeOutputBuffers(
|
||||||
static absl::Status GetFirstInputError(
|
static absl::Status GetFirstInputError(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles) {
|
absl::Span<PjRtBuffer* const> argument_handles) {
|
||||||
for (auto* handle : argument_handles) {
|
for (auto* handle : argument_handles) {
|
||||||
auto* buffer = tensorflow::down_cast<PjRtStreamExecutorBuffer*>(handle);
|
auto* buffer = tensorflow::down_cast<CommonPjRtBuffer*>(handle);
|
||||||
PjRtStreamExecutorBuffer::ScopedHold hold =
|
CommonPjRtBuffer::ScopedHold hold =
|
||||||
buffer->GetBufferWithUsageHold();
|
buffer->GetBufferWithHold(CommonPjRtBuffer::ScopedHold::kUsage);
|
||||||
if (!hold.ok()) {
|
if (!hold.ok()) {
|
||||||
return hold.status();
|
return hold.status();
|
||||||
}
|
}
|
||||||
for (const auto& event : hold->definition_events()) {
|
auto* tracked_buffer =
|
||||||
|
tensorflow::down_cast<TrackedDeviceBuffer*>(hold.buffer());
|
||||||
|
for (const auto& event : tracked_buffer->definition_events()) {
|
||||||
if (event->IsPredeterminedError()) {
|
if (event->IsPredeterminedError()) {
|
||||||
return event->GetDefinedStatus();
|
return event->GetDefinedStatus();
|
||||||
}
|
}
|
||||||
|
|
@ -2524,7 +2525,7 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
||||||
int executable_idx = executables_.size() > 1 ? partition : 0;
|
int executable_idx = executables_.size() > 1 ? partition : 0;
|
||||||
|
|
||||||
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
|
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
|
||||||
std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
|
std::vector<CommonPjRtBuffer::ScopedHold> device_buffers;
|
||||||
device_buffers.reserve(argument_handles.size());
|
device_buffers.reserve(argument_handles.size());
|
||||||
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||||
result_buffer_or_status =
|
result_buffer_or_status =
|
||||||
|
|
@ -2543,33 +2544,46 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
||||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
se::Stream* stream = device_state->compute_stream();
|
se::Stream* stream = device_state->compute_stream();
|
||||||
|
|
||||||
auto definition_event = device_state->GetEventForComputeStreamSyncPoint(
|
auto definition_event_or = device_state->GetEventForComputeStreamSyncPoint(
|
||||||
device_state->GetNextComputeStreamSyncPoint(), client_->thread_pool());
|
device_state->GetNextComputeStreamSyncPoint(), client_->thread_pool());
|
||||||
if (!definition_event.ok()) {
|
if (!definition_event_or.ok()) {
|
||||||
StallStreamOnError(device_state, stream);
|
StallStreamOnError(device_state, stream);
|
||||||
for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
|
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
|
||||||
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
|
if (b.type() == CommonPjRtBuffer::ScopedHold::kDonation) {
|
||||||
// Even though there was an error we need to call ConfirmDonation, which
|
// Even though there was an error we need to call ConfirmDonation, which
|
||||||
// renders b invalid, since the computation has been enqueued and b has
|
// renders b invalid, since the computation has been enqueued and b has
|
||||||
// been donated.
|
// been donated.
|
||||||
b.ConfirmDonation();
|
b.ConfirmDonation();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return definition_event.status();
|
return definition_event_or.status();
|
||||||
|
}
|
||||||
|
std::vector<tsl::RCReference<RawSEDeviceMemory>> leaves_to_release;
|
||||||
|
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||||
|
leaves_to_release.reserve(result_buffer.leaf_count());
|
||||||
|
for (auto& node : result_buffer.leaves()) {
|
||||||
|
leaves_to_release.push_back(node.second);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release;
|
std::vector<tsl::RCReference<RawSEDeviceMemory>> buffers_to_release;
|
||||||
|
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
|
||||||
|
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs,
|
std::vector<std::unique_ptr<PjRtBuffer>> outputs,
|
||||||
MakeOutputBuffers(device_ordinal, options, std::move(result_buffer),
|
MakeOutputBuffers(device_ordinal, options, std::move(result_buffer),
|
||||||
*definition_event, device, compute_callbacks,
|
*definition_event_or, device, compute_callbacks));
|
||||||
buffers_to_release));
|
|
||||||
|
|
||||||
for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
|
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
|
||||||
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
|
if (b.type() == CommonPjRtBuffer::ScopedHold::kUsage) {
|
||||||
RecordUsage(std::move(b), device_state, device_state, *definition_event,
|
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||||
stream, &buffers_to_release);
|
buffers_to_release.push_back(
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(
|
||||||
|
b.buffer()->GetRawBuffer(b.parent()->memory_space()).get())
|
||||||
|
->device_buffer());
|
||||||
|
}
|
||||||
|
b.ConvertUsageHold(definition_event);
|
||||||
} else {
|
} else {
|
||||||
CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
|
CHECK(b.type() == CommonPjRtBuffer::ScopedHold::kDonation);
|
||||||
b.ConfirmDonation();
|
b.ConfirmDonation();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -2583,7 +2597,8 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
|
||||||
}
|
}
|
||||||
definition_event->AndThen(
|
definition_event->AndThen(
|
||||||
[callbacks{std::move(compute_callbacks)},
|
[callbacks{std::move(compute_callbacks)},
|
||||||
buffers_to_release{std::move(buffers_to_release)}]() mutable {
|
buffers_to_release{std::move(buffers_to_release)},
|
||||||
|
leaves_to_release = std::move(leaves_to_release)]() mutable {
|
||||||
for (auto& fn : callbacks) {
|
for (auto& fn : callbacks) {
|
||||||
std::move(fn)();
|
std::move(fn)();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -824,7 +824,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
||||||
int device_ordinal, const ExecuteOptions& options,
|
int device_ordinal, const ExecuteOptions& options,
|
||||||
absl::Span<const Shape> executable_parameter_shapes,
|
absl::Span<const Shape> executable_parameter_shapes,
|
||||||
absl::Span<PjRtBuffer* const> argument_handles,
|
absl::Span<PjRtBuffer* const> argument_handles,
|
||||||
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
|
absl::Span<const CommonPjRtBuffer::ScopedHold> device_buffers,
|
||||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
|
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
|
||||||
|
|
||||||
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
absl::StatusOr<ShapeTree<tsl::RCReference<RawSEDeviceMemory>>>
|
||||||
|
|
@ -832,7 +832,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||||
int partition, int executable_idx, const RunId& run_id,
|
int partition, int executable_idx, const RunId& run_id,
|
||||||
const ExecuteOptions& options, PjRtDevice* device,
|
const ExecuteOptions& options, PjRtDevice* device,
|
||||||
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
|
std::vector<CommonPjRtBuffer::ScopedHold>* device_buffers,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
|
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
|
||||||
|
|
||||||
|
|
@ -841,9 +841,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
||||||
int device_ordinal, const ExecuteOptions& options,
|
int device_ordinal, const ExecuteOptions& options,
|
||||||
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
ShapeTree<tsl::RCReference<RawSEDeviceMemory>> result_buffer,
|
||||||
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
BufferSequencingEventRef definition_event, PjRtDevice* device,
|
||||||
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks,
|
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
|
||||||
std::vector<tsl::RCReference<RawSEDeviceMemory>>& buffers_to_release)
|
|
||||||
const;
|
|
||||||
|
|
||||||
absl::StatusOr<Result> ExecuteHelper(
|
absl::StatusOr<Result> ExecuteHelper(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user