[XLA][Numerics][HLO Original Value] Support original values for more cases in while loop simplifier pass

This updates the original value of a while loop if its input was nested tuples and got flatten during the pass

PiperOrigin-RevId: 826644894
This commit is contained in:
Jian Cai 2025-10-31 15:10:43 -07:00 committed by TensorFlower Gardener
parent a3f8740bc7
commit 9c620f90b8
2 changed files with 67 additions and 3 deletions

View File

@ -151,7 +151,7 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
HloInstruction* while_op, absl::flat_hash_set<int64_t>& used_tuple_indices,
std::optional<absl::flat_hash_map<int32_t, int32_t>>
dead_to_surviving_index = std::nullopt) {
auto copy_original_value =
auto copy_remaining_original_arrays =
[&](const HloInstruction* src_instruction,
HloInstruction* dest_instruction,
const absl::flat_hash_map<int64_t, int64_t>& old_to_new_tuple_idx) {
@ -305,8 +305,9 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
CopyFrontendAttributes(while_op, new_while_op);
CopyMetadata(while_op, new_while_op);
copy_original_value(while_init, new_while_init, old_to_new_tuple_idx);
copy_original_value(while_op, new_while_op, old_to_new_tuple_idx);
copy_remaining_original_arrays(while_init, new_while_init,
old_to_new_tuple_idx);
copy_remaining_original_arrays(while_op, new_while_op, old_to_new_tuple_idx);
// Create a tuple op that recreates the output of the old while op. That is,
// we transform to
@ -1193,6 +1194,20 @@ static std::vector<HloInstruction*> GetFlatTupleElems(
}
static absl::StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
auto flatten_original_value = [&](HloInstruction* old_instr,
HloInstruction* new_instr) {
if (old_instr->original_value()) {
auto new_original_value =
std::make_shared<OriginalValue>(new_instr->shape());
int64_t i = 0;
for (auto& [shape_index, original_array] :
old_instr->original_value()->tree().leaves()) {
*new_original_value->mutable_tree()->mutable_element({i++}) =
original_array;
}
new_instr->set_original_value(new_original_value);
}
};
HloModule* module = while_op->GetModule();
HloComputation* computation = while_op->parent();
auto* while_init = while_op->mutable_operand(0);
@ -1294,6 +1309,9 @@ static absl::StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
for (auto& instr : new_instrs) {
computation->AddInstruction(std::move(instr));
}
flatten_original_value(while_init, new_while_op->mutable_operand(0));
flatten_original_value(while_op, new_while_op);
return true;
}

View File

@ -1535,5 +1535,51 @@ ENTRY %main (arg.0: f32[3], arg.1: f32[2]) -> (f32[3], f32[2], f32[2], f32[3]) {
R"(({"arg.0"}, {"arg.1"}, {"constant.0"}))");
}
TEST_F(WhileLoopSimplifierTest, FlattenNestedTupleWithOriginalValue) {
const std::string hlo_string = R"(
HloModule Test
Body {
param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
ta = (s32[1]) get-tuple-element(param), index=0
a = s32[1] get-tuple-element(ta), index=0
a.1 = s32[1] add(a, a)
tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1
ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
}
Cond {
param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
ROOT cond = pred[] constant(true)
}
ENTRY Loop {
a = s32[1] constant({0})
b = s32[2] constant({0,1})
c = s32[3] constant({0,1,2})
d = s32[4] constant({0,1,2,3})
ta = (s32[1]) tuple(a)
td = (s32[4]) tuple(d)
tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td)
init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd), origin={(({"a"}), (
{"b"}, {"c"}, ({"d"})))}
ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init),
condition=Cond, body=Body, origin={(({"while.116" {0}}), (
{"while.116" {1}}, {"while.116" {2}}, ({"while.116" {3}})))}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopSimplifier().Run(module.get()));
EXPECT_TRUE(changed);
HloInstruction* while_instr = FindFirstWhile(module.get());
ASSERT_NE(while_instr->original_value(), nullptr);
EXPECT_EQ(
while_instr->original_value()->ToString(),
R"(({"while.116" {0}}, {"while.116" {1}}, {"while.116" {2}}, {"while.116" {3}}))");
HloInstruction* while_init = while_instr->while_init();
ASSERT_NE(while_init->original_value(), nullptr);
EXPECT_EQ(while_init->original_value()->ToString(),
R"(({"a"}, {"b"}, {"c"}, {"d"}))");
}
} // namespace
} // namespace xla