From a6e123761d01eaac4a842b6cd80f35af010808f7 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Fri, 31 Oct 2025 13:58:55 -0700 Subject: [PATCH] [XLA][Numerics][HLO Original Values] Handles original values of while loops in TPU reduce code motion pass This updates the original value of a while loop after its input/output shape gets changed after the pass sinks qualified reduce instructions into its body. PiperOrigin-RevId: 826618908 --- third_party/xla/xla/hlo/ir/hlo_instruction.h | 6 +- .../xla/xla/hlo/ir/hlo_original_value.h | 2 + third_party/xla/xla/service/BUILD | 3 +- .../xla/service/while_loop_fusible_sinking.cc | 39 ++--------- .../while_loop_fusible_sinking_test.cc | 8 +-- third_party/xla/xla/service/while_util.cc | 67 +++++++++++++++++++ third_party/xla/xla/service/while_util.h | 9 +++ 7 files changed, 92 insertions(+), 42 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index e1e12dbf2af..d62cb202339 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -2467,7 +2467,7 @@ class alignas(kInstructionTypeMask + 1) HloInstruction { std::shared_ptr original_value() const; void set_original_value(std::shared_ptr original_value); - // Copy original value from the input instruction if the source and + // Copies original value from the input instruction if the source and // destination shapes are compatible. This performs a deep copy if clone is // set to true. Otherwise, it performs a shallow copy. Print a warning if the // shapes are not compatible and issue_warning is set to true. @@ -2475,8 +2475,8 @@ class alignas(kInstructionTypeMask + 1) HloInstruction { bool issue_warning = false); protected: - // Internal constructor for a given opcode/shape, other fields must be filled - // by factory methods. + // Internal constructor for a given opcode/shape, other fields must be + // filled by factory methods. HloInstruction(HloOpcode opcode, const Shape& shape); void RemoveAllOperands() { operands_.clear(); } diff --git a/third_party/xla/xla/hlo/ir/hlo_original_value.h b/third_party/xla/xla/hlo/ir/hlo_original_value.h index 039745fa372..5ccfebc6839 100644 --- a/third_party/xla/xla/hlo/ir/hlo_original_value.h +++ b/third_party/xla/xla/hlo/ir/hlo_original_value.h @@ -129,6 +129,8 @@ class OriginalValue { bool IsCompatibleWith(const Shape& shape) const; + bool IsTuple() const { return tree().IsTuple(); } + bool operator==(const OriginalValue& other) const; bool operator!=(const OriginalValue& other) const { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index f1f6aaeb45d..d520201d541 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4639,6 +4639,7 @@ xla_cc_test( srcs = ["while_util_test.cc"], deps = [ ":while_util", + "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", @@ -4906,8 +4907,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc index 852258d2944..e363c7e20cc 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -42,8 +42,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { @@ -97,29 +95,6 @@ absl::Status UpdateWhileUsesWithTuple(HloInstruction* while_instr, return absl::OkStatus(); } -void AppendOriginalValues(HloInstruction* instr, - const HloInstruction::InstructionVector& new_operands, - int64_t next_index) { - if (instr->original_value() != nullptr && !new_operands.empty()) { - std::shared_ptr old_original_value = instr->original_value(), - new_original_value = - std::make_shared( - instr->shape()); - for (auto& [shape_index, original_array] : old_original_value->tree()) { - *new_original_value->mutable_tree()->mutable_element(shape_index) = - original_array; - } - - for (int64_t i = 0; i < new_operands.size(); ++i) { - if (new_operands[i]->original_value() != nullptr) { - new_original_value->mutable_tree()->CopySubtreeFrom( - new_operands[i]->original_value()->tree(), {}, {next_index + i}); - } - } - return instr->set_original_value(new_original_value); - } -} - // Appends the given new operand to while input and update loops computations // and shape accordingly and returns the gte instruction within the body that // represents the new operand. @@ -130,8 +105,6 @@ absl::StatusOr AppendToWhileState( ShapeUtil::AppendShapeToTuple(new_operand->shape(), while_input->mutable_shape()); while_input->AppendOperand(new_operand); - AppendOriginalValues(while_input, {new_operand}, - while_input->operand_count() - 1); // Update the body computation. HloComputation* body = while_instr->while_body(); *body->parameter_instruction(0)->mutable_shape() = while_input->shape(); @@ -149,8 +122,8 @@ absl::StatusOr AppendToWhileState( TF_RETURN_IF_ERROR( UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1)); *while_instr->mutable_shape() = while_input->shape(); - AppendOriginalValues(while_instr, {new_operand}, - while_input->operand_count() - 1); + // The new body root tuple element has the same value as the new operand. + AppendToWhileLoopOriginalValue(while_instr, {new_operand}); return new_gte; } @@ -499,11 +472,11 @@ absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( root->AppendOperand(new_operands[i]); } *(init_value->mutable_shape()) = parameter->shape(); - AppendOriginalValues(init_value, fusion->operands(), - next_index - fusion->operand_count()); *(while_instr->mutable_shape()) = parameter->shape(); - AppendOriginalValues(while_instr, fusion->operands(), - next_index - fusion->operand_count()); + // + // The new body root tuple elements have the same value as the fusion + // operands. + AppendToWhileLoopOriginalValue(while_instr, fusion->operands()); *(while_cond->parameter_instruction(0)->mutable_shape()) = parameter->shape(); *(root->mutable_shape()) = parameter->shape(); diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc b/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc index b5637f997e7..f5baf95b30d 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc @@ -471,11 +471,11 @@ ENTRY entry { HloInstruction* while_instr = FindInstruction(module.get(), "while"); ASSERT_NE(while_instr->original_value(), nullptr); EXPECT_EQ(while_instr->original_value()->ToString(), - "({\"while\" {0}}, {\"while\" {1}}, {\"parameter\"})"); + R"(({"while" {0}}, {"while" {1}}, {"parameter"}))"); HloInstruction* while_init = while_instr->while_init(); ASSERT_NE(while_init->original_value(), nullptr); EXPECT_EQ(while_init->original_value()->ToString(), - "({\"constant\"}, {\"mask\"}, {\"parameter\"})"); + R"(({"constant"}, {"mask"}, {"parameter"}))"); } TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) { @@ -518,11 +518,11 @@ TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) { HloInstruction* while_instr = FindInstruction(module.get(), "while"); ASSERT_NE(while_instr->original_value(), nullptr); EXPECT_EQ(while_instr->original_value()->ToString(), - "({\"while\" {0}}, {\"while\" {1}}, {\"while\" {2}}, {\"zero\"})"); + R"(({"while" {0}}, {"while" {1}}, {"while" {2}}, {"zero"}))"); HloInstruction* while_init = while_instr->while_init(); ASSERT_NE(while_init->original_value(), nullptr); EXPECT_EQ(while_init->original_value()->ToString(), - "({\"zero\"}, {\"zeros32\"}, {\"broadcast\"}, {\"zero\"})"); + R"(({"zero"}, {"zeros32"}, {"broadcast"}, {"zero"}))"); } } // namespace diff --git a/third_party/xla/xla/service/while_util.cc b/third_party/xla/xla/service/while_util.cc index 5149850241c..7e092d2bc26 100644 --- a/third_party/xla/xla/service/while_util.cc +++ b/third_party/xla/xla/service/while_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/while_util.h" +#include #include #include #include @@ -615,4 +616,70 @@ absl::Status WhileUtil::IncrementWhileLoopTripCount( return induction_var->ReplaceAllUsesWith(decremented_induction_var); } +void AppendToWhileLoopOriginalValue( + HloInstruction* while_instr, + const HloInstruction::InstructionVector& new_while_input_tuple_elements) { + auto append_to_original_value = [&](HloInstruction* instr, + int64_t next_index) { + std::shared_ptr old_original_value = instr->original_value(); + if (old_original_value != nullptr && + old_original_value->IsCompatibleWith(instr->shape())) { + return; + } + + // Returns if neither the instruction nor any of its new tuple elements have + // an original value. + if (old_original_value == nullptr) { + bool has_original_value = false; + std::for_each(new_while_input_tuple_elements.begin(), + new_while_input_tuple_elements.end(), + [&has_original_value](const HloInstruction* instr) { + has_original_value |= + (instr->original_value() != nullptr && + !instr->original_value()->IsEmpty()); + }); + if (!has_original_value) { + return; + } + } + + std::shared_ptr new_original_value = + std::make_shared(instr->shape()); + if (old_original_value != nullptr) { + if (!old_original_value->IsTuple()) { + new_original_value->mutable_tree()->CopySubtreeFrom( + old_original_value->tree(), {}, {0}); + } else { + for (auto& [shape_index, original_array] : old_original_value->tree()) { + *new_original_value->mutable_original_array(shape_index) = + original_array; + } + } + } + + for (int64_t i = 0; i < new_while_input_tuple_elements.size(); ++i) { + if (new_while_input_tuple_elements[i]->original_value() != nullptr) { + new_original_value->mutable_tree()->CopySubtreeFrom( + new_while_input_tuple_elements[i]->original_value()->tree(), {}, + {next_index + i}); + } + } + instr->set_original_value(new_original_value); + }; + + if (while_instr->opcode() != HloOpcode::kWhile) { + return; + } + const Shape& while_shape = while_instr->shape(); + if (!while_shape.IsTuple()) { + return; + } + // Calculates the start index for the new tuple elements in the new original + // value. + int64_t next_index = + while_shape.tuple_shapes().size() - new_while_input_tuple_elements.size(); + append_to_original_value(while_instr->while_init(), next_index); + append_to_original_value(while_instr, next_index); +} + } // namespace xla diff --git a/third_party/xla/xla/service/while_util.h b/third_party/xla/xla/service/while_util.h index c05234d28ae..7b2a820ced8 100644 --- a/third_party/xla/xla/service/while_util.h +++ b/third_party/xla/xla/service/while_util.h @@ -155,6 +155,15 @@ class WhileUtil { static absl::Status IncrementWhileLoopTripCount( const HloInstruction& while_instruction, int32_t increment); }; + +// This is a helper function to update the original value after some +// transformations append new elements to the while input tuple (or turn it into +// a tuple if it was not one before). It appends the original values of the +// new elements after existing children of the root node of the old original +// value. This is done for both the input and output of the loop respectively. +void AppendToWhileLoopOriginalValue( + HloInstruction* while_instr, + const HloInstruction::InstructionVector& new_while_input_tuple_elements); } // namespace xla #endif // XLA_SERVICE_WHILE_UTIL_H_