[xla:cpu] Pass HloModule pointer to Thunk SerDes

PiperOrigin-RevId: 826312546
This commit is contained in:
Eugene Zhulenev 2025-10-30 21:37:36 -07:00 committed by TensorFlower Gardener
parent 56d3b19280
commit bf23bf1b32
5 changed files with 38 additions and 28 deletions

View File

@ -1203,6 +1203,7 @@ cc_library(
"//xla/backends/cpu/runtime/xnnpack:xnn_convolution_thunk",
"//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk",
"//xla/backends/cpu/runtime/xnnpack:xnn_fusion_thunk",
"//xla/hlo/ir:hlo",
"//xla/runtime:resource_use",
"//xla/runtime:work_group",
"//xla/service:buffer_assignment",

View File

@ -356,8 +356,8 @@ class ThunkSerDesProtobuf : public SerDesBase<Thunk> {
const ThunkProto& proto) const;
private:
// TODO(basiol) remove NOLINT when this actually gets used
const std::vector<BufferAllocation>* buffer_allocations_; // NOLINT
const HloModule* hlo_module_;
const std::vector<BufferAllocation>* buffer_allocations_;
const std::vector<std::shared_ptr<Resource>>* thunk_resources_;
};
@ -1087,9 +1087,10 @@ ReduceScatterThunkFromProto(
}
static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
const ThunkProto& proto,
const std::vector<BufferAllocation>& buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
const ThunkProto& proto, const HloModule* hlo_module,
const std::vector<BufferAllocation>* buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
buffer_allocations);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<ThunkSequence> call_sequence,
@ -1101,9 +1102,10 @@ static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
static absl::StatusOr<std::unique_ptr<ConditionalThunk>>
ConditionalThunkFromProto(
const ThunkProto& proto,
const std::vector<BufferAllocation>& buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
const ThunkProto& proto, const HloModule* hlo_module,
const std::vector<BufferAllocation>* buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
buffer_allocations);
std::vector<ThunkSequence> branch_sequences;
for (const ThunkSequenceProto& branch_sequence_proto :
@ -1114,10 +1116,10 @@ ConditionalThunkFromProto(
}
TF_ASSIGN_OR_RETURN(Thunk::Info info, ThunkInfoFromProto(proto.info()));
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice branch_index_buffer,
BufferAllocation::Slice::FromProto(
proto.conditional_thunk().branch_index_buffer(), buffer_allocations));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice branch_index_buffer,
BufferAllocation::Slice::FromProto(
proto.conditional_thunk().branch_index_buffer(),
*buffer_allocations));
return ConditionalThunk::Create(std::move(info),
std::move(branch_index_buffer),
@ -1480,9 +1482,10 @@ static absl::StatusOr<std::unique_ptr<TopKThunk>> TopKThunkFromProto(
}
static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
const ThunkProto& proto,
const std::vector<BufferAllocation>& buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
const ThunkProto& proto, const HloModule* hlo_module,
const std::vector<BufferAllocation>* buffer_allocations) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
buffer_allocations);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<ThunkSequence> cond_sequence,
@ -1496,7 +1499,7 @@ static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice cond_buffer,
BufferAllocation::Slice::FromProto(proto.while_thunk().cond_buffer(),
buffer_allocations));
*buffer_allocations));
std::optional<int64_t> trip_count = std::nullopt;
if (proto.while_thunk().trip_count().contains_value()) {
@ -1662,9 +1665,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
}
}
case Thunk::Kind::kCall:
return CallThunkFromProto(proto, *buffer_allocations_);
return CallThunkFromProto(proto, hlo_module_, buffer_allocations_);
case Thunk::Kind::kConditional:
return ConditionalThunkFromProto(proto, *buffer_allocations_);
return ConditionalThunkFromProto(proto, hlo_module_, buffer_allocations_);
case Thunk::Kind::kConvolution:
return ConvolutionThunkFromProto(proto, *buffer_allocations_);
case Thunk::Kind::kCopy:
@ -1688,7 +1691,7 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
case Thunk::Kind::kTopK:
return TopKThunkFromProto(proto, *buffer_allocations_);
case Thunk::Kind::kWhile:
return WhileThunkFromProto(proto, *buffer_allocations_);
return WhileThunkFromProto(proto, hlo_module_, buffer_allocations_);
case Thunk::Kind::kXnnFusion: {
TF_ASSIGN_OR_RETURN(
auto xnn_fusion_kind,
@ -1715,8 +1718,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
}
ThunkSequenceSerDesProtobuf::ThunkSequenceSerDesProtobuf(
const HloModule* hlo_module,
const std::vector<BufferAllocation>* buffer_allocations)
: buffer_allocations_(buffer_allocations) {}
: hlo_module_(hlo_module), buffer_allocations_(buffer_allocations) {}
absl::StatusOr<std::string> ThunkSequenceSerDesProtobuf::Serialize(
const ThunkSequence& thunk_sequence) {

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "xla/backends/cpu/runtime/serdes_base.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/thunk.pb.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/buffer_assignment.h"
namespace xla::cpu {
@ -36,10 +37,13 @@ void ForEachThunkProto(const ThunkSequenceProto& proto,
class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
public:
// For serialization, `hlo_module` and `buffer_allocations` are optional. For
// deserialization, both are required as we rely on the HLO module to resolve
// thunks that were generated from `HloComputation`s, and we also need buffer
// allocations to resolve buffer slices.
explicit ThunkSequenceSerDesProtobuf(
const std::vector<BufferAllocation>* buffer_allocations =
nullptr); // NOTE buffer allocations aren't
// needed for serialization.
const HloModule* hlo_module = nullptr,
const std::vector<BufferAllocation>* buffer_allocations = nullptr);
absl::StatusOr<std::string> Serialize(
const ThunkSequence& thunk_sequence) override;
@ -52,6 +56,7 @@ class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
const ThunkSequenceProto& proto) const;
private:
const HloModule* hlo_module_;
const std::vector<BufferAllocation>* buffer_allocations_;
};

View File

@ -219,8 +219,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test {
public:
void SetUp() override {
thunk_sequence_serdes_ =
std::make_unique<T>(&buffer_allocations_.GetUnderlyingVector());
thunk_sequence_serdes_ = std::make_unique<T>(
nullptr, &buffer_allocations_.GetUnderlyingVector());
}
protected:

View File

@ -82,7 +82,7 @@ CpuAotCompilationResult::Create(
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
TargetMachineOptionsProto target_machine_options) {
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
&buffer_assignment->Allocations());
hlo_module, &buffer_assignment->Allocations());
TF_ASSIGN_OR_RETURN(ThunkSequenceProto thunk_proto,
thunk_sequence_serdes.ToProto(thunks));
@ -143,7 +143,7 @@ CpuAotCompilationResult::CpuAotCompilationResult(
module_ = hlo_module->Clone();
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
&buffer_assignment->Allocations());
hlo_module, &buffer_assignment->Allocations());
*proto_.mutable_thunk_sequence() = thunks;
}
@ -181,7 +181,7 @@ CpuAotCompilationResult::LoadExecutable(
}
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
&buffer_assignment->Allocations());
module.get(), &buffer_assignment->Allocations());
TF_ASSIGN_OR_RETURN(std::unique_ptr<ThunkSequence> thunks,
thunk_sequence_serdes.FromProto(proto_.thunk_sequence()));