[XLA][Numerics][HLO Original Values] Handles original values of while loops in TPU reduce code motion pass

This updates the original value of a while loop after its input/output shape gets changed after the pass sinks qualified reduce instructions into its body.

PiperOrigin-RevId: 826618908
This commit is contained in:
Jian Cai 2025-10-31 13:58:55 -07:00 committed by TensorFlower Gardener
parent eef0661fc5
commit a6e123761d
7 changed files with 92 additions and 42 deletions

View File

@ -2467,7 +2467,7 @@ class alignas(kInstructionTypeMask + 1) HloInstruction {
std::shared_ptr<OriginalValue> original_value() const; std::shared_ptr<OriginalValue> original_value() const;
void set_original_value(std::shared_ptr<OriginalValue> original_value); void set_original_value(std::shared_ptr<OriginalValue> original_value);
// Copy original value from the input instruction if the source and // Copies original value from the input instruction if the source and
// destination shapes are compatible. This performs a deep copy if clone is // destination shapes are compatible. This performs a deep copy if clone is
// set to true. Otherwise, it performs a shallow copy. Print a warning if the // set to true. Otherwise, it performs a shallow copy. Print a warning if the
// shapes are not compatible and issue_warning is set to true. // shapes are not compatible and issue_warning is set to true.
@ -2475,8 +2475,8 @@ class alignas(kInstructionTypeMask + 1) HloInstruction {
bool issue_warning = false); bool issue_warning = false);
protected: protected:
// Internal constructor for a given opcode/shape, other fields must be filled // Internal constructor for a given opcode/shape, other fields must be
// by factory methods. // filled by factory methods.
HloInstruction(HloOpcode opcode, const Shape& shape); HloInstruction(HloOpcode opcode, const Shape& shape);
void RemoveAllOperands() { operands_.clear(); } void RemoveAllOperands() { operands_.clear(); }

View File

@ -129,6 +129,8 @@ class OriginalValue {
bool IsCompatibleWith(const Shape& shape) const; bool IsCompatibleWith(const Shape& shape) const;
bool IsTuple() const { return tree().IsTuple(); }
bool operator==(const OriginalValue& other) const; bool operator==(const OriginalValue& other) const;
bool operator!=(const OriginalValue& other) const { bool operator!=(const OriginalValue& other) const {

View File

@ -4639,6 +4639,7 @@ xla_cc_test(
srcs = ["while_util_test.cc"], srcs = ["while_util_test.cc"],
deps = [ deps = [
":while_util", ":while_util",
"//xla:shape_util",
"//xla:util", "//xla:util",
"//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:hlo_hardware_independent_test_base",
@ -4906,8 +4907,6 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
], ],
) )

View File

@ -42,8 +42,6 @@ limitations under the License.
#include "xla/shape_util.h" #include "xla/shape_util.h"
#include "xla/util.h" #include "xla/util.h"
#include "xla/xla_data.pb.h" #include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
namespace xla { namespace xla {
@ -97,29 +95,6 @@ absl::Status UpdateWhileUsesWithTuple(HloInstruction* while_instr,
return absl::OkStatus(); return absl::OkStatus();
} }
void AppendOriginalValues(HloInstruction* instr,
const HloInstruction::InstructionVector& new_operands,
int64_t next_index) {
if (instr->original_value() != nullptr && !new_operands.empty()) {
std::shared_ptr<OriginalValue> old_original_value = instr->original_value(),
new_original_value =
std::make_shared<OriginalValue>(
instr->shape());
for (auto& [shape_index, original_array] : old_original_value->tree()) {
*new_original_value->mutable_tree()->mutable_element(shape_index) =
original_array;
}
for (int64_t i = 0; i < new_operands.size(); ++i) {
if (new_operands[i]->original_value() != nullptr) {
new_original_value->mutable_tree()->CopySubtreeFrom(
new_operands[i]->original_value()->tree(), {}, {next_index + i});
}
}
return instr->set_original_value(new_original_value);
}
}
// Appends the given new operand to while input and update loops computations // Appends the given new operand to while input and update loops computations
// and shape accordingly and returns the gte instruction within the body that // and shape accordingly and returns the gte instruction within the body that
// represents the new operand. // represents the new operand.
@ -130,8 +105,6 @@ absl::StatusOr<HloInstruction*> AppendToWhileState(
ShapeUtil::AppendShapeToTuple(new_operand->shape(), ShapeUtil::AppendShapeToTuple(new_operand->shape(),
while_input->mutable_shape()); while_input->mutable_shape());
while_input->AppendOperand(new_operand); while_input->AppendOperand(new_operand);
AppendOriginalValues(while_input, {new_operand},
while_input->operand_count() - 1);
// Update the body computation. // Update the body computation.
HloComputation* body = while_instr->while_body(); HloComputation* body = while_instr->while_body();
*body->parameter_instruction(0)->mutable_shape() = while_input->shape(); *body->parameter_instruction(0)->mutable_shape() = while_input->shape();
@ -149,8 +122,8 @@ absl::StatusOr<HloInstruction*> AppendToWhileState(
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1)); UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1));
*while_instr->mutable_shape() = while_input->shape(); *while_instr->mutable_shape() = while_input->shape();
AppendOriginalValues(while_instr, {new_operand}, // The new body root tuple element has the same value as the new operand.
while_input->operand_count() - 1); AppendToWhileLoopOriginalValue(while_instr, {new_operand});
return new_gte; return new_gte;
} }
@ -499,11 +472,11 @@ absl::StatusOr<bool> WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop(
root->AppendOperand(new_operands[i]); root->AppendOperand(new_operands[i]);
} }
*(init_value->mutable_shape()) = parameter->shape(); *(init_value->mutable_shape()) = parameter->shape();
AppendOriginalValues(init_value, fusion->operands(),
next_index - fusion->operand_count());
*(while_instr->mutable_shape()) = parameter->shape(); *(while_instr->mutable_shape()) = parameter->shape();
AppendOriginalValues(while_instr, fusion->operands(), //
next_index - fusion->operand_count()); // The new body root tuple elements have the same value as the fusion
// operands.
AppendToWhileLoopOriginalValue(while_instr, fusion->operands());
*(while_cond->parameter_instruction(0)->mutable_shape()) = *(while_cond->parameter_instruction(0)->mutable_shape()) =
parameter->shape(); parameter->shape();
*(root->mutable_shape()) = parameter->shape(); *(root->mutable_shape()) = parameter->shape();

View File

@ -471,11 +471,11 @@ ENTRY entry {
HloInstruction* while_instr = FindInstruction(module.get(), "while"); HloInstruction* while_instr = FindInstruction(module.get(), "while");
ASSERT_NE(while_instr->original_value(), nullptr); ASSERT_NE(while_instr->original_value(), nullptr);
EXPECT_EQ(while_instr->original_value()->ToString(), EXPECT_EQ(while_instr->original_value()->ToString(),
"({\"while\" {0}}, {\"while\" {1}}, {\"parameter\"})"); R"(({"while" {0}}, {"while" {1}}, {"parameter"}))");
HloInstruction* while_init = while_instr->while_init(); HloInstruction* while_init = while_instr->while_init();
ASSERT_NE(while_init->original_value(), nullptr); ASSERT_NE(while_init->original_value(), nullptr);
EXPECT_EQ(while_init->original_value()->ToString(), EXPECT_EQ(while_init->original_value()->ToString(),
"({\"constant\"}, {\"mask\"}, {\"parameter\"})"); R"(({"constant"}, {"mask"}, {"parameter"}))");
} }
TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) { TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) {
@ -518,11 +518,11 @@ TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) {
HloInstruction* while_instr = FindInstruction(module.get(), "while"); HloInstruction* while_instr = FindInstruction(module.get(), "while");
ASSERT_NE(while_instr->original_value(), nullptr); ASSERT_NE(while_instr->original_value(), nullptr);
EXPECT_EQ(while_instr->original_value()->ToString(), EXPECT_EQ(while_instr->original_value()->ToString(),
"({\"while\" {0}}, {\"while\" {1}}, {\"while\" {2}}, {\"zero\"})"); R"(({"while" {0}}, {"while" {1}}, {"while" {2}}, {"zero"}))");
HloInstruction* while_init = while_instr->while_init(); HloInstruction* while_init = while_instr->while_init();
ASSERT_NE(while_init->original_value(), nullptr); ASSERT_NE(while_init->original_value(), nullptr);
EXPECT_EQ(while_init->original_value()->ToString(), EXPECT_EQ(while_init->original_value()->ToString(),
"({\"zero\"}, {\"zeros32\"}, {\"broadcast\"}, {\"zero\"})"); R"(({"zero"}, {"zeros32"}, {"broadcast"}, {"zero"}))");
} }
} // namespace } // namespace

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "xla/service/while_util.h" #include "xla/service/while_util.h"
#include <algorithm>
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <iterator> #include <iterator>
@ -615,4 +616,70 @@ absl::Status WhileUtil::IncrementWhileLoopTripCount(
return induction_var->ReplaceAllUsesWith(decremented_induction_var); return induction_var->ReplaceAllUsesWith(decremented_induction_var);
} }
void AppendToWhileLoopOriginalValue(
HloInstruction* while_instr,
const HloInstruction::InstructionVector& new_while_input_tuple_elements) {
auto append_to_original_value = [&](HloInstruction* instr,
int64_t next_index) {
std::shared_ptr<OriginalValue> old_original_value = instr->original_value();
if (old_original_value != nullptr &&
old_original_value->IsCompatibleWith(instr->shape())) {
return;
}
// Returns if neither the instruction nor any of its new tuple elements have
// an original value.
if (old_original_value == nullptr) {
bool has_original_value = false;
std::for_each(new_while_input_tuple_elements.begin(),
new_while_input_tuple_elements.end(),
[&has_original_value](const HloInstruction* instr) {
has_original_value |=
(instr->original_value() != nullptr &&
!instr->original_value()->IsEmpty());
});
if (!has_original_value) {
return;
}
}
std::shared_ptr<OriginalValue> new_original_value =
std::make_shared<OriginalValue>(instr->shape());
if (old_original_value != nullptr) {
if (!old_original_value->IsTuple()) {
new_original_value->mutable_tree()->CopySubtreeFrom(
old_original_value->tree(), {}, {0});
} else {
for (auto& [shape_index, original_array] : old_original_value->tree()) {
*new_original_value->mutable_original_array(shape_index) =
original_array;
}
}
}
for (int64_t i = 0; i < new_while_input_tuple_elements.size(); ++i) {
if (new_while_input_tuple_elements[i]->original_value() != nullptr) {
new_original_value->mutable_tree()->CopySubtreeFrom(
new_while_input_tuple_elements[i]->original_value()->tree(), {},
{next_index + i});
}
}
instr->set_original_value(new_original_value);
};
if (while_instr->opcode() != HloOpcode::kWhile) {
return;
}
const Shape& while_shape = while_instr->shape();
if (!while_shape.IsTuple()) {
return;
}
// Calculates the start index for the new tuple elements in the new original
// value.
int64_t next_index =
while_shape.tuple_shapes().size() - new_while_input_tuple_elements.size();
append_to_original_value(while_instr->while_init(), next_index);
append_to_original_value(while_instr, next_index);
}
} // namespace xla } // namespace xla

View File

@ -155,6 +155,15 @@ class WhileUtil {
static absl::Status IncrementWhileLoopTripCount( static absl::Status IncrementWhileLoopTripCount(
const HloInstruction& while_instruction, int32_t increment); const HloInstruction& while_instruction, int32_t increment);
}; };
// This is a helper function to update the original value after some
// transformations append new elements to the while input tuple (or turn it into
// a tuple if it was not one before). It appends the original values of the
// new elements after existing children of the root node of the old original
// value. This is done for both the input and output of the loop respectively.
void AppendToWhileLoopOriginalValue(
HloInstruction* while_instr,
const HloInstruction::InstructionVector& new_while_input_tuple_elements);
} // namespace xla } // namespace xla
#endif // XLA_SERVICE_WHILE_UTIL_H_ #endif // XLA_SERVICE_WHILE_UTIL_H_