mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA][Numerics][HLO Value Tracking] Handle original values in while loop fusible sinking pass
This reconstructs the original value for while loops with a rewritten input/output shape during the pass. PiperOrigin-RevId: 822465131
This commit is contained in:
parent
add51a87c3
commit
95d3b6fe36
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -34,6 +35,7 @@ limitations under the License.
|
||||||
#include "xla/hlo/ir/hlo_computation.h"
|
#include "xla/hlo/ir/hlo_computation.h"
|
||||||
#include "xla/hlo/ir/hlo_instruction.h"
|
#include "xla/hlo/ir/hlo_instruction.h"
|
||||||
#include "xla/hlo/ir/hlo_opcode.h"
|
#include "xla/hlo/ir/hlo_opcode.h"
|
||||||
|
#include "xla/hlo/ir/hlo_original_value.h"
|
||||||
#include "xla/hlo/utils/hlo_query.h"
|
#include "xla/hlo/utils/hlo_query.h"
|
||||||
#include "xla/service/pattern_matcher.h"
|
#include "xla/service/pattern_matcher.h"
|
||||||
#include "xla/service/while_util.h"
|
#include "xla/service/while_util.h"
|
||||||
|
|
@ -95,6 +97,29 @@ absl::Status UpdateWhileUsesWithTuple(HloInstruction* while_instr,
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AppendOriginalValues(HloInstruction* instr,
|
||||||
|
const HloInstruction::InstructionVector& new_operands,
|
||||||
|
int64_t next_index) {
|
||||||
|
if (instr->original_value() != nullptr && !new_operands.empty()) {
|
||||||
|
std::shared_ptr<OriginalValue> old_original_value = instr->original_value(),
|
||||||
|
new_original_value =
|
||||||
|
std::make_shared<OriginalValue>(
|
||||||
|
instr->shape());
|
||||||
|
for (auto& [shape_index, original_array] : old_original_value->tree()) {
|
||||||
|
*new_original_value->mutable_tree()->mutable_element(shape_index) =
|
||||||
|
original_array;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < new_operands.size(); ++i) {
|
||||||
|
if (new_operands[i]->original_value() != nullptr) {
|
||||||
|
new_original_value->mutable_tree()->CopySubtreeFrom(
|
||||||
|
new_operands[i]->original_value()->tree(), {}, {next_index + i});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return instr->set_original_value(new_original_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Appends the given new operand to while input and update loops computations
|
// Appends the given new operand to while input and update loops computations
|
||||||
// and shape accordingly and returns the gte instruction within the body that
|
// and shape accordingly and returns the gte instruction within the body that
|
||||||
// represents the new operand.
|
// represents the new operand.
|
||||||
|
|
@ -105,6 +130,8 @@ absl::StatusOr<HloInstruction*> AppendToWhileState(
|
||||||
ShapeUtil::AppendShapeToTuple(new_operand->shape(),
|
ShapeUtil::AppendShapeToTuple(new_operand->shape(),
|
||||||
while_input->mutable_shape());
|
while_input->mutable_shape());
|
||||||
while_input->AppendOperand(new_operand);
|
while_input->AppendOperand(new_operand);
|
||||||
|
AppendOriginalValues(while_input, {new_operand},
|
||||||
|
while_input->operand_count() - 1);
|
||||||
// Update the body computation.
|
// Update the body computation.
|
||||||
HloComputation* body = while_instr->while_body();
|
HloComputation* body = while_instr->while_body();
|
||||||
*body->parameter_instruction(0)->mutable_shape() = while_input->shape();
|
*body->parameter_instruction(0)->mutable_shape() = while_input->shape();
|
||||||
|
|
@ -122,6 +149,9 @@ absl::StatusOr<HloInstruction*> AppendToWhileState(
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1));
|
UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1));
|
||||||
*while_instr->mutable_shape() = while_input->shape();
|
*while_instr->mutable_shape() = while_input->shape();
|
||||||
|
AppendOriginalValues(while_instr, {new_operand},
|
||||||
|
while_input->operand_count() - 1);
|
||||||
|
|
||||||
return new_gte;
|
return new_gte;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -459,6 +489,7 @@ absl::StatusOr<bool> WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop(
|
||||||
HloInstruction* parameter = while_body->parameter_instruction(0);
|
HloInstruction* parameter = while_body->parameter_instruction(0);
|
||||||
int64_t next_index = init_value->operand_count();
|
int64_t next_index = init_value->operand_count();
|
||||||
new_operands.resize(fusion->operand_count());
|
new_operands.resize(fusion->operand_count());
|
||||||
|
|
||||||
for (int64_t i = 0; i < fusion->operand_count(); ++i) {
|
for (int64_t i = 0; i < fusion->operand_count(); ++i) {
|
||||||
init_value->AppendOperand(fusion->mutable_operand(i));
|
init_value->AppendOperand(fusion->mutable_operand(i));
|
||||||
parameter->mutable_shape()->mutable_tuple_shapes()->push_back(
|
parameter->mutable_shape()->mutable_tuple_shapes()->push_back(
|
||||||
|
|
@ -468,10 +499,15 @@ absl::StatusOr<bool> WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop(
|
||||||
root->AppendOperand(new_operands[i]);
|
root->AppendOperand(new_operands[i]);
|
||||||
}
|
}
|
||||||
*(init_value->mutable_shape()) = parameter->shape();
|
*(init_value->mutable_shape()) = parameter->shape();
|
||||||
|
AppendOriginalValues(init_value, fusion->operands(),
|
||||||
|
next_index - fusion->operand_count());
|
||||||
*(while_instr->mutable_shape()) = parameter->shape();
|
*(while_instr->mutable_shape()) = parameter->shape();
|
||||||
|
AppendOriginalValues(while_instr, fusion->operands(),
|
||||||
|
next_index - fusion->operand_count());
|
||||||
*(while_cond->parameter_instruction(0)->mutable_shape()) =
|
*(while_cond->parameter_instruction(0)->mutable_shape()) =
|
||||||
parameter->shape();
|
parameter->shape();
|
||||||
*(root->mutable_shape()) = parameter->shape();
|
*(root->mutable_shape()) = parameter->shape();
|
||||||
|
|
||||||
auto cloned_fusion = while_body->AddInstruction(
|
auto cloned_fusion = while_body->AddInstruction(
|
||||||
fusion->CloneWithNewOperands(fusion->shape(), new_operands));
|
fusion->CloneWithNewOperands(fusion->shape(), new_operands));
|
||||||
TF_RETURN_IF_ERROR(fusion->parent()->RemoveInstruction(fusion));
|
TF_RETURN_IF_ERROR(fusion->parent()->RemoveInstruction(fusion));
|
||||||
|
|
@ -539,6 +575,7 @@ absl::StatusOr<bool> WhileLoopFusibleSinking::Run(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "absl/log/check.h"
|
#include "absl/log/check.h"
|
||||||
|
#include "xla/hlo/ir/hlo_instruction.h"
|
||||||
#include "xla/hlo/ir/hlo_module.h"
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
|
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
|
||||||
#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h"
|
#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h"
|
||||||
|
|
@ -429,5 +430,100 @@ TEST_F(WhileLoopFusibleSinkingTest, TestNoPlumbWithUnknonwnTripCount) {
|
||||||
EXPECT_FALSE(changed);
|
EXPECT_FALSE(changed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(WhileLoopFusibleSinkingTest, SinkMaskWithOriginalValue) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule ModuleWithWhile
|
||||||
|
|
||||||
|
body {
|
||||||
|
p_body = (f32[5,7],f32[5,7]) parameter(0)
|
||||||
|
p_body.0 = get-tuple-element(p_body), index=0
|
||||||
|
p_body.1 = get-tuple-element(p_body), index=1
|
||||||
|
|
||||||
|
add.0 = add(p_body.0, p_body.1)
|
||||||
|
ROOT root = tuple(add.0, p_body.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
condition {
|
||||||
|
p_cond = (f32[5,7],f32[5,7]) parameter(0)
|
||||||
|
ROOT result = pred[] constant(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
const_0 = f32[5,7] parameter(0), origin={{"constant"}}
|
||||||
|
p = f32[5] parameter(1), origin={{"parameter"}}
|
||||||
|
a = f32[5,7] iota(), iota_dimension=0
|
||||||
|
b = f32[5,7] iota(), iota_dimension=1
|
||||||
|
c = add(a, b)
|
||||||
|
d = f32[5,7] broadcast(p), dimensions={0}
|
||||||
|
mask = multiply(c,d), origin={{"mask"}}
|
||||||
|
while_init = tuple(const_0, mask), origin={({"constant"}, {"mask"})}
|
||||||
|
ROOT while = while(while_init), condition=condition, body=body, origin={({"while" {0}}, {"while" {1}})}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
|
WhileLoopFusibleSinking{}.Run(module.get()));
|
||||||
|
ASSERT_TRUE(changed);
|
||||||
|
|
||||||
|
HloInstruction* while_instr = FindInstruction(module.get(), "while");
|
||||||
|
ASSERT_NE(while_instr->original_value(), nullptr);
|
||||||
|
EXPECT_EQ(while_instr->original_value()->ToString(),
|
||||||
|
"({\"while\" {0}}, {\"while\" {1}}, {\"parameter\"})");
|
||||||
|
HloInstruction* while_init = while_instr->while_init();
|
||||||
|
ASSERT_NE(while_init->original_value(), nullptr);
|
||||||
|
EXPECT_EQ(while_init->original_value()->ToString(),
|
||||||
|
"({\"constant\"}, {\"mask\"}, {\"parameter\"})");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(WhileLoopFusibleSinkingTest, PlumbSingleBroadcastWithOriginalValue) {
|
||||||
|
const std::string hlo_string_before = R"(
|
||||||
|
HloModule test
|
||||||
|
|
||||||
|
loop.body {
|
||||||
|
loop_var.1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0)
|
||||||
|
get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0
|
||||||
|
get-tuple-element.2 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=1
|
||||||
|
get-tuple-element.3 = s32[4,3,5]{2,1,0} get-tuple-element(loop_var.1), index=2
|
||||||
|
bitcast.12855 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} bitcast(get-tuple-element.3)
|
||||||
|
add.40974 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.2, bitcast.12855)
|
||||||
|
constant.1 = s32[]{:T(128)} constant(1)
|
||||||
|
idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1)
|
||||||
|
ROOT tuple = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(idx, add.40974, get-tuple-element.3)
|
||||||
|
}
|
||||||
|
|
||||||
|
loop.condition {
|
||||||
|
loop_var.2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0)
|
||||||
|
get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(loop_var.2), index=0
|
||||||
|
constant.2 = s32[]{:T(128)} constant(4)
|
||||||
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %main {
|
||||||
|
param.1 = s32[4,3,5]{2,1,0} iota(), iota_dimension=0
|
||||||
|
zero = s32[]{:T(128)} constant(0), origin={{"zero"}}
|
||||||
|
zeros32 = s32[]{:T(128)} constant(0), origin={{"zeros32"}}
|
||||||
|
broadcast = s32[1,1,1,4,3,5]{5,4,3,2,1,0} broadcast(zeros32), origin={{"broadcast"}}
|
||||||
|
input = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(zero, broadcast, param.1), origin={({"zero"}, {"zeros32"}, {"broadcast"})}
|
||||||
|
ROOT while = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) while(input), condition=loop.condition, body=loop.body, origin={({"while" {0}}, {"while" {1}}, {"while" {2}})}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string_before));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
|
WhileLoopFusibleSinking{}.Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
HloInstruction* while_instr = FindInstruction(module.get(), "while");
|
||||||
|
ASSERT_NE(while_instr->original_value(), nullptr);
|
||||||
|
EXPECT_EQ(while_instr->original_value()->ToString(),
|
||||||
|
"({\"while\" {0}}, {\"while\" {1}}, {\"while\" {2}}, {\"zero\"})");
|
||||||
|
HloInstruction* while_init = while_instr->while_init();
|
||||||
|
ASSERT_NE(while_init->original_value(), nullptr);
|
||||||
|
EXPECT_EQ(while_init->original_value()->ToString(),
|
||||||
|
"({\"zero\"}, {\"zeros32\"}, {\"broadcast\"}, {\"zero\"})");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user