mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:cpu] Pass HloModule pointer to Thunk SerDes
Reverts 993369077a
PiperOrigin-RevId: 826675119
This commit is contained in:
parent
00be2bc09e
commit
fbd032df67
|
|
@ -1203,6 +1203,7 @@ cc_library(
|
||||||
"//xla/backends/cpu/runtime/xnnpack:xnn_convolution_thunk",
|
"//xla/backends/cpu/runtime/xnnpack:xnn_convolution_thunk",
|
||||||
"//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk",
|
"//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk",
|
||||||
"//xla/backends/cpu/runtime/xnnpack:xnn_fusion_thunk",
|
"//xla/backends/cpu/runtime/xnnpack:xnn_fusion_thunk",
|
||||||
|
"//xla/hlo/ir:hlo",
|
||||||
"//xla/runtime:resource_use",
|
"//xla/runtime:resource_use",
|
||||||
"//xla/runtime:work_group",
|
"//xla/runtime:work_group",
|
||||||
"//xla/service:buffer_assignment",
|
"//xla/service:buffer_assignment",
|
||||||
|
|
|
||||||
|
|
@ -344,6 +344,7 @@ class ThunkSerDesProtobuf : public SerDesBase<Thunk> {
|
||||||
public:
|
public:
|
||||||
// Buffer allocations and resources are not needed for serialization.
|
// Buffer allocations and resources are not needed for serialization.
|
||||||
explicit ThunkSerDesProtobuf(
|
explicit ThunkSerDesProtobuf(
|
||||||
|
const HloModule* hlo_module = nullptr,
|
||||||
const std::vector<BufferAllocation>* buffer_allocations = nullptr,
|
const std::vector<BufferAllocation>* buffer_allocations = nullptr,
|
||||||
const std::vector<std::shared_ptr<Resource>>* thunk_resources = nullptr);
|
const std::vector<std::shared_ptr<Resource>>* thunk_resources = nullptr);
|
||||||
absl::StatusOr<std::string> Serialize(const Thunk& thunk) override;
|
absl::StatusOr<std::string> Serialize(const Thunk& thunk) override;
|
||||||
|
|
@ -356,16 +357,18 @@ class ThunkSerDesProtobuf : public SerDesBase<Thunk> {
|
||||||
const ThunkProto& proto) const;
|
const ThunkProto& proto) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// TODO(basiol) remove NOLINT when this actually gets used
|
const HloModule* hlo_module_;
|
||||||
const std::vector<BufferAllocation>* buffer_allocations_; // NOLINT
|
const std::vector<BufferAllocation>* buffer_allocations_;
|
||||||
|
|
||||||
const std::vector<std::shared_ptr<Resource>>* thunk_resources_;
|
const std::vector<std::shared_ptr<Resource>>* thunk_resources_;
|
||||||
};
|
};
|
||||||
|
|
||||||
ThunkSerDesProtobuf::ThunkSerDesProtobuf(
|
ThunkSerDesProtobuf::ThunkSerDesProtobuf(
|
||||||
|
const HloModule* hlo_module,
|
||||||
const std::vector<BufferAllocation>* buffer_allocations,
|
const std::vector<BufferAllocation>* buffer_allocations,
|
||||||
const std::vector<std::shared_ptr<Resource>>* thunk_resources)
|
const std::vector<std::shared_ptr<Resource>>* thunk_resources)
|
||||||
: buffer_allocations_(buffer_allocations),
|
: hlo_module_(hlo_module),
|
||||||
|
buffer_allocations_(buffer_allocations),
|
||||||
thunk_resources_(thunk_resources) {}
|
thunk_resources_(thunk_resources) {}
|
||||||
|
|
||||||
absl::StatusOr<std::string> ThunkSerDesProtobuf::Serialize(const Thunk& thunk) {
|
absl::StatusOr<std::string> ThunkSerDesProtobuf::Serialize(const Thunk& thunk) {
|
||||||
|
|
@ -1087,9 +1090,10 @@ ReduceScatterThunkFromProto(
|
||||||
}
|
}
|
||||||
|
|
||||||
static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
|
static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
|
||||||
const ThunkProto& proto,
|
const ThunkProto& proto, const HloModule* hlo_module,
|
||||||
const std::vector<BufferAllocation>& buffer_allocations) {
|
const std::vector<BufferAllocation>* buffer_allocations) {
|
||||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
|
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
|
||||||
|
buffer_allocations);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<ThunkSequence> call_sequence,
|
std::unique_ptr<ThunkSequence> call_sequence,
|
||||||
|
|
@ -1101,9 +1105,10 @@ static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
|
||||||
|
|
||||||
static absl::StatusOr<std::unique_ptr<ConditionalThunk>>
|
static absl::StatusOr<std::unique_ptr<ConditionalThunk>>
|
||||||
ConditionalThunkFromProto(
|
ConditionalThunkFromProto(
|
||||||
const ThunkProto& proto,
|
const ThunkProto& proto, const HloModule* hlo_module,
|
||||||
const std::vector<BufferAllocation>& buffer_allocations) {
|
const std::vector<BufferAllocation>* buffer_allocations) {
|
||||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
|
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
|
||||||
|
buffer_allocations);
|
||||||
|
|
||||||
std::vector<ThunkSequence> branch_sequences;
|
std::vector<ThunkSequence> branch_sequences;
|
||||||
for (const ThunkSequenceProto& branch_sequence_proto :
|
for (const ThunkSequenceProto& branch_sequence_proto :
|
||||||
|
|
@ -1114,10 +1119,10 @@ ConditionalThunkFromProto(
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(Thunk::Info info, ThunkInfoFromProto(proto.info()));
|
TF_ASSIGN_OR_RETURN(Thunk::Info info, ThunkInfoFromProto(proto.info()));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice branch_index_buffer,
|
||||||
BufferAllocation::Slice branch_index_buffer,
|
BufferAllocation::Slice::FromProto(
|
||||||
BufferAllocation::Slice::FromProto(
|
proto.conditional_thunk().branch_index_buffer(),
|
||||||
proto.conditional_thunk().branch_index_buffer(), buffer_allocations));
|
*buffer_allocations));
|
||||||
|
|
||||||
return ConditionalThunk::Create(std::move(info),
|
return ConditionalThunk::Create(std::move(info),
|
||||||
std::move(branch_index_buffer),
|
std::move(branch_index_buffer),
|
||||||
|
|
@ -1480,9 +1485,10 @@ static absl::StatusOr<std::unique_ptr<TopKThunk>> TopKThunkFromProto(
|
||||||
}
|
}
|
||||||
|
|
||||||
static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
|
static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
|
||||||
const ThunkProto& proto,
|
const ThunkProto& proto, const HloModule* hlo_module,
|
||||||
const std::vector<BufferAllocation>& buffer_allocations) {
|
const std::vector<BufferAllocation>* buffer_allocations) {
|
||||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
|
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
|
||||||
|
buffer_allocations);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<ThunkSequence> cond_sequence,
|
std::unique_ptr<ThunkSequence> cond_sequence,
|
||||||
|
|
@ -1496,7 +1502,7 @@ static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
BufferAllocation::Slice cond_buffer,
|
BufferAllocation::Slice cond_buffer,
|
||||||
BufferAllocation::Slice::FromProto(proto.while_thunk().cond_buffer(),
|
BufferAllocation::Slice::FromProto(proto.while_thunk().cond_buffer(),
|
||||||
buffer_allocations));
|
*buffer_allocations));
|
||||||
|
|
||||||
std::optional<int64_t> trip_count = std::nullopt;
|
std::optional<int64_t> trip_count = std::nullopt;
|
||||||
if (proto.while_thunk().trip_count().contains_value()) {
|
if (proto.while_thunk().trip_count().contains_value()) {
|
||||||
|
|
@ -1662,9 +1668,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case Thunk::Kind::kCall:
|
case Thunk::Kind::kCall:
|
||||||
return CallThunkFromProto(proto, *buffer_allocations_);
|
return CallThunkFromProto(proto, hlo_module_, buffer_allocations_);
|
||||||
case Thunk::Kind::kConditional:
|
case Thunk::Kind::kConditional:
|
||||||
return ConditionalThunkFromProto(proto, *buffer_allocations_);
|
return ConditionalThunkFromProto(proto, hlo_module_, buffer_allocations_);
|
||||||
case Thunk::Kind::kConvolution:
|
case Thunk::Kind::kConvolution:
|
||||||
return ConvolutionThunkFromProto(proto, *buffer_allocations_);
|
return ConvolutionThunkFromProto(proto, *buffer_allocations_);
|
||||||
case Thunk::Kind::kCopy:
|
case Thunk::Kind::kCopy:
|
||||||
|
|
@ -1688,7 +1694,7 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
|
||||||
case Thunk::Kind::kTopK:
|
case Thunk::Kind::kTopK:
|
||||||
return TopKThunkFromProto(proto, *buffer_allocations_);
|
return TopKThunkFromProto(proto, *buffer_allocations_);
|
||||||
case Thunk::Kind::kWhile:
|
case Thunk::Kind::kWhile:
|
||||||
return WhileThunkFromProto(proto, *buffer_allocations_);
|
return WhileThunkFromProto(proto, hlo_module_, buffer_allocations_);
|
||||||
case Thunk::Kind::kXnnFusion: {
|
case Thunk::Kind::kXnnFusion: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto xnn_fusion_kind,
|
auto xnn_fusion_kind,
|
||||||
|
|
@ -1715,8 +1721,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
|
||||||
}
|
}
|
||||||
|
|
||||||
ThunkSequenceSerDesProtobuf::ThunkSequenceSerDesProtobuf(
|
ThunkSequenceSerDesProtobuf::ThunkSequenceSerDesProtobuf(
|
||||||
|
const HloModule* hlo_module,
|
||||||
const std::vector<BufferAllocation>* buffer_allocations)
|
const std::vector<BufferAllocation>* buffer_allocations)
|
||||||
: buffer_allocations_(buffer_allocations) {}
|
: hlo_module_(hlo_module), buffer_allocations_(buffer_allocations) {}
|
||||||
|
|
||||||
absl::StatusOr<std::string> ThunkSequenceSerDesProtobuf::Serialize(
|
absl::StatusOr<std::string> ThunkSequenceSerDesProtobuf::Serialize(
|
||||||
const ThunkSequence& thunk_sequence) {
|
const ThunkSequence& thunk_sequence) {
|
||||||
|
|
@ -1736,7 +1743,7 @@ ThunkSequenceSerDesProtobuf::Deserialize(const std::string& serialized) {
|
||||||
|
|
||||||
absl::StatusOr<ThunkSequenceProto> ThunkSequenceSerDesProtobuf::ToProto(
|
absl::StatusOr<ThunkSequenceProto> ThunkSequenceSerDesProtobuf::ToProto(
|
||||||
const ThunkSequence& thunk_sequence) const {
|
const ThunkSequence& thunk_sequence) const {
|
||||||
ThunkSerDesProtobuf thunk_serdes(buffer_allocations_);
|
ThunkSerDesProtobuf thunk_serdes(hlo_module_, buffer_allocations_);
|
||||||
ThunkSequenceProto proto;
|
ThunkSequenceProto proto;
|
||||||
proto.mutable_thunks()->Reserve(thunk_sequence.size());
|
proto.mutable_thunks()->Reserve(thunk_sequence.size());
|
||||||
|
|
||||||
|
|
@ -1798,7 +1805,7 @@ ThunkSequenceSerDesProtobuf::FromProto(const ThunkSequenceProto& proto) const {
|
||||||
|
|
||||||
size_t thunk_index = 0;
|
size_t thunk_index = 0;
|
||||||
for (const ThunkProto& thunk_proto : proto.thunks()) {
|
for (const ThunkProto& thunk_proto : proto.thunks()) {
|
||||||
ThunkSerDesProtobuf thunk_serdes(buffer_allocations_,
|
ThunkSerDesProtobuf thunk_serdes(hlo_module_, buffer_allocations_,
|
||||||
&thunk_resources[thunk_index++]);
|
&thunk_resources[thunk_index++]);
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> thunk,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> thunk,
|
||||||
thunk_serdes.FromProto(thunk_proto));
|
thunk_serdes.FromProto(thunk_proto));
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||||
#include "xla/backends/cpu/runtime/serdes_base.h"
|
#include "xla/backends/cpu/runtime/serdes_base.h"
|
||||||
#include "xla/backends/cpu/runtime/thunk.h"
|
#include "xla/backends/cpu/runtime/thunk.h"
|
||||||
#include "xla/backends/cpu/runtime/thunk.pb.h"
|
#include "xla/backends/cpu/runtime/thunk.pb.h"
|
||||||
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
#include "xla/service/buffer_assignment.h"
|
#include "xla/service/buffer_assignment.h"
|
||||||
|
|
||||||
namespace xla::cpu {
|
namespace xla::cpu {
|
||||||
|
|
@ -36,10 +37,13 @@ void ForEachThunkProto(const ThunkSequenceProto& proto,
|
||||||
|
|
||||||
class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
|
class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
|
||||||
public:
|
public:
|
||||||
|
// For serialization, `hlo_module` and `buffer_allocations` are optional. For
|
||||||
|
// deserialization, both are required as we rely on the HLO module to resolve
|
||||||
|
// thunks that were generated from `HloComputation`s, and we also need buffer
|
||||||
|
// allocations to resolve buffer slices.
|
||||||
explicit ThunkSequenceSerDesProtobuf(
|
explicit ThunkSequenceSerDesProtobuf(
|
||||||
const std::vector<BufferAllocation>* buffer_allocations =
|
const HloModule* hlo_module = nullptr,
|
||||||
nullptr); // NOTE buffer allocations aren't
|
const std::vector<BufferAllocation>* buffer_allocations = nullptr);
|
||||||
// needed for serialization.
|
|
||||||
|
|
||||||
absl::StatusOr<std::string> Serialize(
|
absl::StatusOr<std::string> Serialize(
|
||||||
const ThunkSequence& thunk_sequence) override;
|
const ThunkSequence& thunk_sequence) override;
|
||||||
|
|
@ -52,6 +56,7 @@ class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
|
||||||
const ThunkSequenceProto& proto) const;
|
const ThunkSequenceProto& proto) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
const HloModule* hlo_module_;
|
||||||
const std::vector<BufferAllocation>* buffer_allocations_;
|
const std::vector<BufferAllocation>* buffer_allocations_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -219,8 +219,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
thunk_sequence_serdes_ =
|
thunk_sequence_serdes_ = std::make_unique<T>(
|
||||||
std::make_unique<T>(&buffer_allocations_.GetUnderlyingVector());
|
nullptr, &buffer_allocations_.GetUnderlyingVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ CpuAotCompilationResult::Create(
|
||||||
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
||||||
TargetMachineOptionsProto target_machine_options) {
|
TargetMachineOptionsProto target_machine_options) {
|
||||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
||||||
&buffer_assignment->Allocations());
|
hlo_module, &buffer_assignment->Allocations());
|
||||||
TF_ASSIGN_OR_RETURN(ThunkSequenceProto thunk_proto,
|
TF_ASSIGN_OR_RETURN(ThunkSequenceProto thunk_proto,
|
||||||
thunk_sequence_serdes.ToProto(thunks));
|
thunk_sequence_serdes.ToProto(thunks));
|
||||||
|
|
||||||
|
|
@ -143,7 +143,7 @@ CpuAotCompilationResult::CpuAotCompilationResult(
|
||||||
module_ = hlo_module->Clone();
|
module_ = hlo_module->Clone();
|
||||||
|
|
||||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
||||||
&buffer_assignment->Allocations());
|
hlo_module, &buffer_assignment->Allocations());
|
||||||
*proto_.mutable_thunk_sequence() = thunks;
|
*proto_.mutable_thunk_sequence() = thunks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -181,7 +181,7 @@ CpuAotCompilationResult::LoadExecutable(
|
||||||
}
|
}
|
||||||
|
|
||||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
||||||
&buffer_assignment->Allocations());
|
module.get(), &buffer_assignment->Allocations());
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<ThunkSequence> thunks,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<ThunkSequence> thunks,
|
||||||
thunk_sequence_serdes.FromProto(proto_.thunk_sequence()));
|
thunk_sequence_serdes.FromProto(proto_.thunk_sequence()));
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user