Add proto [de]serialization for HostExecuteDoneThunk

PiperOrigin-RevId: 821029998
This commit is contained in:
Aliia Khasanova 2025-10-18 04:48:57 -07:00 committed by TensorFlower Gardener
parent 083e682264
commit 4985a1c2f3
3 changed files with 99 additions and 4 deletions

View File

@ -639,14 +639,35 @@ HostExecuteDoneThunk::HostExecuteDoneThunk(
std::string HostExecuteDoneThunk::ToString(int indent) const { return ""; } std::string HostExecuteDoneThunk::ToString(int indent) const { return ""; }
absl::StatusOr<ThunkProto> HostExecuteDoneThunk::ToProto() const { absl::StatusOr<ThunkProto> HostExecuteDoneThunk::ToProto() const {
return Unimplemented("Not implemented yet."); ThunkProto proto;
*proto.mutable_thunk_info() = thunk_info().ToProto();
HostExecuteDoneThunkProto* host_execute_done_thunk_proto =
proto.mutable_host_execute_done_thunk();
auto async_events_unique_id = GetAsyncEventsUniqueId();
// By design, async_events_unique_id should always be present for
// HostExecuteDoneThunk.
CHECK_NE(async_events_unique_id, std::nullopt);
host_execute_done_thunk_proto->set_async_events_unique_id(
async_events_unique_id.value().value());
return proto;
} }
absl::StatusOr<std::unique_ptr<HostExecuteDoneThunk>> absl::StatusOr<std::unique_ptr<HostExecuteDoneThunk>>
HostExecuteDoneThunk::FromProto( HostExecuteDoneThunk::FromProto(
ThunkInfo thunk_info, const HostExecuteDoneThunkProto& proto, ThunkInfo thunk_info, const HostExecuteDoneThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations) { absl::Span<const BufferAllocation> buffer_allocations,
return Unimplemented("Not implemented yet."); HostExecuteAsyncEventsMap& async_events_map) {
// If async_events_map already contains an entry for the given unique id,
// that means that the pairing start thunk is already serialized and we reuse
// the id to connect them. Otherwise, create a new entry.
auto [async_event_it, _] = async_events_map.try_emplace(
AsyncEventsUniqueId(proto.async_events_unique_id()),
std::make_shared<HostExecuteAsyncEvents>());
return std::make_unique<HostExecuteDoneThunk>(thunk_info,
async_event_it->second);
} }
absl::Status HostExecuteDoneThunk::Initialize(const InitializeParams& params) { absl::Status HostExecuteDoneThunk::Initialize(const InitializeParams& params) {
@ -674,5 +695,13 @@ absl::Status HostExecuteDoneThunk::ExecuteOnStream(
return absl::OkStatus(); return absl::OkStatus();
} }
std::optional<AsyncEventsUniqueId>
HostExecuteDoneThunk::GetAsyncEventsUniqueId() const {
CHECK(async_events_)
<< "async_events_ must not be null in HostExecuteDoneThunk";
// We rely on the fact that the pointer to async_events_ is unique.
return absl::bit_cast<AsyncEventsUniqueId>(async_events_.get());
}
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -161,11 +161,14 @@ class HostExecuteDoneThunk : public Thunk {
absl::StatusOr<ThunkProto> ToProto() const override; absl::StatusOr<ThunkProto> ToProto() const override;
static absl::StatusOr<std::unique_ptr<HostExecuteDoneThunk>> FromProto( static absl::StatusOr<std::unique_ptr<HostExecuteDoneThunk>> FromProto(
ThunkInfo thunk_info, const HostExecuteDoneThunkProto& proto, ThunkInfo thunk_info, const HostExecuteDoneThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations); absl::Span<const BufferAllocation> buffer_allocations,
HostExecuteAsyncEventsMap& async_events_map);
absl::Status Initialize(const InitializeParams& params) override; absl::Status Initialize(const InitializeParams& params) override;
absl::Status ExecuteOnStream(const ExecuteParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override;
std::optional<AsyncEventsUniqueId> GetAsyncEventsUniqueId() const override;
private: private:
std::shared_ptr<HostExecuteAsyncEvents> async_events_; std::shared_ptr<HostExecuteAsyncEvents> async_events_;
}; };

View File

@ -640,6 +640,69 @@ TEST(HostExecuteStartThunkTest, ProtoRoundTrip) {
EXPECT_THAT(round_trip_proto, tsl::proto_testing::EqualsProto(proto)); EXPECT_THAT(round_trip_proto, tsl::proto_testing::EqualsProto(proto));
} }
TEST(HostExecuteThunkTest, ProtoRoundTripPairing) {
static constexpr char const* kHloModule = R"(
HloModule module
ENTRY add_inplace {
p0 = s32[] parameter(0)
ROOT add = s32[] add(p0, p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kHloModule, {}));
BufferAllocation alloc_arg(/*index=*/0, 4, /*color=*/0);
BufferAllocation alloc_result(/*index=*/1, 4, /*color=*/0);
BufferAllocation::Slice slice_arg(&alloc_arg, 0, 4);
BufferAllocation::Slice slice_result(&alloc_result, 0, 4);
TF_ASSERT_OK_AND_ASSIGN(auto start_thunk_orig,
CreateHostExecuteStartThunk(
Thunk::ThunkInfo(), *hlo_module,
{{slice_arg, ShapeUtil::MakeShape(S32, {})}},
{{slice_result, ShapeUtil::MakeShape(S32, {})}}));
HostExecuteDoneThunk done_thunk_orig(Thunk::ThunkInfo(),
start_thunk_orig->async_events());
TF_ASSERT_OK_AND_ASSIGN(ThunkProto start_proto, start_thunk_orig->ToProto());
TF_ASSERT_OK_AND_ASSIGN(ThunkProto done_proto, done_thunk_orig.ToProto());
// Check that the ids are matching.
EXPECT_EQ(start_proto.host_execute_start_thunk().async_events_unique_id(),
done_proto.host_execute_done_thunk().async_events_unique_id());
std::vector<BufferAllocation> buffer_allocations = {
BufferAllocation(/*index=*/0, /*size=*/4, /*color=*/0),
BufferAllocation(/*index=*/1, /*size=*/4, /*color=*/0)};
TF_ASSERT_OK_AND_ASSIGN(
Thunk::ThunkInfo start_thunk_info,
Thunk::ThunkInfo::FromProto(start_proto.thunk_info()));
TF_ASSERT_OK_AND_ASSIGN(Thunk::ThunkInfo done_thunk_info,
Thunk::ThunkInfo::FromProto(done_proto.thunk_info()));
HostExecuteAsyncEventsMap async_events_map;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HostExecuteDoneThunk> done_thunk,
HostExecuteDoneThunk::FromProto(done_thunk_info,
done_proto.host_execute_done_thunk(),
buffer_allocations, async_events_map));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HostExecuteStartThunk> start_thunk,
HostExecuteStartThunk::FromProto(start_thunk_info,
start_proto.host_execute_start_thunk(),
buffer_allocations, async_events_map));
EXPECT_EQ(async_events_map.size(), 1);
EXPECT_EQ(start_thunk->GetAsyncEventsUniqueId(),
done_thunk->GetAsyncEventsUniqueId());
}
} // namespace } // namespace
} // namespace gpu } // namespace gpu