Update calls to HloModule::CreateFromProto in hlo_module_util to remap instruction ids by default. This should speed up compilation.

PiperOrigin-RevId: 824521542
This commit is contained in:
A. Unique TensorFlower 2025-10-27 08:03:10 -07:00 committed by TensorFlower Gardener
parent 60ac8fa628
commit fd2941bc67
3 changed files with 88 additions and 5 deletions

View File

@ -73,15 +73,20 @@ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
TF_ASSIGN_OR_RETURN(
HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto, debug_options));
return HloModule::CreateFromProto(proto, config);
return HloModule::CreateFromProto(proto, config,
/*buffer_assignment_proto=*/nullptr,
/*preserve_instruction_ids=*/false);
}
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config,
bool is_module_post_optimizations) {
VLOG(4) << proto.ShortDebugString();
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(proto, module_config));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(proto, module_config,
/*buffer_assignment_proto=*/nullptr,
/*preserve_instruction_ids=*/false));
TF_RETURN_IF_ERROR(
HloVerifier(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/is_module_post_optimizations)
@ -133,7 +138,9 @@ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleBinaryProtofile(
HloModuleConfig module_config,
HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
return HloModule::CreateFromProto(module_proto, module_config);
return HloModule::CreateFromProto(module_proto, module_config,
/*buffer_assignment_proto=*/nullptr,
/*preserve_instruction_ids=*/false);
}
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleTextProtoFile(
@ -146,7 +153,9 @@ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleTextProtoFile(
HloModuleConfig module_config,
HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
return HloModule::CreateFromProto(module_proto, module_config);
return HloModule::CreateFromProto(module_proto, module_config,
/*buffer_assignment_proto=*/nullptr,
/*preserve_instruction_ids=*/false);
}
absl::StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(

View File

@ -256,6 +256,7 @@ xla_test(
"//xla:status_macros",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:filecheck",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_executable",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h"
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <random>
@ -33,6 +34,8 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
@ -62,6 +65,11 @@ limitations under the License.
namespace xla {
namespace {
using ::testing::Each;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Lt;
using ::testing::Property;
using ::testing::SizeIs;
using ::tsl::testing::IsOkAndHolds;
using ::tsl::testing::StatusIs;
@ -795,6 +803,71 @@ TEST_F(FunctionalHloRunnerTest, ReadHloUnoptimizedSnapshot) {
hlo_module_and_arguments_from_binary.arguments.size());
}
TEST_F(FunctionalHloRunnerTest,
ReadHloModuleProtoDoesNotPreserveInstructionIds) {
std::string path_to_text_hlo =
GetHloPath("sharded_unoptimized_hlo_snapshot.pbtxt");
tsl::Env* env = tsl::Env::Default();
// Read the text proto
HloUnoptimizedSnapshot message;
TF_ASSERT_OK(tsl::ReadTextProto(env, path_to_text_hlo, &message));
// Manually modify instruction ids in the proto.
int64_t instruction_id_offset = 1000;
for (HloComputationProto& computation :
*message.mutable_hlo_module()->mutable_computations()) {
for (HloInstructionProto& instruction :
*computation.mutable_instructions()) {
instruction.set_id(instruction.id() + instruction_id_offset);
for (int64_t& operand_id : *instruction.mutable_operand_ids()) {
operand_id += instruction_id_offset;
}
}
computation.set_root_id(computation.root_id() + instruction_id_offset);
}
// Dump message in the custom binary format
std::string path_to_binary_hlo =
tsl::io::JoinPath(std::getenv("TEST_UNDECLARED_OUTPUTS_DIR"),
"sharded_unoptimized_hlo_snapshot_modified_ids.pb");
std::unique_ptr<tsl::WritableFile> file;
TF_ASSERT_OK(env->NewWritableFile(path_to_binary_hlo, &file));
tsl::WritableFileCopyingOutputStream output(file.get());
tsl::protobuf::io::CopyingOutputStreamAdaptor adaptor(&output);
EXPECT_TRUE(message.SerializeToZeroCopyStream(&adaptor));
adaptor.Flush();
TF_ASSERT_OK(file->Close());
// Read HloModuleAndArguments from binary dump.
TF_ASSERT_OK_AND_ASSIGN(
HloModuleAndArguments hlo_module_and_arguments_from_binary,
FunctionalHloRunner::LoadHloModuleAndArguments(
path_to_binary_hlo, InputFormat::kUnoptimizedSnapshotProtoBinary));
// Check if ids have been re-assigned in a compact way
HloComputation* entry_computation =
hlo_module_and_arguments_from_binary.hlo_module->entry_computation();
EXPECT_THAT(entry_computation->instructions(),
ElementsAre(Property(&HloInstruction::local_id, Eq(0)),
Property(&HloInstruction::local_id, Eq(1)),
Property(&HloInstruction::local_id, Eq(2)),
Property(&HloInstruction::local_id, Eq(3))));
// Check that all operand ids are also within the re-assigned range.
EXPECT_THAT(entry_computation->instructions(),
Each(Property(&HloInstruction::operands,
Each(Property(&HloInstruction::local_id, Lt(4))))));
EXPECT_THAT(entry_computation->root_instruction()->local_id(), Eq(3));
}
TEST_F(FunctionalHloRunnerTest, FixFakeArguments) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
GetPjRtClient());