mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add proto [de]serialization for HostExecuteDoneThunk
PiperOrigin-RevId: 821029998
This commit is contained in:
parent
083e682264
commit
4985a1c2f3
|
|
@ -639,14 +639,35 @@ HostExecuteDoneThunk::HostExecuteDoneThunk(
|
||||||
std::string HostExecuteDoneThunk::ToString(int indent) const { return ""; }
|
std::string HostExecuteDoneThunk::ToString(int indent) const { return ""; }
|
||||||
|
|
||||||
absl::StatusOr<ThunkProto> HostExecuteDoneThunk::ToProto() const {
|
absl::StatusOr<ThunkProto> HostExecuteDoneThunk::ToProto() const {
|
||||||
return Unimplemented("Not implemented yet.");
|
ThunkProto proto;
|
||||||
|
*proto.mutable_thunk_info() = thunk_info().ToProto();
|
||||||
|
HostExecuteDoneThunkProto* host_execute_done_thunk_proto =
|
||||||
|
proto.mutable_host_execute_done_thunk();
|
||||||
|
|
||||||
|
auto async_events_unique_id = GetAsyncEventsUniqueId();
|
||||||
|
// By design, async_events_unique_id should always be present for
|
||||||
|
// HostExecuteDoneThunk.
|
||||||
|
CHECK_NE(async_events_unique_id, std::nullopt);
|
||||||
|
|
||||||
|
host_execute_done_thunk_proto->set_async_events_unique_id(
|
||||||
|
async_events_unique_id.value().value());
|
||||||
|
|
||||||
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::unique_ptr<HostExecuteDoneThunk>>
|
absl::StatusOr<std::unique_ptr<HostExecuteDoneThunk>>
|
||||||
HostExecuteDoneThunk::FromProto(
|
HostExecuteDoneThunk::FromProto(
|
||||||
ThunkInfo thunk_info, const HostExecuteDoneThunkProto& proto,
|
ThunkInfo thunk_info, const HostExecuteDoneThunkProto& proto,
|
||||||
absl::Span<const BufferAllocation> buffer_allocations) {
|
absl::Span<const BufferAllocation> buffer_allocations,
|
||||||
return Unimplemented("Not implemented yet.");
|
HostExecuteAsyncEventsMap& async_events_map) {
|
||||||
|
// If async_events_map already contains an entry for the given unique id,
|
||||||
|
// that means that the pairing start thunk is already serialized and we reuse
|
||||||
|
// the id to connect them. Otherwise, create a new entry.
|
||||||
|
auto [async_event_it, _] = async_events_map.try_emplace(
|
||||||
|
AsyncEventsUniqueId(proto.async_events_unique_id()),
|
||||||
|
std::make_shared<HostExecuteAsyncEvents>());
|
||||||
|
return std::make_unique<HostExecuteDoneThunk>(thunk_info,
|
||||||
|
async_event_it->second);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status HostExecuteDoneThunk::Initialize(const InitializeParams& params) {
|
absl::Status HostExecuteDoneThunk::Initialize(const InitializeParams& params) {
|
||||||
|
|
@ -674,5 +695,13 @@ absl::Status HostExecuteDoneThunk::ExecuteOnStream(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<AsyncEventsUniqueId>
|
||||||
|
HostExecuteDoneThunk::GetAsyncEventsUniqueId() const {
|
||||||
|
CHECK(async_events_)
|
||||||
|
<< "async_events_ must not be null in HostExecuteDoneThunk";
|
||||||
|
// We rely on the fact that the pointer to async_events_ is unique.
|
||||||
|
return absl::bit_cast<AsyncEventsUniqueId>(async_events_.get());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
|
|
@ -161,11 +161,14 @@ class HostExecuteDoneThunk : public Thunk {
|
||||||
absl::StatusOr<ThunkProto> ToProto() const override;
|
absl::StatusOr<ThunkProto> ToProto() const override;
|
||||||
static absl::StatusOr<std::unique_ptr<HostExecuteDoneThunk>> FromProto(
|
static absl::StatusOr<std::unique_ptr<HostExecuteDoneThunk>> FromProto(
|
||||||
ThunkInfo thunk_info, const HostExecuteDoneThunkProto& proto,
|
ThunkInfo thunk_info, const HostExecuteDoneThunkProto& proto,
|
||||||
absl::Span<const BufferAllocation> buffer_allocations);
|
absl::Span<const BufferAllocation> buffer_allocations,
|
||||||
|
HostExecuteAsyncEventsMap& async_events_map);
|
||||||
|
|
||||||
absl::Status Initialize(const InitializeParams& params) override;
|
absl::Status Initialize(const InitializeParams& params) override;
|
||||||
absl::Status ExecuteOnStream(const ExecuteParams& params) override;
|
absl::Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
|
std::optional<AsyncEventsUniqueId> GetAsyncEventsUniqueId() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<HostExecuteAsyncEvents> async_events_;
|
std::shared_ptr<HostExecuteAsyncEvents> async_events_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -640,6 +640,69 @@ TEST(HostExecuteStartThunkTest, ProtoRoundTrip) {
|
||||||
EXPECT_THAT(round_trip_proto, tsl::proto_testing::EqualsProto(proto));
|
EXPECT_THAT(round_trip_proto, tsl::proto_testing::EqualsProto(proto));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(HostExecuteThunkTest, ProtoRoundTripPairing) {
|
||||||
|
static constexpr char const* kHloModule = R"(
|
||||||
|
HloModule module
|
||||||
|
ENTRY add_inplace {
|
||||||
|
p0 = s32[] parameter(0)
|
||||||
|
ROOT add = s32[] add(p0, p0)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||||
|
ParseAndReturnUnverifiedModule(kHloModule, {}));
|
||||||
|
|
||||||
|
BufferAllocation alloc_arg(/*index=*/0, 4, /*color=*/0);
|
||||||
|
BufferAllocation alloc_result(/*index=*/1, 4, /*color=*/0);
|
||||||
|
|
||||||
|
BufferAllocation::Slice slice_arg(&alloc_arg, 0, 4);
|
||||||
|
BufferAllocation::Slice slice_result(&alloc_result, 0, 4);
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto start_thunk_orig,
|
||||||
|
CreateHostExecuteStartThunk(
|
||||||
|
Thunk::ThunkInfo(), *hlo_module,
|
||||||
|
{{slice_arg, ShapeUtil::MakeShape(S32, {})}},
|
||||||
|
{{slice_result, ShapeUtil::MakeShape(S32, {})}}));
|
||||||
|
|
||||||
|
HostExecuteDoneThunk done_thunk_orig(Thunk::ThunkInfo(),
|
||||||
|
start_thunk_orig->async_events());
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(ThunkProto start_proto, start_thunk_orig->ToProto());
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(ThunkProto done_proto, done_thunk_orig.ToProto());
|
||||||
|
|
||||||
|
// Check that the ids are matching.
|
||||||
|
EXPECT_EQ(start_proto.host_execute_start_thunk().async_events_unique_id(),
|
||||||
|
done_proto.host_execute_done_thunk().async_events_unique_id());
|
||||||
|
|
||||||
|
std::vector<BufferAllocation> buffer_allocations = {
|
||||||
|
BufferAllocation(/*index=*/0, /*size=*/4, /*color=*/0),
|
||||||
|
BufferAllocation(/*index=*/1, /*size=*/4, /*color=*/0)};
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Thunk::ThunkInfo start_thunk_info,
|
||||||
|
Thunk::ThunkInfo::FromProto(start_proto.thunk_info()));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(Thunk::ThunkInfo done_thunk_info,
|
||||||
|
Thunk::ThunkInfo::FromProto(done_proto.thunk_info()));
|
||||||
|
|
||||||
|
HostExecuteAsyncEventsMap async_events_map;
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<HostExecuteDoneThunk> done_thunk,
|
||||||
|
HostExecuteDoneThunk::FromProto(done_thunk_info,
|
||||||
|
done_proto.host_execute_done_thunk(),
|
||||||
|
buffer_allocations, async_events_map));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<HostExecuteStartThunk> start_thunk,
|
||||||
|
HostExecuteStartThunk::FromProto(start_thunk_info,
|
||||||
|
start_proto.host_execute_start_thunk(),
|
||||||
|
buffer_allocations, async_events_map));
|
||||||
|
|
||||||
|
EXPECT_EQ(async_events_map.size(), 1);
|
||||||
|
EXPECT_EQ(start_thunk->GetAsyncEventsUniqueId(),
|
||||||
|
done_thunk->GetAsyncEventsUniqueId());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user