[XLA][Numerics][HLO Original Value] Support original values for some cases in while loop simplifier pass

This updates the original value of a while loop if any unused parameters are removed.

PiperOrigin-RevId: 825221785
This commit is contained in:
Jian Cai 2025-10-28 15:07:03 -07:00 committed by TensorFlower Gardener
parent 2feb74eeff
commit 66078903f7
2 changed files with 87 additions and 0 deletions

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_original_value.h"
#include "xla/hlo/ir/hlo_print_options.h"
#include "xla/hlo/transforms/simplifiers/hlo_dce.h"
#include "xla/hlo/utils/hlo_query.h"
@ -150,6 +151,35 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
HloInstruction* while_op, absl::flat_hash_set<int64_t>& used_tuple_indices,
std::optional<absl::flat_hash_map<int32_t, int32_t>>
dead_to_surviving_index = std::nullopt) {
auto copy_original_value =
[&](const HloInstruction* src_instruction,
HloInstruction* dest_instruction,
const absl::flat_hash_map<int64_t, int64_t>& old_to_new_tuple_idx) {
std::shared_ptr<OriginalValue> original_value =
src_instruction->original_value();
if (!original_value) {
return;
}
const int64_t src_tuple_size =
src_instruction->shape().tuple_shapes().size(),
dest_tuple_size =
dest_instruction->shape().tuple_shapes().size();
std::shared_ptr<OriginalValue> old_original_value =
src_instruction->original_value();
std::shared_ptr<xla::OriginalValue> new_original_value =
std::make_shared<xla::OriginalValue>(dest_instruction->shape());
for (const auto& [old_idx, new_idx] : old_to_new_tuple_idx) {
if (old_idx < 0 || old_idx >= src_tuple_size || new_idx < 0 ||
new_idx >= dest_tuple_size) {
return;
}
new_original_value->mutable_tree()->CopySubtreeFrom(
old_original_value->tree(), {old_idx}, {new_idx});
}
dest_instruction->set_original_value(new_original_value);
};
// Build up maps from the old/new to the new/old tuple indices.
std::vector<int64_t> new_to_old_tuple_idx(used_tuple_indices.begin(),
used_tuple_indices.end());
@ -275,6 +305,9 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
CopyFrontendAttributes(while_op, new_while_op);
CopyMetadata(while_op, new_while_op);
copy_original_value(while_init, new_while_init, old_to_new_tuple_idx);
copy_original_value(while_op, new_while_op, old_to_new_tuple_idx);
// Create a tuple op that recreates the output of the old while op. That is,
// we transform to
//

View File

@ -1481,5 +1481,59 @@ ENTRY %main (arg.0: f32[3], arg.1: f32[2]) -> (f32[3], f32[2], f32[2], f32[3]) {
op::GetTupleElement(op::While(), 0)));
}
TEST_F(WhileLoopSimplifierTest, RemoveDeadTupleIndicesWithOriginalValue) {
const std::string hlo_string = R"(
HloModule dus
%while.body (arg_tuple: (f32[3], f32[2], f32[2], f32[3], s32[])) -> (f32[3], f32[2], f32[2], f32[3], s32[]) {
%arg_tuple = (f32[3], f32[2], f32[2], f32[3], s32[]) parameter(0)
%get-tuple-element.0 = f32[3] get-tuple-element(%arg_tuple), index=0
%get-tuple-element.1 = f32[2] get-tuple-element(%arg_tuple), index=1
%get-tuple-element.2 = f32[2] get-tuple-element(%arg_tuple), index=2
%get-tuple-element.3 = f32[3] get-tuple-element(%arg_tuple), index=3
%get-tuple-element.4 = s32[] get-tuple-element(%arg_tuple), index=4
%constant.1 = s32[] constant(1)
%constant.v0 = f32[1] constant({0.0})
%constant.v1 = f32[1] constant({1.0})
%dynamic-update-slice.0 = f32[3] dynamic-update-slice(%get-tuple-element.0, %constant.v0, s32[] %constant.1)
%dynamic-update-slice.3 = f32[3] dynamic-update-slice(%get-tuple-element.3, %constant.v0, s32[] %constant.1)
%add = add(s32[] %get-tuple-element.4, s32[] %constant.1)
ROOT %tuple = tuple(%dynamic-update-slice.0, %get-tuple-element.1, %get-tuple-element.2, %dynamic-update-slice.3, %add)
}
%while.condition (arg_tuple.cond:(f32[3], f32[2], f32[2], f32[3], s32[])) -> pred[] {
%arg_tuple.cond = (f32[3], f32[2], f32[2], f32[3], s32[]) parameter(0)
%get-tuple-element.cond = s32[] get-tuple-element(%arg_tuple.cond), index=4
%constant.3 = s32[] constant(3)
ROOT %compare = pred[] compare(s32[] %get-tuple-element.cond, s32[] %constant.3), direction=LT
}
ENTRY %main (arg.0: f32[3], arg.1: f32[2]) -> (f32[3], f32[2], f32[2], f32[3]) {
%constant.0 = s32[] constant(0)
%arg.0 = f32[3] parameter(0)
%arg.1 = f32[2] parameter(1)
%input = tuple(%arg.0, %arg.1, %arg.1, %arg.0, %constant.0), origin={({"arg.0"}, {"arg.1"}, {"arg.1"}, {"arg.0"}, {"constant.0"})}
%while = while(%input), condition=%while.condition, body=%while.body, origin={({"while.116" {0}}, {"while.116" {1}}, {"while.116" {2}}, {"while.116" {3}}, {"while.116" {4}})}
%get-tuple-element.out0 = f32[3] get-tuple-element(%while), index=0
%get-tuple-element.out1 = f32[2] get-tuple-element(%while), index=1
%get-tuple-element.out2 = f32[2] get-tuple-element(%while), index=2
%get-tuple-element.out3 = f32[3] get-tuple-element(%while), index=3
ROOT %root = tuple(%get-tuple-element.out0, %get-tuple-element.out1, %get-tuple-element.out2, %get-tuple-element.out3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
ASSERT_TRUE(WhileLoopSimplifier().Run(module.get()).value());
HloInstruction* while_instr = FindFirstWhile(module.get());
ASSERT_NE(while_instr->original_value(), nullptr);
EXPECT_EQ(while_instr->original_value()->ToString(),
R"(({"while.116" {0}}, {"while.116" {1}}, {"while.116" {4}}))");
HloInstruction* while_init = while_instr->while_init();
ASSERT_NE(while_init->original_value(), nullptr);
EXPECT_EQ(while_init->original_value()->ToString(),
R"(({"arg.0"}, {"arg.1"}, {"constant.0"}))");
}
} // namespace
} // namespace xla