[xla:cpu] Pass HloModule pointer to Thunk SerDes

Reverts 993369077a

PiperOrigin-RevId: 826675119
This commit is contained in:
Eugene Zhulenev 2025-10-31 16:51:10 -07:00 committed by TensorFlower Gardener
parent 00be2bc09e
commit fbd032df67
5 changed files with 44 additions and 31 deletions

View File

@ -1203,6 +1203,7 @@ cc_library(
"//xla/backends/cpu/runtime/xnnpack:xnn_convolution_thunk", "//xla/backends/cpu/runtime/xnnpack:xnn_convolution_thunk",
"//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk", "//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk",
"//xla/backends/cpu/runtime/xnnpack:xnn_fusion_thunk", "//xla/backends/cpu/runtime/xnnpack:xnn_fusion_thunk",
"//xla/hlo/ir:hlo",
"//xla/runtime:resource_use", "//xla/runtime:resource_use",
"//xla/runtime:work_group", "//xla/runtime:work_group",
"//xla/service:buffer_assignment", "//xla/service:buffer_assignment",

View File

@ -344,6 +344,7 @@ class ThunkSerDesProtobuf : public SerDesBase<Thunk> {
public: public:
// Buffer allocations and resources are not needed for serialization. // Buffer allocations and resources are not needed for serialization.
explicit ThunkSerDesProtobuf( explicit ThunkSerDesProtobuf(
const HloModule* hlo_module = nullptr,
const std::vector<BufferAllocation>* buffer_allocations = nullptr, const std::vector<BufferAllocation>* buffer_allocations = nullptr,
const std::vector<std::shared_ptr<Resource>>* thunk_resources = nullptr); const std::vector<std::shared_ptr<Resource>>* thunk_resources = nullptr);
absl::StatusOr<std::string> Serialize(const Thunk& thunk) override; absl::StatusOr<std::string> Serialize(const Thunk& thunk) override;
@ -356,16 +357,18 @@ class ThunkSerDesProtobuf : public SerDesBase<Thunk> {
const ThunkProto& proto) const; const ThunkProto& proto) const;
private: private:
// TODO(basiol) remove NOLINT when this actually gets used const HloModule* hlo_module_;
const std::vector<BufferAllocation>* buffer_allocations_; // NOLINT const std::vector<BufferAllocation>* buffer_allocations_;
const std::vector<std::shared_ptr<Resource>>* thunk_resources_; const std::vector<std::shared_ptr<Resource>>* thunk_resources_;
}; };
ThunkSerDesProtobuf::ThunkSerDesProtobuf( ThunkSerDesProtobuf::ThunkSerDesProtobuf(
const HloModule* hlo_module,
const std::vector<BufferAllocation>* buffer_allocations, const std::vector<BufferAllocation>* buffer_allocations,
const std::vector<std::shared_ptr<Resource>>* thunk_resources) const std::vector<std::shared_ptr<Resource>>* thunk_resources)
: buffer_allocations_(buffer_allocations), : hlo_module_(hlo_module),
buffer_allocations_(buffer_allocations),
thunk_resources_(thunk_resources) {} thunk_resources_(thunk_resources) {}
absl::StatusOr<std::string> ThunkSerDesProtobuf::Serialize(const Thunk& thunk) { absl::StatusOr<std::string> ThunkSerDesProtobuf::Serialize(const Thunk& thunk) {
@ -1087,9 +1090,10 @@ ReduceScatterThunkFromProto(
} }
static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto( static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
const ThunkProto& proto, const ThunkProto& proto, const HloModule* hlo_module,
const std::vector<BufferAllocation>& buffer_allocations) { const std::vector<BufferAllocation>* buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations); ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
buffer_allocations);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<ThunkSequence> call_sequence, std::unique_ptr<ThunkSequence> call_sequence,
@ -1101,9 +1105,10 @@ static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
static absl::StatusOr<std::unique_ptr<ConditionalThunk>> static absl::StatusOr<std::unique_ptr<ConditionalThunk>>
ConditionalThunkFromProto( ConditionalThunkFromProto(
const ThunkProto& proto, const ThunkProto& proto, const HloModule* hlo_module,
const std::vector<BufferAllocation>& buffer_allocations) { const std::vector<BufferAllocation>* buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations); ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
buffer_allocations);
std::vector<ThunkSequence> branch_sequences; std::vector<ThunkSequence> branch_sequences;
for (const ThunkSequenceProto& branch_sequence_proto : for (const ThunkSequenceProto& branch_sequence_proto :
@ -1114,10 +1119,10 @@ ConditionalThunkFromProto(
} }
TF_ASSIGN_OR_RETURN(Thunk::Info info, ThunkInfoFromProto(proto.info())); TF_ASSIGN_OR_RETURN(Thunk::Info info, ThunkInfoFromProto(proto.info()));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(BufferAllocation::Slice branch_index_buffer,
BufferAllocation::Slice branch_index_buffer, BufferAllocation::Slice::FromProto(
BufferAllocation::Slice::FromProto( proto.conditional_thunk().branch_index_buffer(),
proto.conditional_thunk().branch_index_buffer(), buffer_allocations)); *buffer_allocations));
return ConditionalThunk::Create(std::move(info), return ConditionalThunk::Create(std::move(info),
std::move(branch_index_buffer), std::move(branch_index_buffer),
@ -1480,9 +1485,10 @@ static absl::StatusOr<std::unique_ptr<TopKThunk>> TopKThunkFromProto(
} }
static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto( static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
const ThunkProto& proto, const ThunkProto& proto, const HloModule* hlo_module,
const std::vector<BufferAllocation>& buffer_allocations) { const std::vector<BufferAllocation>* buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations); ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
buffer_allocations);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<ThunkSequence> cond_sequence, std::unique_ptr<ThunkSequence> cond_sequence,
@ -1496,7 +1502,7 @@ static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice cond_buffer, BufferAllocation::Slice cond_buffer,
BufferAllocation::Slice::FromProto(proto.while_thunk().cond_buffer(), BufferAllocation::Slice::FromProto(proto.while_thunk().cond_buffer(),
buffer_allocations)); *buffer_allocations));
std::optional<int64_t> trip_count = std::nullopt; std::optional<int64_t> trip_count = std::nullopt;
if (proto.while_thunk().trip_count().contains_value()) { if (proto.while_thunk().trip_count().contains_value()) {
@ -1662,9 +1668,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
} }
} }
case Thunk::Kind::kCall: case Thunk::Kind::kCall:
return CallThunkFromProto(proto, *buffer_allocations_); return CallThunkFromProto(proto, hlo_module_, buffer_allocations_);
case Thunk::Kind::kConditional: case Thunk::Kind::kConditional:
return ConditionalThunkFromProto(proto, *buffer_allocations_); return ConditionalThunkFromProto(proto, hlo_module_, buffer_allocations_);
case Thunk::Kind::kConvolution: case Thunk::Kind::kConvolution:
return ConvolutionThunkFromProto(proto, *buffer_allocations_); return ConvolutionThunkFromProto(proto, *buffer_allocations_);
case Thunk::Kind::kCopy: case Thunk::Kind::kCopy:
@ -1688,7 +1694,7 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
case Thunk::Kind::kTopK: case Thunk::Kind::kTopK:
return TopKThunkFromProto(proto, *buffer_allocations_); return TopKThunkFromProto(proto, *buffer_allocations_);
case Thunk::Kind::kWhile: case Thunk::Kind::kWhile:
return WhileThunkFromProto(proto, *buffer_allocations_); return WhileThunkFromProto(proto, hlo_module_, buffer_allocations_);
case Thunk::Kind::kXnnFusion: { case Thunk::Kind::kXnnFusion: {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto xnn_fusion_kind, auto xnn_fusion_kind,
@ -1715,8 +1721,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
} }
ThunkSequenceSerDesProtobuf::ThunkSequenceSerDesProtobuf( ThunkSequenceSerDesProtobuf::ThunkSequenceSerDesProtobuf(
const HloModule* hlo_module,
const std::vector<BufferAllocation>* buffer_allocations) const std::vector<BufferAllocation>* buffer_allocations)
: buffer_allocations_(buffer_allocations) {} : hlo_module_(hlo_module), buffer_allocations_(buffer_allocations) {}
absl::StatusOr<std::string> ThunkSequenceSerDesProtobuf::Serialize( absl::StatusOr<std::string> ThunkSequenceSerDesProtobuf::Serialize(
const ThunkSequence& thunk_sequence) { const ThunkSequence& thunk_sequence) {
@ -1736,7 +1743,7 @@ ThunkSequenceSerDesProtobuf::Deserialize(const std::string& serialized) {
absl::StatusOr<ThunkSequenceProto> ThunkSequenceSerDesProtobuf::ToProto( absl::StatusOr<ThunkSequenceProto> ThunkSequenceSerDesProtobuf::ToProto(
const ThunkSequence& thunk_sequence) const { const ThunkSequence& thunk_sequence) const {
ThunkSerDesProtobuf thunk_serdes(buffer_allocations_); ThunkSerDesProtobuf thunk_serdes(hlo_module_, buffer_allocations_);
ThunkSequenceProto proto; ThunkSequenceProto proto;
proto.mutable_thunks()->Reserve(thunk_sequence.size()); proto.mutable_thunks()->Reserve(thunk_sequence.size());
@ -1798,7 +1805,7 @@ ThunkSequenceSerDesProtobuf::FromProto(const ThunkSequenceProto& proto) const {
size_t thunk_index = 0; size_t thunk_index = 0;
for (const ThunkProto& thunk_proto : proto.thunks()) { for (const ThunkProto& thunk_proto : proto.thunks()) {
ThunkSerDesProtobuf thunk_serdes(buffer_allocations_, ThunkSerDesProtobuf thunk_serdes(hlo_module_, buffer_allocations_,
&thunk_resources[thunk_index++]); &thunk_resources[thunk_index++]);
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> thunk, TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> thunk,
thunk_serdes.FromProto(thunk_proto)); thunk_serdes.FromProto(thunk_proto));

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "xla/backends/cpu/runtime/serdes_base.h" #include "xla/backends/cpu/runtime/serdes_base.h"
#include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/backends/cpu/runtime/thunk.pb.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/buffer_assignment.h" #include "xla/service/buffer_assignment.h"
namespace xla::cpu { namespace xla::cpu {
@ -36,10 +37,13 @@ void ForEachThunkProto(const ThunkSequenceProto& proto,
class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> { class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
public: public:
// For serialization, `hlo_module` and `buffer_allocations` are optional. For
// deserialization, both are required as we rely on the HLO module to resolve
// thunks that were generated from `HloComputation`s, and we also need buffer
// allocations to resolve buffer slices.
explicit ThunkSequenceSerDesProtobuf( explicit ThunkSequenceSerDesProtobuf(
const std::vector<BufferAllocation>* buffer_allocations = const HloModule* hlo_module = nullptr,
nullptr); // NOTE buffer allocations aren't const std::vector<BufferAllocation>* buffer_allocations = nullptr);
// needed for serialization.
absl::StatusOr<std::string> Serialize( absl::StatusOr<std::string> Serialize(
const ThunkSequence& thunk_sequence) override; const ThunkSequence& thunk_sequence) override;
@ -52,6 +56,7 @@ class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
const ThunkSequenceProto& proto) const; const ThunkSequenceProto& proto) const;
private: private:
const HloModule* hlo_module_;
const std::vector<BufferAllocation>* buffer_allocations_; const std::vector<BufferAllocation>* buffer_allocations_;
}; };

View File

@ -219,8 +219,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test {
public: public:
void SetUp() override { void SetUp() override {
thunk_sequence_serdes_ = thunk_sequence_serdes_ = std::make_unique<T>(
std::make_unique<T>(&buffer_allocations_.GetUnderlyingVector()); nullptr, &buffer_allocations_.GetUnderlyingVector());
} }
protected: protected:

View File

@ -82,7 +82,7 @@ CpuAotCompilationResult::Create(
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data, std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
TargetMachineOptionsProto target_machine_options) { TargetMachineOptionsProto target_machine_options) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes( ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
&buffer_assignment->Allocations()); hlo_module, &buffer_assignment->Allocations());
TF_ASSIGN_OR_RETURN(ThunkSequenceProto thunk_proto, TF_ASSIGN_OR_RETURN(ThunkSequenceProto thunk_proto,
thunk_sequence_serdes.ToProto(thunks)); thunk_sequence_serdes.ToProto(thunks));
@ -143,7 +143,7 @@ CpuAotCompilationResult::CpuAotCompilationResult(
module_ = hlo_module->Clone(); module_ = hlo_module->Clone();
ThunkSequenceSerDesProtobuf thunk_sequence_serdes( ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
&buffer_assignment->Allocations()); hlo_module, &buffer_assignment->Allocations());
*proto_.mutable_thunk_sequence() = thunks; *proto_.mutable_thunk_sequence() = thunks;
} }
@ -181,7 +181,7 @@ CpuAotCompilationResult::LoadExecutable(
} }
ThunkSequenceSerDesProtobuf thunk_sequence_serdes( ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
&buffer_assignment->Allocations()); module.get(), &buffer_assignment->Allocations());
TF_ASSIGN_OR_RETURN(std::unique_ptr<ThunkSequence> thunks, TF_ASSIGN_OR_RETURN(std::unique_ptr<ThunkSequence> thunks,
thunk_sequence_serdes.FromProto(proto_.thunk_sequence())); thunk_sequence_serdes.FromProto(proto_.thunk_sequence()));