mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Update calls to HloModule::CreateFromProto in hlo_module_util to remap instruction ids by default. This should speed up compilation.
PiperOrigin-RevId: 824521542
This commit is contained in:
parent
60ac8fa628
commit
fd2941bc67
19
third_party/xla/xla/service/hlo_module_util.cc
vendored
19
third_party/xla/xla/service/hlo_module_util.cc
vendored
|
|
@ -73,15 +73,20 @@ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
HloModuleConfig config,
|
||||
HloModule::CreateModuleConfigFromProto(proto, debug_options));
|
||||
return HloModule::CreateFromProto(proto, config);
|
||||
return HloModule::CreateFromProto(proto, config,
|
||||
/*buffer_assignment_proto=*/nullptr,
|
||||
/*preserve_instruction_ids=*/false);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||
bool is_module_post_optimizations) {
|
||||
VLOG(4) << proto.ShortDebugString();
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
HloModule::CreateFromProto(proto, module_config));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
HloModule::CreateFromProto(proto, module_config,
|
||||
/*buffer_assignment_proto=*/nullptr,
|
||||
/*preserve_instruction_ids=*/false));
|
||||
TF_RETURN_IF_ERROR(
|
||||
HloVerifier(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/is_module_post_optimizations)
|
||||
|
|
@ -133,7 +138,9 @@ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleBinaryProtofile(
|
|||
HloModuleConfig module_config,
|
||||
HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
|
||||
|
||||
return HloModule::CreateFromProto(module_proto, module_config);
|
||||
return HloModule::CreateFromProto(module_proto, module_config,
|
||||
/*buffer_assignment_proto=*/nullptr,
|
||||
/*preserve_instruction_ids=*/false);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleTextProtoFile(
|
||||
|
|
@ -146,7 +153,9 @@ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleTextProtoFile(
|
|||
HloModuleConfig module_config,
|
||||
HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
|
||||
|
||||
return HloModule::CreateFromProto(module_proto, module_config);
|
||||
return HloModule::CreateFromProto(module_proto, module_config,
|
||||
/*buffer_assignment_proto=*/nullptr,
|
||||
/*preserve_instruction_ids=*/false);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
|
|
|
|||
|
|
@ -256,6 +256,7 @@ xla_test(
|
|||
"//xla:status_macros",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla:xla_proto_cc",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/testlib:filecheck",
|
||||
"//xla/pjrt:pjrt_client",
|
||||
"//xla/pjrt:pjrt_executable",
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
|
|
@ -33,6 +34,8 @@ limitations under the License.
|
|||
#include "absl/strings/string_view.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "xla/debug_options_flags.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/testlib/filecheck.h"
|
||||
#include "xla/pjrt/pjrt_client.h"
|
||||
#include "xla/pjrt/pjrt_executable.h"
|
||||
|
|
@ -62,6 +65,11 @@ limitations under the License.
|
|||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::Each;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Eq;
|
||||
using ::testing::Lt;
|
||||
using ::testing::Property;
|
||||
using ::testing::SizeIs;
|
||||
using ::tsl::testing::IsOkAndHolds;
|
||||
using ::tsl::testing::StatusIs;
|
||||
|
|
@ -795,6 +803,71 @@ TEST_F(FunctionalHloRunnerTest, ReadHloUnoptimizedSnapshot) {
|
|||
hlo_module_and_arguments_from_binary.arguments.size());
|
||||
}
|
||||
|
||||
TEST_F(FunctionalHloRunnerTest,
|
||||
ReadHloModuleProtoDoesNotPreserveInstructionIds) {
|
||||
std::string path_to_text_hlo =
|
||||
GetHloPath("sharded_unoptimized_hlo_snapshot.pbtxt");
|
||||
|
||||
tsl::Env* env = tsl::Env::Default();
|
||||
|
||||
// Read the text proto
|
||||
HloUnoptimizedSnapshot message;
|
||||
TF_ASSERT_OK(tsl::ReadTextProto(env, path_to_text_hlo, &message));
|
||||
|
||||
// Manually modify instruction ids in the proto.
|
||||
int64_t instruction_id_offset = 1000;
|
||||
for (HloComputationProto& computation :
|
||||
*message.mutable_hlo_module()->mutable_computations()) {
|
||||
for (HloInstructionProto& instruction :
|
||||
*computation.mutable_instructions()) {
|
||||
instruction.set_id(instruction.id() + instruction_id_offset);
|
||||
for (int64_t& operand_id : *instruction.mutable_operand_ids()) {
|
||||
operand_id += instruction_id_offset;
|
||||
}
|
||||
}
|
||||
computation.set_root_id(computation.root_id() + instruction_id_offset);
|
||||
}
|
||||
|
||||
// Dump message in the custom binary format
|
||||
std::string path_to_binary_hlo =
|
||||
tsl::io::JoinPath(std::getenv("TEST_UNDECLARED_OUTPUTS_DIR"),
|
||||
"sharded_unoptimized_hlo_snapshot_modified_ids.pb");
|
||||
|
||||
std::unique_ptr<tsl::WritableFile> file;
|
||||
TF_ASSERT_OK(env->NewWritableFile(path_to_binary_hlo, &file));
|
||||
|
||||
tsl::WritableFileCopyingOutputStream output(file.get());
|
||||
|
||||
tsl::protobuf::io::CopyingOutputStreamAdaptor adaptor(&output);
|
||||
EXPECT_TRUE(message.SerializeToZeroCopyStream(&adaptor));
|
||||
adaptor.Flush();
|
||||
|
||||
TF_ASSERT_OK(file->Close());
|
||||
|
||||
// Read HloModuleAndArguments from binary dump.
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloModuleAndArguments hlo_module_and_arguments_from_binary,
|
||||
FunctionalHloRunner::LoadHloModuleAndArguments(
|
||||
path_to_binary_hlo, InputFormat::kUnoptimizedSnapshotProtoBinary));
|
||||
|
||||
// Check if ids have been re-assigned in a compact way
|
||||
HloComputation* entry_computation =
|
||||
hlo_module_and_arguments_from_binary.hlo_module->entry_computation();
|
||||
|
||||
EXPECT_THAT(entry_computation->instructions(),
|
||||
ElementsAre(Property(&HloInstruction::local_id, Eq(0)),
|
||||
Property(&HloInstruction::local_id, Eq(1)),
|
||||
Property(&HloInstruction::local_id, Eq(2)),
|
||||
Property(&HloInstruction::local_id, Eq(3))));
|
||||
|
||||
// Check that all operand ids are also within the re-assigned range.
|
||||
EXPECT_THAT(entry_computation->instructions(),
|
||||
Each(Property(&HloInstruction::operands,
|
||||
Each(Property(&HloInstruction::local_id, Lt(4))))));
|
||||
|
||||
EXPECT_THAT(entry_computation->root_instruction()->local_id(), Eq(3));
|
||||
}
|
||||
|
||||
TEST_F(FunctionalHloRunnerTest, FixFakeArguments) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
|
||||
GetPjRtClient());
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user