From fbd032df67dee28f887b87c82cdbdcdafe99f8a7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 31 Oct 2025 16:51:10 -0700 Subject: [PATCH] [xla:cpu] Pass HloModule pointer to Thunk SerDes Reverts 993369077a7da432be091fdf7abc3fbd370ecdf1 PiperOrigin-RevId: 826675119 --- .../xla/xla/backends/cpu/runtime/BUILD | 1 + .../cpu/runtime/thunk_proto_serdes.cc | 53 +++++++++++-------- .../backends/cpu/runtime/thunk_proto_serdes.h | 11 ++-- .../cpu/runtime/thunk_sequence_serdes_test.cc | 4 +- .../service/cpu/cpu_aot_compilation_result.cc | 6 +-- 5 files changed, 44 insertions(+), 31 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index 6f103e8060f..48a7c4d7977 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -1203,6 +1203,7 @@ cc_library( "//xla/backends/cpu/runtime/xnnpack:xnn_convolution_thunk", "//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk", "//xla/backends/cpu/runtime/xnnpack:xnn_fusion_thunk", + "//xla/hlo/ir:hlo", "//xla/runtime:resource_use", "//xla/runtime:work_group", "//xla/service:buffer_assignment", diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_proto_serdes.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_proto_serdes.cc index e60e29f7abf..c611b356c7c 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_proto_serdes.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_proto_serdes.cc @@ -344,6 +344,7 @@ class ThunkSerDesProtobuf : public SerDesBase { public: // Buffer allocations and resources are not needed for serialization. explicit ThunkSerDesProtobuf( + const HloModule* hlo_module = nullptr, const std::vector* buffer_allocations = nullptr, const std::vector>* thunk_resources = nullptr); absl::StatusOr Serialize(const Thunk& thunk) override; @@ -356,16 +357,18 @@ class ThunkSerDesProtobuf : public SerDesBase { const ThunkProto& proto) const; private: - // TODO(basiol) remove NOLINT when this actually gets used - const std::vector* buffer_allocations_; // NOLINT + const HloModule* hlo_module_; + const std::vector* buffer_allocations_; const std::vector>* thunk_resources_; }; ThunkSerDesProtobuf::ThunkSerDesProtobuf( + const HloModule* hlo_module, const std::vector* buffer_allocations, const std::vector>* thunk_resources) - : buffer_allocations_(buffer_allocations), + : hlo_module_(hlo_module), + buffer_allocations_(buffer_allocations), thunk_resources_(thunk_resources) {} absl::StatusOr ThunkSerDesProtobuf::Serialize(const Thunk& thunk) { @@ -1087,9 +1090,10 @@ ReduceScatterThunkFromProto( } static absl::StatusOr> CallThunkFromProto( - const ThunkProto& proto, - const std::vector& buffer_allocations) { - ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations); + const ThunkProto& proto, const HloModule* hlo_module, + const std::vector* buffer_allocations) { + ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module, + buffer_allocations); TF_ASSIGN_OR_RETURN( std::unique_ptr call_sequence, @@ -1101,9 +1105,10 @@ static absl::StatusOr> CallThunkFromProto( static absl::StatusOr> ConditionalThunkFromProto( - const ThunkProto& proto, - const std::vector& buffer_allocations) { - ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations); + const ThunkProto& proto, const HloModule* hlo_module, + const std::vector* buffer_allocations) { + ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module, + buffer_allocations); std::vector branch_sequences; for (const ThunkSequenceProto& branch_sequence_proto : @@ -1114,10 +1119,10 @@ ConditionalThunkFromProto( } TF_ASSIGN_OR_RETURN(Thunk::Info info, ThunkInfoFromProto(proto.info())); - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice branch_index_buffer, - BufferAllocation::Slice::FromProto( - proto.conditional_thunk().branch_index_buffer(), buffer_allocations)); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice branch_index_buffer, + BufferAllocation::Slice::FromProto( + proto.conditional_thunk().branch_index_buffer(), + *buffer_allocations)); return ConditionalThunk::Create(std::move(info), std::move(branch_index_buffer), @@ -1480,9 +1485,10 @@ static absl::StatusOr> TopKThunkFromProto( } static absl::StatusOr> WhileThunkFromProto( - const ThunkProto& proto, - const std::vector& buffer_allocations) { - ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations); + const ThunkProto& proto, const HloModule* hlo_module, + const std::vector* buffer_allocations) { + ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module, + buffer_allocations); TF_ASSIGN_OR_RETURN( std::unique_ptr cond_sequence, @@ -1496,7 +1502,7 @@ static absl::StatusOr> WhileThunkFromProto( TF_ASSIGN_OR_RETURN( BufferAllocation::Slice cond_buffer, BufferAllocation::Slice::FromProto(proto.while_thunk().cond_buffer(), - buffer_allocations)); + *buffer_allocations)); std::optional trip_count = std::nullopt; if (proto.while_thunk().trip_count().contains_value()) { @@ -1662,9 +1668,9 @@ absl::StatusOr> ThunkSerDesProtobuf::FromProto( } } case Thunk::Kind::kCall: - return CallThunkFromProto(proto, *buffer_allocations_); + return CallThunkFromProto(proto, hlo_module_, buffer_allocations_); case Thunk::Kind::kConditional: - return ConditionalThunkFromProto(proto, *buffer_allocations_); + return ConditionalThunkFromProto(proto, hlo_module_, buffer_allocations_); case Thunk::Kind::kConvolution: return ConvolutionThunkFromProto(proto, *buffer_allocations_); case Thunk::Kind::kCopy: @@ -1688,7 +1694,7 @@ absl::StatusOr> ThunkSerDesProtobuf::FromProto( case Thunk::Kind::kTopK: return TopKThunkFromProto(proto, *buffer_allocations_); case Thunk::Kind::kWhile: - return WhileThunkFromProto(proto, *buffer_allocations_); + return WhileThunkFromProto(proto, hlo_module_, buffer_allocations_); case Thunk::Kind::kXnnFusion: { TF_ASSIGN_OR_RETURN( auto xnn_fusion_kind, @@ -1715,8 +1721,9 @@ absl::StatusOr> ThunkSerDesProtobuf::FromProto( } ThunkSequenceSerDesProtobuf::ThunkSequenceSerDesProtobuf( + const HloModule* hlo_module, const std::vector* buffer_allocations) - : buffer_allocations_(buffer_allocations) {} + : hlo_module_(hlo_module), buffer_allocations_(buffer_allocations) {} absl::StatusOr ThunkSequenceSerDesProtobuf::Serialize( const ThunkSequence& thunk_sequence) { @@ -1736,7 +1743,7 @@ ThunkSequenceSerDesProtobuf::Deserialize(const std::string& serialized) { absl::StatusOr ThunkSequenceSerDesProtobuf::ToProto( const ThunkSequence& thunk_sequence) const { - ThunkSerDesProtobuf thunk_serdes(buffer_allocations_); + ThunkSerDesProtobuf thunk_serdes(hlo_module_, buffer_allocations_); ThunkSequenceProto proto; proto.mutable_thunks()->Reserve(thunk_sequence.size()); @@ -1798,7 +1805,7 @@ ThunkSequenceSerDesProtobuf::FromProto(const ThunkSequenceProto& proto) const { size_t thunk_index = 0; for (const ThunkProto& thunk_proto : proto.thunks()) { - ThunkSerDesProtobuf thunk_serdes(buffer_allocations_, + ThunkSerDesProtobuf thunk_serdes(hlo_module_, buffer_allocations_, &thunk_resources[thunk_index++]); TF_ASSIGN_OR_RETURN(std::unique_ptr thunk, thunk_serdes.FromProto(thunk_proto)); diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_proto_serdes.h b/third_party/xla/xla/backends/cpu/runtime/thunk_proto_serdes.h index b18f727ccea..08b4deaaf91 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_proto_serdes.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_proto_serdes.h @@ -25,6 +25,7 @@ limitations under the License. #include "xla/backends/cpu/runtime/serdes_base.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk.pb.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" namespace xla::cpu { @@ -36,10 +37,13 @@ void ForEachThunkProto(const ThunkSequenceProto& proto, class ThunkSequenceSerDesProtobuf : public SerDesBase { public: + // For serialization, `hlo_module` and `buffer_allocations` are optional. For + // deserialization, both are required as we rely on the HLO module to resolve + // thunks that were generated from `HloComputation`s, and we also need buffer + // allocations to resolve buffer slices. explicit ThunkSequenceSerDesProtobuf( - const std::vector* buffer_allocations = - nullptr); // NOTE buffer allocations aren't - // needed for serialization. + const HloModule* hlo_module = nullptr, + const std::vector* buffer_allocations = nullptr); absl::StatusOr Serialize( const ThunkSequence& thunk_sequence) override; @@ -52,6 +56,7 @@ class ThunkSequenceSerDesProtobuf : public SerDesBase { const ThunkSequenceProto& proto) const; private: + const HloModule* hlo_module_; const std::vector* buffer_allocations_; }; diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_sequence_serdes_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_sequence_serdes_test.cc index 8daa2cb544d..2fc4d850d10 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_sequence_serdes_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_sequence_serdes_test.cc @@ -219,8 +219,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test { public: void SetUp() override { - thunk_sequence_serdes_ = - std::make_unique(&buffer_allocations_.GetUnderlyingVector()); + thunk_sequence_serdes_ = std::make_unique( + nullptr, &buffer_allocations_.GetUnderlyingVector()); } protected: diff --git a/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.cc b/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.cc index 3107b59df3f..959cde89ffa 100644 --- a/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.cc +++ b/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.cc @@ -82,7 +82,7 @@ CpuAotCompilationResult::Create( std::unique_ptr hlo_profile_printer_data, TargetMachineOptionsProto target_machine_options) { ThunkSequenceSerDesProtobuf thunk_sequence_serdes( - &buffer_assignment->Allocations()); + hlo_module, &buffer_assignment->Allocations()); TF_ASSIGN_OR_RETURN(ThunkSequenceProto thunk_proto, thunk_sequence_serdes.ToProto(thunks)); @@ -143,7 +143,7 @@ CpuAotCompilationResult::CpuAotCompilationResult( module_ = hlo_module->Clone(); ThunkSequenceSerDesProtobuf thunk_sequence_serdes( - &buffer_assignment->Allocations()); + hlo_module, &buffer_assignment->Allocations()); *proto_.mutable_thunk_sequence() = thunks; } @@ -181,7 +181,7 @@ CpuAotCompilationResult::LoadExecutable( } ThunkSequenceSerDesProtobuf thunk_sequence_serdes( - &buffer_assignment->Allocations()); + module.get(), &buffer_assignment->Allocations()); TF_ASSIGN_OR_RETURN(std::unique_ptr thunks, thunk_sequence_serdes.FromProto(proto_.thunk_sequence()));