mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA][Numerics][HLO Original Value] Support original values for more cases in while loop simplifier pass
This updates the original value of a while loop if its input was nested tuples and got flatten during the pass PiperOrigin-RevId: 826644894
This commit is contained in:
parent
a3f8740bc7
commit
9c620f90b8
|
|
@ -151,7 +151,7 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
|
|||
HloInstruction* while_op, absl::flat_hash_set<int64_t>& used_tuple_indices,
|
||||
std::optional<absl::flat_hash_map<int32_t, int32_t>>
|
||||
dead_to_surviving_index = std::nullopt) {
|
||||
auto copy_original_value =
|
||||
auto copy_remaining_original_arrays =
|
||||
[&](const HloInstruction* src_instruction,
|
||||
HloInstruction* dest_instruction,
|
||||
const absl::flat_hash_map<int64_t, int64_t>& old_to_new_tuple_idx) {
|
||||
|
|
@ -305,8 +305,9 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
|
|||
CopyFrontendAttributes(while_op, new_while_op);
|
||||
CopyMetadata(while_op, new_while_op);
|
||||
|
||||
copy_original_value(while_init, new_while_init, old_to_new_tuple_idx);
|
||||
copy_original_value(while_op, new_while_op, old_to_new_tuple_idx);
|
||||
copy_remaining_original_arrays(while_init, new_while_init,
|
||||
old_to_new_tuple_idx);
|
||||
copy_remaining_original_arrays(while_op, new_while_op, old_to_new_tuple_idx);
|
||||
|
||||
// Create a tuple op that recreates the output of the old while op. That is,
|
||||
// we transform to
|
||||
|
|
@ -1193,6 +1194,20 @@ static std::vector<HloInstruction*> GetFlatTupleElems(
|
|||
}
|
||||
|
||||
static absl::StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
|
||||
auto flatten_original_value = [&](HloInstruction* old_instr,
|
||||
HloInstruction* new_instr) {
|
||||
if (old_instr->original_value()) {
|
||||
auto new_original_value =
|
||||
std::make_shared<OriginalValue>(new_instr->shape());
|
||||
int64_t i = 0;
|
||||
for (auto& [shape_index, original_array] :
|
||||
old_instr->original_value()->tree().leaves()) {
|
||||
*new_original_value->mutable_tree()->mutable_element({i++}) =
|
||||
original_array;
|
||||
}
|
||||
new_instr->set_original_value(new_original_value);
|
||||
}
|
||||
};
|
||||
HloModule* module = while_op->GetModule();
|
||||
HloComputation* computation = while_op->parent();
|
||||
auto* while_init = while_op->mutable_operand(0);
|
||||
|
|
@ -1294,6 +1309,9 @@ static absl::StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
|
|||
for (auto& instr : new_instrs) {
|
||||
computation->AddInstruction(std::move(instr));
|
||||
}
|
||||
|
||||
flatten_original_value(while_init, new_while_op->mutable_operand(0));
|
||||
flatten_original_value(while_op, new_while_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1535,5 +1535,51 @@ ENTRY %main (arg.0: f32[3], arg.1: f32[2]) -> (f32[3], f32[2], f32[2], f32[3]) {
|
|||
R"(({"arg.0"}, {"arg.1"}, {"constant.0"}))");
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopSimplifierTest, FlattenNestedTupleWithOriginalValue) {
|
||||
const std::string hlo_string = R"(
|
||||
HloModule Test
|
||||
Body {
|
||||
param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
|
||||
ta = (s32[1]) get-tuple-element(param), index=0
|
||||
a = s32[1] get-tuple-element(ta), index=0
|
||||
a.1 = s32[1] add(a, a)
|
||||
tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1
|
||||
ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
|
||||
}
|
||||
Cond {
|
||||
param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
|
||||
ROOT cond = pred[] constant(true)
|
||||
}
|
||||
ENTRY Loop {
|
||||
a = s32[1] constant({0})
|
||||
b = s32[2] constant({0,1})
|
||||
c = s32[3] constant({0,1,2})
|
||||
d = s32[4] constant({0,1,2,3})
|
||||
ta = (s32[1]) tuple(a)
|
||||
td = (s32[4]) tuple(d)
|
||||
tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td)
|
||||
init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd), origin={(({"a"}), (
|
||||
{"b"}, {"c"}, ({"d"})))}
|
||||
ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init),
|
||||
condition=Cond, body=Body, origin={(({"while.116" {0}}), (
|
||||
{"while.116" {1}}, {"while.116" {2}}, ({"while.116" {3}})))}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopSimplifier().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
HloInstruction* while_instr = FindFirstWhile(module.get());
|
||||
ASSERT_NE(while_instr->original_value(), nullptr);
|
||||
EXPECT_EQ(
|
||||
while_instr->original_value()->ToString(),
|
||||
R"(({"while.116" {0}}, {"while.116" {1}}, {"while.116" {2}}, {"while.116" {3}}))");
|
||||
HloInstruction* while_init = while_instr->while_init();
|
||||
ASSERT_NE(while_init->original_value(), nullptr);
|
||||
EXPECT_EQ(while_init->original_value()->ToString(),
|
||||
R"(({"a"}, {"b"}, {"c"}, {"d"}))");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user