[XLA][Numerics][HLO Value Tracking] Handle original values in while loop fusible sinking pass

This reconstructs the original value for while loops with a rewritten input/output shape during the pass.

PiperOrigin-RevId: 822465131
This commit is contained in:
Jian Cai 2025-10-22 00:56:25 -07:00 committed by TensorFlower Gardener
parent add51a87c3
commit 95d3b6fe36
2 changed files with 133 additions and 0 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <vector>
@ -34,6 +35,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_original_value.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/while_util.h"
@ -95,6 +97,29 @@ absl::Status UpdateWhileUsesWithTuple(HloInstruction* while_instr,
return absl::OkStatus();
}
void AppendOriginalValues(HloInstruction* instr,
const HloInstruction::InstructionVector& new_operands,
int64_t next_index) {
if (instr->original_value() != nullptr && !new_operands.empty()) {
std::shared_ptr<OriginalValue> old_original_value = instr->original_value(),
new_original_value =
std::make_shared<OriginalValue>(
instr->shape());
for (auto& [shape_index, original_array] : old_original_value->tree()) {
*new_original_value->mutable_tree()->mutable_element(shape_index) =
original_array;
}
for (int64_t i = 0; i < new_operands.size(); ++i) {
if (new_operands[i]->original_value() != nullptr) {
new_original_value->mutable_tree()->CopySubtreeFrom(
new_operands[i]->original_value()->tree(), {}, {next_index + i});
}
}
return instr->set_original_value(new_original_value);
}
}
// Appends the given new operand to while input and update loops computations
// and shape accordingly and returns the gte instruction within the body that
// represents the new operand.
@ -105,6 +130,8 @@ absl::StatusOr<HloInstruction*> AppendToWhileState(
ShapeUtil::AppendShapeToTuple(new_operand->shape(),
while_input->mutable_shape());
while_input->AppendOperand(new_operand);
AppendOriginalValues(while_input, {new_operand},
while_input->operand_count() - 1);
// Update the body computation.
HloComputation* body = while_instr->while_body();
*body->parameter_instruction(0)->mutable_shape() = while_input->shape();
@ -122,6 +149,9 @@ absl::StatusOr<HloInstruction*> AppendToWhileState(
TF_RETURN_IF_ERROR(
UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1));
*while_instr->mutable_shape() = while_input->shape();
AppendOriginalValues(while_instr, {new_operand},
while_input->operand_count() - 1);
return new_gte;
}
@ -459,6 +489,7 @@ absl::StatusOr<bool> WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop(
HloInstruction* parameter = while_body->parameter_instruction(0);
int64_t next_index = init_value->operand_count();
new_operands.resize(fusion->operand_count());
for (int64_t i = 0; i < fusion->operand_count(); ++i) {
init_value->AppendOperand(fusion->mutable_operand(i));
parameter->mutable_shape()->mutable_tuple_shapes()->push_back(
@ -468,10 +499,15 @@ absl::StatusOr<bool> WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop(
root->AppendOperand(new_operands[i]);
}
*(init_value->mutable_shape()) = parameter->shape();
AppendOriginalValues(init_value, fusion->operands(),
next_index - fusion->operand_count());
*(while_instr->mutable_shape()) = parameter->shape();
AppendOriginalValues(while_instr, fusion->operands(),
next_index - fusion->operand_count());
*(while_cond->parameter_instruction(0)->mutable_shape()) =
parameter->shape();
*(root->mutable_shape()) = parameter->shape();
auto cloned_fusion = while_body->AddInstruction(
fusion->CloneWithNewOperands(fusion->shape(), new_operands));
TF_RETURN_IF_ERROR(fusion->parent()->RemoveInstruction(fusion));
@ -539,6 +575,7 @@ absl::StatusOr<bool> WhileLoopFusibleSinking::Run(
}
}
}
return changed;
}
} // namespace xla

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h"
@ -429,5 +430,100 @@ TEST_F(WhileLoopFusibleSinkingTest, TestNoPlumbWithUnknonwnTripCount) {
EXPECT_FALSE(changed);
}
TEST_F(WhileLoopFusibleSinkingTest, SinkMaskWithOriginalValue) {
const char* const hlo_string = R"(
HloModule ModuleWithWhile
body {
p_body = (f32[5,7],f32[5,7]) parameter(0)
p_body.0 = get-tuple-element(p_body), index=0
p_body.1 = get-tuple-element(p_body), index=1
add.0 = add(p_body.0, p_body.1)
ROOT root = tuple(add.0, p_body.1)
}
condition {
p_cond = (f32[5,7],f32[5,7]) parameter(0)
ROOT result = pred[] constant(true)
}
ENTRY entry {
const_0 = f32[5,7] parameter(0), origin={{"constant"}}
p = f32[5] parameter(1), origin={{"parameter"}}
a = f32[5,7] iota(), iota_dimension=0
b = f32[5,7] iota(), iota_dimension=1
c = add(a, b)
d = f32[5,7] broadcast(p), dimensions={0}
mask = multiply(c,d), origin={{"mask"}}
while_init = tuple(const_0, mask), origin={({"constant"}, {"mask"})}
ROOT while = while(while_init), condition=condition, body=body, origin={({"while" {0}}, {"while" {1}})}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopFusibleSinking{}.Run(module.get()));
ASSERT_TRUE(changed);
HloInstruction* while_instr = FindInstruction(module.get(), "while");
ASSERT_NE(while_instr->original_value(), nullptr);
EXPECT_EQ(while_instr->original_value()->ToString(),
"({\"while\" {0}}, {\"while\" {1}}, {\"parameter\"})");
HloInstruction* while_init = while_instr->while_init();
ASSERT_NE(while_init->original_value(), nullptr);
EXPECT_EQ(while_init->original_value()->ToString(),
"({\"constant\"}, {\"mask\"}, {\"parameter\"})");
}
TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) {
const std::string hlo_string_before = R"(
HloModule test
loop.body {
loop_var.1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0)
get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0
get-tuple-element.2 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=1
get-tuple-element.3 = s32[4,3,5]{2,1,0} get-tuple-element(loop_var.1), index=2
bitcast.12855 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} bitcast(get-tuple-element.3)
add.40974 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.2, bitcast.12855)
constant.1 = s32[]{:T(128)} constant(1)
idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1)
ROOT tuple = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(idx, add.40974, get-tuple-element.3)
}
loop.condition {
loop_var.2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0)
get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(loop_var.2), index=0
constant.2 = s32[]{:T(128)} constant(4)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY %main {
param.1 = s32[4,3,5]{2,1,0} iota(), iota_dimension=0
zero = s32[]{:T(128)} constant(0), origin={{"zero"}}
zeros32 = s32[]{:T(128)} constant(0), origin={{"zeros32"}}
broadcast = s32[1,1,1,4,3,5]{5,4,3,2,1,0} broadcast(zeros32), origin={{"broadcast"}}
input = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(zero, broadcast, param.1), origin={({"zero"}, {"zeros32"}, {"broadcast"})}
ROOT while = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) while(input), condition=loop.condition, body=loop.body, origin={({"while" {0}}, {"while" {1}}, {"while" {2}})}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string_before));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopFusibleSinking{}.Run(module.get()));
EXPECT_TRUE(changed);
HloInstruction* while_instr = FindInstruction(module.get(), "while");
ASSERT_NE(while_instr->original_value(), nullptr);
EXPECT_EQ(while_instr->original_value()->ToString(),
"({\"while\" {0}}, {\"while\" {1}}, {\"while\" {2}}, {\"zero\"})");
HloInstruction* while_init = while_instr->while_init();
ASSERT_NE(while_init->original_value(), nullptr);
EXPECT_EQ(while_init->original_value()->ToString(),
"({\"zero\"}, {\"zeros32\"}, {\"broadcast\"}, {\"zero\"})");
}
} // namespace
} // namespace xla