mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA][Numerics][HLO Original Values] Handles original values of while loops in TPU reduce code motion pass
This updates the original value of a while loop after its input/output shape gets changed after the pass sinks qualified reduce instructions into its body. PiperOrigin-RevId: 826618908
This commit is contained in:
parent
eef0661fc5
commit
a6e123761d
6
third_party/xla/xla/hlo/ir/hlo_instruction.h
vendored
6
third_party/xla/xla/hlo/ir/hlo_instruction.h
vendored
|
|
@ -2467,7 +2467,7 @@ class alignas(kInstructionTypeMask + 1) HloInstruction {
|
||||||
std::shared_ptr<OriginalValue> original_value() const;
|
std::shared_ptr<OriginalValue> original_value() const;
|
||||||
void set_original_value(std::shared_ptr<OriginalValue> original_value);
|
void set_original_value(std::shared_ptr<OriginalValue> original_value);
|
||||||
|
|
||||||
// Copy original value from the input instruction if the source and
|
// Copies original value from the input instruction if the source and
|
||||||
// destination shapes are compatible. This performs a deep copy if clone is
|
// destination shapes are compatible. This performs a deep copy if clone is
|
||||||
// set to true. Otherwise, it performs a shallow copy. Print a warning if the
|
// set to true. Otherwise, it performs a shallow copy. Print a warning if the
|
||||||
// shapes are not compatible and issue_warning is set to true.
|
// shapes are not compatible and issue_warning is set to true.
|
||||||
|
|
@ -2475,8 +2475,8 @@ class alignas(kInstructionTypeMask + 1) HloInstruction {
|
||||||
bool issue_warning = false);
|
bool issue_warning = false);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Internal constructor for a given opcode/shape, other fields must be filled
|
// Internal constructor for a given opcode/shape, other fields must be
|
||||||
// by factory methods.
|
// filled by factory methods.
|
||||||
HloInstruction(HloOpcode opcode, const Shape& shape);
|
HloInstruction(HloOpcode opcode, const Shape& shape);
|
||||||
|
|
||||||
void RemoveAllOperands() { operands_.clear(); }
|
void RemoveAllOperands() { operands_.clear(); }
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,8 @@ class OriginalValue {
|
||||||
|
|
||||||
bool IsCompatibleWith(const Shape& shape) const;
|
bool IsCompatibleWith(const Shape& shape) const;
|
||||||
|
|
||||||
|
bool IsTuple() const { return tree().IsTuple(); }
|
||||||
|
|
||||||
bool operator==(const OriginalValue& other) const;
|
bool operator==(const OriginalValue& other) const;
|
||||||
|
|
||||||
bool operator!=(const OriginalValue& other) const {
|
bool operator!=(const OriginalValue& other) const {
|
||||||
|
|
|
||||||
3
third_party/xla/xla/service/BUILD
vendored
3
third_party/xla/xla/service/BUILD
vendored
|
|
@ -4639,6 +4639,7 @@ xla_cc_test(
|
||||||
srcs = ["while_util_test.cc"],
|
srcs = ["while_util_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":while_util",
|
":while_util",
|
||||||
|
"//xla:shape_util",
|
||||||
"//xla:util",
|
"//xla:util",
|
||||||
"//xla/hlo/ir:hlo",
|
"//xla/hlo/ir:hlo",
|
||||||
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
|
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
|
||||||
|
|
@ -4906,8 +4907,6 @@ cc_library(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings:string_view",
|
"@com_google_absl//absl/strings:string_view",
|
||||||
"@local_tsl//tsl/platform:errors",
|
|
||||||
"@local_tsl//tsl/platform:statusor",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,6 @@ limitations under the License.
|
||||||
#include "xla/shape_util.h"
|
#include "xla/shape_util.h"
|
||||||
#include "xla/util.h"
|
#include "xla/util.h"
|
||||||
#include "xla/xla_data.pb.h"
|
#include "xla/xla_data.pb.h"
|
||||||
#include "tsl/platform/errors.h"
|
|
||||||
#include "tsl/platform/statusor.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
|
@ -97,29 +95,6 @@ absl::Status UpdateWhileUsesWithTuple(HloInstruction* while_instr,
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AppendOriginalValues(HloInstruction* instr,
|
|
||||||
const HloInstruction::InstructionVector& new_operands,
|
|
||||||
int64_t next_index) {
|
|
||||||
if (instr->original_value() != nullptr && !new_operands.empty()) {
|
|
||||||
std::shared_ptr<OriginalValue> old_original_value = instr->original_value(),
|
|
||||||
new_original_value =
|
|
||||||
std::make_shared<OriginalValue>(
|
|
||||||
instr->shape());
|
|
||||||
for (auto& [shape_index, original_array] : old_original_value->tree()) {
|
|
||||||
*new_original_value->mutable_tree()->mutable_element(shape_index) =
|
|
||||||
original_array;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < new_operands.size(); ++i) {
|
|
||||||
if (new_operands[i]->original_value() != nullptr) {
|
|
||||||
new_original_value->mutable_tree()->CopySubtreeFrom(
|
|
||||||
new_operands[i]->original_value()->tree(), {}, {next_index + i});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return instr->set_original_value(new_original_value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Appends the given new operand to while input and update loops computations
|
// Appends the given new operand to while input and update loops computations
|
||||||
// and shape accordingly and returns the gte instruction within the body that
|
// and shape accordingly and returns the gte instruction within the body that
|
||||||
// represents the new operand.
|
// represents the new operand.
|
||||||
|
|
@ -130,8 +105,6 @@ absl::StatusOr<HloInstruction*> AppendToWhileState(
|
||||||
ShapeUtil::AppendShapeToTuple(new_operand->shape(),
|
ShapeUtil::AppendShapeToTuple(new_operand->shape(),
|
||||||
while_input->mutable_shape());
|
while_input->mutable_shape());
|
||||||
while_input->AppendOperand(new_operand);
|
while_input->AppendOperand(new_operand);
|
||||||
AppendOriginalValues(while_input, {new_operand},
|
|
||||||
while_input->operand_count() - 1);
|
|
||||||
// Update the body computation.
|
// Update the body computation.
|
||||||
HloComputation* body = while_instr->while_body();
|
HloComputation* body = while_instr->while_body();
|
||||||
*body->parameter_instruction(0)->mutable_shape() = while_input->shape();
|
*body->parameter_instruction(0)->mutable_shape() = while_input->shape();
|
||||||
|
|
@ -149,8 +122,8 @@ absl::StatusOr<HloInstruction*> AppendToWhileState(
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1));
|
UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1));
|
||||||
*while_instr->mutable_shape() = while_input->shape();
|
*while_instr->mutable_shape() = while_input->shape();
|
||||||
AppendOriginalValues(while_instr, {new_operand},
|
// The new body root tuple element has the same value as the new operand.
|
||||||
while_input->operand_count() - 1);
|
AppendToWhileLoopOriginalValue(while_instr, {new_operand});
|
||||||
|
|
||||||
return new_gte;
|
return new_gte;
|
||||||
}
|
}
|
||||||
|
|
@ -499,11 +472,11 @@ absl::StatusOr<bool> WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop(
|
||||||
root->AppendOperand(new_operands[i]);
|
root->AppendOperand(new_operands[i]);
|
||||||
}
|
}
|
||||||
*(init_value->mutable_shape()) = parameter->shape();
|
*(init_value->mutable_shape()) = parameter->shape();
|
||||||
AppendOriginalValues(init_value, fusion->operands(),
|
|
||||||
next_index - fusion->operand_count());
|
|
||||||
*(while_instr->mutable_shape()) = parameter->shape();
|
*(while_instr->mutable_shape()) = parameter->shape();
|
||||||
AppendOriginalValues(while_instr, fusion->operands(),
|
//
|
||||||
next_index - fusion->operand_count());
|
// The new body root tuple elements have the same value as the fusion
|
||||||
|
// operands.
|
||||||
|
AppendToWhileLoopOriginalValue(while_instr, fusion->operands());
|
||||||
*(while_cond->parameter_instruction(0)->mutable_shape()) =
|
*(while_cond->parameter_instruction(0)->mutable_shape()) =
|
||||||
parameter->shape();
|
parameter->shape();
|
||||||
*(root->mutable_shape()) = parameter->shape();
|
*(root->mutable_shape()) = parameter->shape();
|
||||||
|
|
|
||||||
|
|
@ -471,11 +471,11 @@ ENTRY entry {
|
||||||
HloInstruction* while_instr = FindInstruction(module.get(), "while");
|
HloInstruction* while_instr = FindInstruction(module.get(), "while");
|
||||||
ASSERT_NE(while_instr->original_value(), nullptr);
|
ASSERT_NE(while_instr->original_value(), nullptr);
|
||||||
EXPECT_EQ(while_instr->original_value()->ToString(),
|
EXPECT_EQ(while_instr->original_value()->ToString(),
|
||||||
"({\"while\" {0}}, {\"while\" {1}}, {\"parameter\"})");
|
R"(({"while" {0}}, {"while" {1}}, {"parameter"}))");
|
||||||
HloInstruction* while_init = while_instr->while_init();
|
HloInstruction* while_init = while_instr->while_init();
|
||||||
ASSERT_NE(while_init->original_value(), nullptr);
|
ASSERT_NE(while_init->original_value(), nullptr);
|
||||||
EXPECT_EQ(while_init->original_value()->ToString(),
|
EXPECT_EQ(while_init->original_value()->ToString(),
|
||||||
"({\"constant\"}, {\"mask\"}, {\"parameter\"})");
|
R"(({"constant"}, {"mask"}, {"parameter"}))");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) {
|
TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) {
|
||||||
|
|
@ -518,11 +518,11 @@ TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) {
|
||||||
HloInstruction* while_instr = FindInstruction(module.get(), "while");
|
HloInstruction* while_instr = FindInstruction(module.get(), "while");
|
||||||
ASSERT_NE(while_instr->original_value(), nullptr);
|
ASSERT_NE(while_instr->original_value(), nullptr);
|
||||||
EXPECT_EQ(while_instr->original_value()->ToString(),
|
EXPECT_EQ(while_instr->original_value()->ToString(),
|
||||||
"({\"while\" {0}}, {\"while\" {1}}, {\"while\" {2}}, {\"zero\"})");
|
R"(({"while" {0}}, {"while" {1}}, {"while" {2}}, {"zero"}))");
|
||||||
HloInstruction* while_init = while_instr->while_init();
|
HloInstruction* while_init = while_instr->while_init();
|
||||||
ASSERT_NE(while_init->original_value(), nullptr);
|
ASSERT_NE(while_init->original_value(), nullptr);
|
||||||
EXPECT_EQ(while_init->original_value()->ToString(),
|
EXPECT_EQ(while_init->original_value()->ToString(),
|
||||||
"({\"zero\"}, {\"zeros32\"}, {\"broadcast\"}, {\"zero\"})");
|
R"(({"zero"}, {"zeros32"}, {"broadcast"}, {"zero"}))");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
67
third_party/xla/xla/service/while_util.cc
vendored
67
third_party/xla/xla/service/while_util.cc
vendored
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "xla/service/while_util.h"
|
#include "xla/service/while_util.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
|
@ -615,4 +616,70 @@ absl::Status WhileUtil::IncrementWhileLoopTripCount(
|
||||||
return induction_var->ReplaceAllUsesWith(decremented_induction_var);
|
return induction_var->ReplaceAllUsesWith(decremented_induction_var);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AppendToWhileLoopOriginalValue(
|
||||||
|
HloInstruction* while_instr,
|
||||||
|
const HloInstruction::InstructionVector& new_while_input_tuple_elements) {
|
||||||
|
auto append_to_original_value = [&](HloInstruction* instr,
|
||||||
|
int64_t next_index) {
|
||||||
|
std::shared_ptr<OriginalValue> old_original_value = instr->original_value();
|
||||||
|
if (old_original_value != nullptr &&
|
||||||
|
old_original_value->IsCompatibleWith(instr->shape())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns if neither the instruction nor any of its new tuple elements have
|
||||||
|
// an original value.
|
||||||
|
if (old_original_value == nullptr) {
|
||||||
|
bool has_original_value = false;
|
||||||
|
std::for_each(new_while_input_tuple_elements.begin(),
|
||||||
|
new_while_input_tuple_elements.end(),
|
||||||
|
[&has_original_value](const HloInstruction* instr) {
|
||||||
|
has_original_value |=
|
||||||
|
(instr->original_value() != nullptr &&
|
||||||
|
!instr->original_value()->IsEmpty());
|
||||||
|
});
|
||||||
|
if (!has_original_value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<OriginalValue> new_original_value =
|
||||||
|
std::make_shared<OriginalValue>(instr->shape());
|
||||||
|
if (old_original_value != nullptr) {
|
||||||
|
if (!old_original_value->IsTuple()) {
|
||||||
|
new_original_value->mutable_tree()->CopySubtreeFrom(
|
||||||
|
old_original_value->tree(), {}, {0});
|
||||||
|
} else {
|
||||||
|
for (auto& [shape_index, original_array] : old_original_value->tree()) {
|
||||||
|
*new_original_value->mutable_original_array(shape_index) =
|
||||||
|
original_array;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < new_while_input_tuple_elements.size(); ++i) {
|
||||||
|
if (new_while_input_tuple_elements[i]->original_value() != nullptr) {
|
||||||
|
new_original_value->mutable_tree()->CopySubtreeFrom(
|
||||||
|
new_while_input_tuple_elements[i]->original_value()->tree(), {},
|
||||||
|
{next_index + i});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
instr->set_original_value(new_original_value);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (while_instr->opcode() != HloOpcode::kWhile) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const Shape& while_shape = while_instr->shape();
|
||||||
|
if (!while_shape.IsTuple()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Calculates the start index for the new tuple elements in the new original
|
||||||
|
// value.
|
||||||
|
int64_t next_index =
|
||||||
|
while_shape.tuple_shapes().size() - new_while_input_tuple_elements.size();
|
||||||
|
append_to_original_value(while_instr->while_init(), next_index);
|
||||||
|
append_to_original_value(while_instr, next_index);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
9
third_party/xla/xla/service/while_util.h
vendored
9
third_party/xla/xla/service/while_util.h
vendored
|
|
@ -155,6 +155,15 @@ class WhileUtil {
|
||||||
static absl::Status IncrementWhileLoopTripCount(
|
static absl::Status IncrementWhileLoopTripCount(
|
||||||
const HloInstruction& while_instruction, int32_t increment);
|
const HloInstruction& while_instruction, int32_t increment);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This is a helper function to update the original value after some
|
||||||
|
// transformations append new elements to the while input tuple (or turn it into
|
||||||
|
// a tuple if it was not one before). It appends the original values of the
|
||||||
|
// new elements after existing children of the root node of the old original
|
||||||
|
// value. This is done for both the input and output of the loop respectively.
|
||||||
|
void AppendToWhileLoopOriginalValue(
|
||||||
|
HloInstruction* while_instr,
|
||||||
|
const HloInstruction::InstructionVector& new_while_input_tuple_elements);
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // XLA_SERVICE_WHILE_UTIL_H_
|
#endif // XLA_SERVICE_WHILE_UTIL_H_
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user