mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA][Numerics][HLO Original Value] Support original values for more cases in while loop simplifier pass
This updates the original value of a while loop if its input was nested tuples and got flatten during the pass PiperOrigin-RevId: 826644894
This commit is contained in:
parent
a3f8740bc7
commit
9c620f90b8
|
|
@ -151,7 +151,7 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
|
||||||
HloInstruction* while_op, absl::flat_hash_set<int64_t>& used_tuple_indices,
|
HloInstruction* while_op, absl::flat_hash_set<int64_t>& used_tuple_indices,
|
||||||
std::optional<absl::flat_hash_map<int32_t, int32_t>>
|
std::optional<absl::flat_hash_map<int32_t, int32_t>>
|
||||||
dead_to_surviving_index = std::nullopt) {
|
dead_to_surviving_index = std::nullopt) {
|
||||||
auto copy_original_value =
|
auto copy_remaining_original_arrays =
|
||||||
[&](const HloInstruction* src_instruction,
|
[&](const HloInstruction* src_instruction,
|
||||||
HloInstruction* dest_instruction,
|
HloInstruction* dest_instruction,
|
||||||
const absl::flat_hash_map<int64_t, int64_t>& old_to_new_tuple_idx) {
|
const absl::flat_hash_map<int64_t, int64_t>& old_to_new_tuple_idx) {
|
||||||
|
|
@ -305,8 +305,9 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
|
||||||
CopyFrontendAttributes(while_op, new_while_op);
|
CopyFrontendAttributes(while_op, new_while_op);
|
||||||
CopyMetadata(while_op, new_while_op);
|
CopyMetadata(while_op, new_while_op);
|
||||||
|
|
||||||
copy_original_value(while_init, new_while_init, old_to_new_tuple_idx);
|
copy_remaining_original_arrays(while_init, new_while_init,
|
||||||
copy_original_value(while_op, new_while_op, old_to_new_tuple_idx);
|
old_to_new_tuple_idx);
|
||||||
|
copy_remaining_original_arrays(while_op, new_while_op, old_to_new_tuple_idx);
|
||||||
|
|
||||||
// Create a tuple op that recreates the output of the old while op. That is,
|
// Create a tuple op that recreates the output of the old while op. That is,
|
||||||
// we transform to
|
// we transform to
|
||||||
|
|
@ -1193,6 +1194,20 @@ static std::vector<HloInstruction*> GetFlatTupleElems(
|
||||||
}
|
}
|
||||||
|
|
||||||
static absl::StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
|
static absl::StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
|
||||||
|
auto flatten_original_value = [&](HloInstruction* old_instr,
|
||||||
|
HloInstruction* new_instr) {
|
||||||
|
if (old_instr->original_value()) {
|
||||||
|
auto new_original_value =
|
||||||
|
std::make_shared<OriginalValue>(new_instr->shape());
|
||||||
|
int64_t i = 0;
|
||||||
|
for (auto& [shape_index, original_array] :
|
||||||
|
old_instr->original_value()->tree().leaves()) {
|
||||||
|
*new_original_value->mutable_tree()->mutable_element({i++}) =
|
||||||
|
original_array;
|
||||||
|
}
|
||||||
|
new_instr->set_original_value(new_original_value);
|
||||||
|
}
|
||||||
|
};
|
||||||
HloModule* module = while_op->GetModule();
|
HloModule* module = while_op->GetModule();
|
||||||
HloComputation* computation = while_op->parent();
|
HloComputation* computation = while_op->parent();
|
||||||
auto* while_init = while_op->mutable_operand(0);
|
auto* while_init = while_op->mutable_operand(0);
|
||||||
|
|
@ -1294,6 +1309,9 @@ static absl::StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
|
||||||
for (auto& instr : new_instrs) {
|
for (auto& instr : new_instrs) {
|
||||||
computation->AddInstruction(std::move(instr));
|
computation->AddInstruction(std::move(instr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flatten_original_value(while_init, new_while_op->mutable_operand(0));
|
||||||
|
flatten_original_value(while_op, new_while_op);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1535,5 +1535,51 @@ ENTRY %main (arg.0: f32[3], arg.1: f32[2]) -> (f32[3], f32[2], f32[2], f32[3]) {
|
||||||
R"(({"arg.0"}, {"arg.1"}, {"constant.0"}))");
|
R"(({"arg.0"}, {"arg.1"}, {"constant.0"}))");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(WhileLoopSimplifierTest, FlattenNestedTupleWithOriginalValue) {
|
||||||
|
const std::string hlo_string = R"(
|
||||||
|
HloModule Test
|
||||||
|
Body {
|
||||||
|
param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
|
||||||
|
ta = (s32[1]) get-tuple-element(param), index=0
|
||||||
|
a = s32[1] get-tuple-element(ta), index=0
|
||||||
|
a.1 = s32[1] add(a, a)
|
||||||
|
tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1
|
||||||
|
ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
|
||||||
|
}
|
||||||
|
Cond {
|
||||||
|
param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
|
||||||
|
ROOT cond = pred[] constant(true)
|
||||||
|
}
|
||||||
|
ENTRY Loop {
|
||||||
|
a = s32[1] constant({0})
|
||||||
|
b = s32[2] constant({0,1})
|
||||||
|
c = s32[3] constant({0,1,2})
|
||||||
|
d = s32[4] constant({0,1,2,3})
|
||||||
|
ta = (s32[1]) tuple(a)
|
||||||
|
td = (s32[4]) tuple(d)
|
||||||
|
tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td)
|
||||||
|
init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd), origin={(({"a"}), (
|
||||||
|
{"b"}, {"c"}, ({"d"})))}
|
||||||
|
ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init),
|
||||||
|
condition=Cond, body=Body, origin={(({"while.116" {0}}), (
|
||||||
|
{"while.116" {1}}, {"while.116" {2}}, ({"while.116" {3}})))}
|
||||||
|
})";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
|
WhileLoopSimplifier().Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
HloInstruction* while_instr = FindFirstWhile(module.get());
|
||||||
|
ASSERT_NE(while_instr->original_value(), nullptr);
|
||||||
|
EXPECT_EQ(
|
||||||
|
while_instr->original_value()->ToString(),
|
||||||
|
R"(({"while.116" {0}}, {"while.116" {1}}, {"while.116" {2}}, {"while.116" {3}}))");
|
||||||
|
HloInstruction* while_init = while_instr->while_init();
|
||||||
|
ASSERT_NE(while_init->original_value(), nullptr);
|
||||||
|
EXPECT_EQ(while_init->original_value()->ToString(),
|
||||||
|
R"(({"a"}, {"b"}, {"c"}, {"d"}))");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user