Add proto [de]serialization for HostExecuteStartThunk

PiperOrigin-RevId: 820645056
This commit is contained in:
Aliia Khasanova 2025-10-17 05:11:47 -07:00 committed by TensorFlower Gardener
parent 0bb1532ddf
commit 30d25d6d18
6 changed files with 170 additions and 17 deletions

View File

@ -2449,6 +2449,7 @@ tf_proto_library(
":convolution_filter_thunk_proto",
":dynamic_slice_thunk_proto",
"//xla:xla_data_proto",
"//xla/core/host_offloading:host_offloading_executable_proto",
"//xla/service:buffer_assignment_proto",
"//xla/service/gpu:backend_configs",
"//xla/service/gpu:gpu_conv_runner_proto",
@ -2744,6 +2745,7 @@ xla_test(
deps = [
":host_execute_thunk",
":thunk",
":thunk_proto_cc",
"//xla:executable_run_options",
"//xla:literal",
"//xla:literal_util",
@ -2770,12 +2772,14 @@ xla_test(
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:status_matchers",
"//xla/tsl/platform:statusor",
"//xla/tsl/util/proto:proto_matchers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:protobuf",
],
)

View File

@ -19,11 +19,13 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/base/call_once.h"
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
@ -425,24 +427,84 @@ HostExecuteStartThunk::HostExecuteStartThunk(
Thunk::ThunkInfo thunk_info,
const HostOffloadingExecutableProto& host_offloading_executable_proto,
absl::InlinedVector<HostExecuteStartThunk::SliceAndShape, 4> args,
absl::InlinedVector<HostExecuteStartThunk::SliceAndShape, 4> results)
absl::InlinedVector<HostExecuteStartThunk::SliceAndShape, 4> results,
std::shared_ptr<HostExecuteAsyncEvents> async_events)
: Thunk(Thunk::Kind::kHostExecuteStart, std::move(thunk_info)),
args_(std::move(args)),
results_(std::move(results)),
executable_proto_(host_offloading_executable_proto),
async_events_(std::make_shared<HostExecuteAsyncEvents>()) {}
executable_proto_(host_offloading_executable_proto) {
async_events_ =
async_events ? async_events : std::make_shared<HostExecuteAsyncEvents>();
}
std::string HostExecuteStartThunk::ToString(int indent) const { return ""; }
absl::StatusOr<ThunkProto> HostExecuteStartThunk::ToProto() const {
return Unimplemented("Not implemented yet.");
ThunkProto proto;
*proto.mutable_thunk_info() = thunk_info().ToProto();
HostExecuteStartThunkProto* host_execute_start_thunk_proto =
proto.mutable_host_execute_start_thunk();
*host_execute_start_thunk_proto->mutable_executable_proto() =
executable_proto_;
for (const auto& [slice, shape] : args_) {
ShapedSliceProto* arg_proto = host_execute_start_thunk_proto->add_args();
TF_ASSIGN_OR_RETURN(*arg_proto->mutable_slice(), slice.ToProto());
*arg_proto->mutable_shape() = shape.ToProto();
}
for (const auto& [slice, shape] : results_) {
ShapedSliceProto* result_proto =
host_execute_start_thunk_proto->add_results();
TF_ASSIGN_OR_RETURN(*result_proto->mutable_slice(), slice.ToProto());
*result_proto->mutable_shape() = shape.ToProto();
}
auto async_events_unique_id = GetAsyncEventsUniqueId();
// By design, async_events_unique_id should always be present for
// HostExecuteStartThunk.
CHECK_NE(async_events_unique_id, std::nullopt);
host_execute_start_thunk_proto->set_async_events_unique_id(
async_events_unique_id.value().value());
return proto;
}
absl::StatusOr<std::unique_ptr<HostExecuteStartThunk>>
HostExecuteStartThunk::FromProto(
ThunkInfo thunk_info, const HostExecuteStartThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations) {
return Unimplemented("Not implemented yet.");
absl::Span<const BufferAllocation> buffer_allocations,
HostExecuteAsyncEventsMap& async_events_map) {
absl::InlinedVector<HostExecuteStartThunk::SliceAndShape, 4> args, results;
auto shaped_slice_from_proto =
[&](const auto& shaped_slice_protos,
absl::InlinedVector<HostExecuteStartThunk::SliceAndShape, 4>&
slices_and_shapes) -> absl::Status {
for (const auto& shaped_slice_proto : shaped_slice_protos) {
TF_ASSIGN_OR_RETURN(auto slice,
BufferAllocation::Slice::FromProto(
shaped_slice_proto.slice(), buffer_allocations));
TF_ASSIGN_OR_RETURN(auto shape,
Shape::FromProto(shaped_slice_proto.shape()));
slices_and_shapes.push_back({slice, shape});
}
return absl::OkStatus();
};
TF_RETURN_IF_ERROR(shaped_slice_from_proto(proto.args(), args));
TF_RETURN_IF_ERROR(shaped_slice_from_proto(proto.results(), results));
// If async_events_map already contains an entry for the given unique id,
// that means that the pairing done thunk is already serialized and we reuse
// the id to connect them. Otherwise, create a new entry.
auto [async_event_it, _] = async_events_map.try_emplace(
AsyncEventsUniqueId(proto.async_events_unique_id()),
std::make_shared<HostExecuteAsyncEvents>());
return std::make_unique<HostExecuteStartThunk>(
thunk_info, proto.executable_proto(), std::move(args), std::move(results),
async_event_it->second);
}
static HostOffloadingAllocator* GetHostOffloadingAllocator(
@ -556,6 +618,14 @@ absl::Status HostExecuteStartThunk::ExecuteOnStream(
return absl::OkStatus();
}
std::optional<AsyncEventsUniqueId>
HostExecuteStartThunk::GetAsyncEventsUniqueId() const {
CHECK(async_events_)
<< "async_events_ must not be null in HostExecuteStartThunk";
// We rely on the fact that the pointer to async_events_ is unique.
return absl::bit_cast<AsyncEventsUniqueId>(async_events_.get());
}
// HostExecuteDoneThunk
HostExecuteDoneThunk::HostExecuteDoneThunk(

View File

@ -17,6 +17,7 @@ limitations under the License.
#define XLA_BACKENDS_GPU_RUNTIME_HOST_EXECUTE_THUNK_H_
#include <memory>
#include <optional>
#include <string>
#include <utility>
@ -70,6 +71,10 @@ class HostExecuteAsyncEvents {
events_ ABSL_GUARDED_BY(events_mu_);
};
using HostExecuteAsyncEventsMap =
absl::flat_hash_map<AsyncEventsUniqueId,
std::shared_ptr<HostExecuteAsyncEvents>>;
class HostExecuteStartThunk : public Thunk {
public:
struct SliceAndShape {
@ -95,9 +100,14 @@ class HostExecuteStartThunk : public Thunk {
std::string ToString(int indent) const override;
absl::StatusOr<ThunkProto> ToProto() const override;
// If async_events_map already contains an entry for the given unique id, we
// reuse the id to connect the start and done thunks. Otherwise, insert a new
// entry into the map.
static absl::StatusOr<std::unique_ptr<HostExecuteStartThunk>> FromProto(
ThunkInfo thunk_info, const HostExecuteStartThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations);
absl::Span<const BufferAllocation> buffer_allocations,
HostExecuteAsyncEventsMap& async_events_map);
absl::Status Initialize(const InitializeParams& params) override;
absl::Status ExecuteOnStream(const ExecuteParams& params) override;
@ -118,12 +128,14 @@ class HostExecuteStartThunk : public Thunk {
return &executable_proto_;
}
protected:
std::optional<AsyncEventsUniqueId> GetAsyncEventsUniqueId() const override;
HostExecuteStartThunk(
Thunk::ThunkInfo thunk_info,
const HostOffloadingExecutableProto& host_offloading_executable_proto,
absl::InlinedVector<SliceAndShape, 4> args,
absl::InlinedVector<SliceAndShape, 4> results);
absl::InlinedVector<SliceAndShape, 4> results,
std::shared_ptr<HostExecuteAsyncEvents> async_events = nullptr);
private:
absl::once_flag executable_init_flag_;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@ -52,6 +53,7 @@ limitations under the License.
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/status_matchers.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/util/proto/proto_matchers.h"
#include "xla/util.h"
#include "tsl/platform/casts.h"
@ -584,6 +586,60 @@ TEST(HostExecuteDoneThunkTest, WaitingOnErrorEvent) {
absl_testing::StatusIs(absl::StatusCode::kInternal));
}
TEST(HostExecuteStartThunkTest, ProtoRoundTrip) {
static constexpr char const* kHloModule = R"(
HloModule module
ENTRY add_inplace {
p0 = s32[] parameter(0)
ROOT add = s32[] add(p0, p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kHloModule, {}));
BufferAllocation alloc_arg(/*index=*/0, 4, /*color=*/0);
BufferAllocation alloc_result(/*index=*/1, 4, /*color=*/0);
BufferAllocation::Slice slice_arg(&alloc_arg, 0, 4);
BufferAllocation::Slice slice_result(&alloc_result, 0, 4);
TF_ASSERT_OK_AND_ASSIGN(auto thunk,
CreateHostExecuteStartThunk(
Thunk::ThunkInfo(), *hlo_module,
{{slice_arg, ShapeUtil::MakeShape(S32, {})}},
{{slice_result, ShapeUtil::MakeShape(S32, {})}}));
TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk->ToProto());
std::vector<BufferAllocation> buffer_allocations = {
BufferAllocation(/*index=*/0, /*size=*/4, /*color=*/0),
BufferAllocation(/*index=*/1, /*size=*/4, /*color=*/0)};
TF_ASSERT_OK_AND_ASSIGN(Thunk::ThunkInfo thunk_info,
Thunk::ThunkInfo::FromProto(proto.thunk_info()));
HostExecuteAsyncEventsMap async_events_map;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HostExecuteStartThunk> round_trip_thunk,
HostExecuteStartThunk::FromProto(thunk_info,
proto.host_execute_start_thunk(),
buffer_allocations, async_events_map));
TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto,
round_trip_thunk->ToProto());
EXPECT_EQ(async_events_map.size(), 1);
EXPECT_EQ(async_events_map.begin()->first,
thunk->GetAsyncEventsUniqueId().value());
// ids are expected to be different, so drop them for the comparison.
round_trip_proto.mutable_host_execute_start_thunk()
->clear_async_events_unique_id();
proto.mutable_host_execute_start_thunk()->clear_async_events_unique_id();
EXPECT_THAT(round_trip_proto, tsl::proto_testing::EqualsProto(proto));
}
} // namespace
} // namespace gpu

View File

@ -19,6 +19,7 @@ package xla.gpu;
import "xla/backends/gpu/runtime/convolution_filter_thunk.proto";
import "xla/backends/gpu/runtime/dynamic_slice_thunk.proto";
import "xla/core/host_offloading/host_offloading_executable.proto";
import "xla/service/buffer_assignment.proto";
import "xla/service/gpu/gpu_conv_runner.proto";
import "xla/service/gpu/gpu_norm_runner.proto";
@ -141,8 +142,21 @@ message CudnnThunkProto {
optional int64 sdpa_dropout_seed = 3;
}
message HostExecuteStartThunkProto {}
message HostExecuteDoneThunkProto {}
message ShapedSliceProto {
xla.buffer_assignment.BufferAllocationSliceProto slice = 1;
xla.ShapeProto shape = 2;
}
message HostExecuteStartThunkProto {
HostOffloadingExecutableProto executable_proto = 1;
repeated ShapedSliceProto args = 2;
repeated ShapedSliceProto results = 3;
uint64 async_events_unique_id = 4;
}
message HostExecuteDoneThunkProto {
uint64 async_events_unique_id = 1;
}
message DynamicSliceThunkProto {
SequentialThunkProto embedded_thunk = 1;
@ -164,11 +178,6 @@ message Memset32BitValueThunkProto {
uint32 value = 2;
}
message ShapedSliceProto {
xla.buffer_assignment.BufferAllocationSliceProto slice = 1;
xla.ShapeProto shape = 2;
}
message InfeedThunkProto {
repeated ShapedSliceProto dest_slices = 1;
}

View File

@ -118,7 +118,9 @@ strict_cc_test(
tf_proto_library(
name = "host_offloading_executable_proto",
srcs = ["host_offloading_executable.proto"],
compatible_with = get_compatible_with_libtpu_portable(),
compatible_with = get_compatible_with_libtpu_portable() + [
# copybara:uncomment "//buildenv/target:non_prod",
],
deps = [
"//xla/service:hlo_proto",
"//xla/service/cpu:executable_proto",