mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[xla:cpu] Pass HloModule pointer to Thunk SerDes
PiperOrigin-RevId: 826312546
This commit is contained in:
parent
56d3b19280
commit
bf23bf1b32
|
|
@ -1203,6 +1203,7 @@ cc_library(
|
|||
"//xla/backends/cpu/runtime/xnnpack:xnn_convolution_thunk",
|
||||
"//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk",
|
||||
"//xla/backends/cpu/runtime/xnnpack:xnn_fusion_thunk",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/runtime:resource_use",
|
||||
"//xla/runtime:work_group",
|
||||
"//xla/service:buffer_assignment",
|
||||
|
|
|
|||
|
|
@ -356,8 +356,8 @@ class ThunkSerDesProtobuf : public SerDesBase<Thunk> {
|
|||
const ThunkProto& proto) const;
|
||||
|
||||
private:
|
||||
// TODO(basiol) remove NOLINT when this actually gets used
|
||||
const std::vector<BufferAllocation>* buffer_allocations_; // NOLINT
|
||||
const HloModule* hlo_module_;
|
||||
const std::vector<BufferAllocation>* buffer_allocations_;
|
||||
|
||||
const std::vector<std::shared_ptr<Resource>>* thunk_resources_;
|
||||
};
|
||||
|
|
@ -1087,9 +1087,10 @@ ReduceScatterThunkFromProto(
|
|||
}
|
||||
|
||||
static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
|
||||
const ThunkProto& proto,
|
||||
const std::vector<BufferAllocation>& buffer_allocations) {
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
|
||||
const ThunkProto& proto, const HloModule* hlo_module,
|
||||
const std::vector<BufferAllocation>* buffer_allocations) {
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
|
||||
buffer_allocations);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<ThunkSequence> call_sequence,
|
||||
|
|
@ -1101,9 +1102,10 @@ static absl::StatusOr<std::unique_ptr<CallThunk>> CallThunkFromProto(
|
|||
|
||||
static absl::StatusOr<std::unique_ptr<ConditionalThunk>>
|
||||
ConditionalThunkFromProto(
|
||||
const ThunkProto& proto,
|
||||
const std::vector<BufferAllocation>& buffer_allocations) {
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
|
||||
const ThunkProto& proto, const HloModule* hlo_module,
|
||||
const std::vector<BufferAllocation>* buffer_allocations) {
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
|
||||
buffer_allocations);
|
||||
|
||||
std::vector<ThunkSequence> branch_sequences;
|
||||
for (const ThunkSequenceProto& branch_sequence_proto :
|
||||
|
|
@ -1114,10 +1116,10 @@ ConditionalThunkFromProto(
|
|||
}
|
||||
TF_ASSIGN_OR_RETURN(Thunk::Info info, ThunkInfoFromProto(proto.info()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
BufferAllocation::Slice branch_index_buffer,
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice branch_index_buffer,
|
||||
BufferAllocation::Slice::FromProto(
|
||||
proto.conditional_thunk().branch_index_buffer(), buffer_allocations));
|
||||
proto.conditional_thunk().branch_index_buffer(),
|
||||
*buffer_allocations));
|
||||
|
||||
return ConditionalThunk::Create(std::move(info),
|
||||
std::move(branch_index_buffer),
|
||||
|
|
@ -1480,9 +1482,10 @@ static absl::StatusOr<std::unique_ptr<TopKThunk>> TopKThunkFromProto(
|
|||
}
|
||||
|
||||
static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
|
||||
const ThunkProto& proto,
|
||||
const std::vector<BufferAllocation>& buffer_allocations) {
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(&buffer_allocations);
|
||||
const ThunkProto& proto, const HloModule* hlo_module,
|
||||
const std::vector<BufferAllocation>* buffer_allocations) {
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(hlo_module,
|
||||
buffer_allocations);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<ThunkSequence> cond_sequence,
|
||||
|
|
@ -1496,7 +1499,7 @@ static absl::StatusOr<std::unique_ptr<WhileThunk>> WhileThunkFromProto(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
BufferAllocation::Slice cond_buffer,
|
||||
BufferAllocation::Slice::FromProto(proto.while_thunk().cond_buffer(),
|
||||
buffer_allocations));
|
||||
*buffer_allocations));
|
||||
|
||||
std::optional<int64_t> trip_count = std::nullopt;
|
||||
if (proto.while_thunk().trip_count().contains_value()) {
|
||||
|
|
@ -1662,9 +1665,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
|
|||
}
|
||||
}
|
||||
case Thunk::Kind::kCall:
|
||||
return CallThunkFromProto(proto, *buffer_allocations_);
|
||||
return CallThunkFromProto(proto, hlo_module_, buffer_allocations_);
|
||||
case Thunk::Kind::kConditional:
|
||||
return ConditionalThunkFromProto(proto, *buffer_allocations_);
|
||||
return ConditionalThunkFromProto(proto, hlo_module_, buffer_allocations_);
|
||||
case Thunk::Kind::kConvolution:
|
||||
return ConvolutionThunkFromProto(proto, *buffer_allocations_);
|
||||
case Thunk::Kind::kCopy:
|
||||
|
|
@ -1688,7 +1691,7 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
|
|||
case Thunk::Kind::kTopK:
|
||||
return TopKThunkFromProto(proto, *buffer_allocations_);
|
||||
case Thunk::Kind::kWhile:
|
||||
return WhileThunkFromProto(proto, *buffer_allocations_);
|
||||
return WhileThunkFromProto(proto, hlo_module_, buffer_allocations_);
|
||||
case Thunk::Kind::kXnnFusion: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto xnn_fusion_kind,
|
||||
|
|
@ -1715,8 +1718,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> ThunkSerDesProtobuf::FromProto(
|
|||
}
|
||||
|
||||
ThunkSequenceSerDesProtobuf::ThunkSequenceSerDesProtobuf(
|
||||
const HloModule* hlo_module,
|
||||
const std::vector<BufferAllocation>* buffer_allocations)
|
||||
: buffer_allocations_(buffer_allocations) {}
|
||||
: hlo_module_(hlo_module), buffer_allocations_(buffer_allocations) {}
|
||||
|
||||
absl::StatusOr<std::string> ThunkSequenceSerDesProtobuf::Serialize(
|
||||
const ThunkSequence& thunk_sequence) {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "xla/backends/cpu/runtime/serdes_base.h"
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
#include "xla/backends/cpu/runtime/thunk.pb.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
|
@ -36,10 +37,13 @@ void ForEachThunkProto(const ThunkSequenceProto& proto,
|
|||
|
||||
class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
|
||||
public:
|
||||
// For serialization, `hlo_module` and `buffer_allocations` are optional. For
|
||||
// deserialization, both are required as we rely on the HLO module to resolve
|
||||
// thunks that were generated from `HloComputation`s, and we also need buffer
|
||||
// allocations to resolve buffer slices.
|
||||
explicit ThunkSequenceSerDesProtobuf(
|
||||
const std::vector<BufferAllocation>* buffer_allocations =
|
||||
nullptr); // NOTE buffer allocations aren't
|
||||
// needed for serialization.
|
||||
const HloModule* hlo_module = nullptr,
|
||||
const std::vector<BufferAllocation>* buffer_allocations = nullptr);
|
||||
|
||||
absl::StatusOr<std::string> Serialize(
|
||||
const ThunkSequence& thunk_sequence) override;
|
||||
|
|
@ -52,6 +56,7 @@ class ThunkSequenceSerDesProtobuf : public SerDesBase<ThunkSequence> {
|
|||
const ThunkSequenceProto& proto) const;
|
||||
|
||||
private:
|
||||
const HloModule* hlo_module_;
|
||||
const std::vector<BufferAllocation>* buffer_allocations_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -219,8 +219,8 @@ class ThunkSequenceSerdesTest : public ::testing::Test {
|
|||
|
||||
public:
|
||||
void SetUp() override {
|
||||
thunk_sequence_serdes_ =
|
||||
std::make_unique<T>(&buffer_allocations_.GetUnderlyingVector());
|
||||
thunk_sequence_serdes_ = std::make_unique<T>(
|
||||
nullptr, &buffer_allocations_.GetUnderlyingVector());
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ CpuAotCompilationResult::Create(
|
|||
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
||||
TargetMachineOptionsProto target_machine_options) {
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
||||
&buffer_assignment->Allocations());
|
||||
hlo_module, &buffer_assignment->Allocations());
|
||||
TF_ASSIGN_OR_RETURN(ThunkSequenceProto thunk_proto,
|
||||
thunk_sequence_serdes.ToProto(thunks));
|
||||
|
||||
|
|
@ -143,7 +143,7 @@ CpuAotCompilationResult::CpuAotCompilationResult(
|
|||
module_ = hlo_module->Clone();
|
||||
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
||||
&buffer_assignment->Allocations());
|
||||
hlo_module, &buffer_assignment->Allocations());
|
||||
*proto_.mutable_thunk_sequence() = thunks;
|
||||
}
|
||||
|
||||
|
|
@ -181,7 +181,7 @@ CpuAotCompilationResult::LoadExecutable(
|
|||
}
|
||||
|
||||
ThunkSequenceSerDesProtobuf thunk_sequence_serdes(
|
||||
&buffer_assignment->Allocations());
|
||||
module.get(), &buffer_assignment->Allocations());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<ThunkSequence> thunks,
|
||||
thunk_sequence_serdes.FromProto(proto_.thunk_sequence()));
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user