mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA][Numerics][HLO Original Value] Support original values for some cases in while loop simplifier pass
This updates the original value of a while loop if any unused parameters are removed. PiperOrigin-RevId: 825221785
This commit is contained in:
parent
2feb74eeff
commit
66078903f7
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/hlo/ir/hlo_original_value.h"
|
||||
#include "xla/hlo/ir/hlo_print_options.h"
|
||||
#include "xla/hlo/transforms/simplifiers/hlo_dce.h"
|
||||
#include "xla/hlo/utils/hlo_query.h"
|
||||
|
|
@ -150,6 +151,35 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
|
|||
HloInstruction* while_op, absl::flat_hash_set<int64_t>& used_tuple_indices,
|
||||
std::optional<absl::flat_hash_map<int32_t, int32_t>>
|
||||
dead_to_surviving_index = std::nullopt) {
|
||||
auto copy_original_value =
|
||||
[&](const HloInstruction* src_instruction,
|
||||
HloInstruction* dest_instruction,
|
||||
const absl::flat_hash_map<int64_t, int64_t>& old_to_new_tuple_idx) {
|
||||
std::shared_ptr<OriginalValue> original_value =
|
||||
src_instruction->original_value();
|
||||
if (!original_value) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t src_tuple_size =
|
||||
src_instruction->shape().tuple_shapes().size(),
|
||||
dest_tuple_size =
|
||||
dest_instruction->shape().tuple_shapes().size();
|
||||
std::shared_ptr<OriginalValue> old_original_value =
|
||||
src_instruction->original_value();
|
||||
std::shared_ptr<xla::OriginalValue> new_original_value =
|
||||
std::make_shared<xla::OriginalValue>(dest_instruction->shape());
|
||||
for (const auto& [old_idx, new_idx] : old_to_new_tuple_idx) {
|
||||
if (old_idx < 0 || old_idx >= src_tuple_size || new_idx < 0 ||
|
||||
new_idx >= dest_tuple_size) {
|
||||
return;
|
||||
}
|
||||
new_original_value->mutable_tree()->CopySubtreeFrom(
|
||||
old_original_value->tree(), {old_idx}, {new_idx});
|
||||
}
|
||||
dest_instruction->set_original_value(new_original_value);
|
||||
};
|
||||
|
||||
// Build up maps from the old/new to the new/old tuple indices.
|
||||
std::vector<int64_t> new_to_old_tuple_idx(used_tuple_indices.begin(),
|
||||
used_tuple_indices.end());
|
||||
|
|
@ -275,6 +305,9 @@ static absl::StatusOr<HloInstruction*> RemoveDeadTupleIndices(
|
|||
CopyFrontendAttributes(while_op, new_while_op);
|
||||
CopyMetadata(while_op, new_while_op);
|
||||
|
||||
copy_original_value(while_init, new_while_init, old_to_new_tuple_idx);
|
||||
copy_original_value(while_op, new_while_op, old_to_new_tuple_idx);
|
||||
|
||||
// Create a tuple op that recreates the output of the old while op. That is,
|
||||
// we transform to
|
||||
//
|
||||
|
|
|
|||
|
|
@ -1481,5 +1481,59 @@ ENTRY %main (arg.0: f32[3], arg.1: f32[2]) -> (f32[3], f32[2], f32[2], f32[3]) {
|
|||
op::GetTupleElement(op::While(), 0)));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopSimplifierTest, RemoveDeadTupleIndicesWithOriginalValue) {
|
||||
const std::string hlo_string = R"(
|
||||
HloModule dus
|
||||
|
||||
%while.body (arg_tuple: (f32[3], f32[2], f32[2], f32[3], s32[])) -> (f32[3], f32[2], f32[2], f32[3], s32[]) {
|
||||
%arg_tuple = (f32[3], f32[2], f32[2], f32[3], s32[]) parameter(0)
|
||||
%get-tuple-element.0 = f32[3] get-tuple-element(%arg_tuple), index=0
|
||||
%get-tuple-element.1 = f32[2] get-tuple-element(%arg_tuple), index=1
|
||||
%get-tuple-element.2 = f32[2] get-tuple-element(%arg_tuple), index=2
|
||||
%get-tuple-element.3 = f32[3] get-tuple-element(%arg_tuple), index=3
|
||||
%get-tuple-element.4 = s32[] get-tuple-element(%arg_tuple), index=4
|
||||
%constant.1 = s32[] constant(1)
|
||||
%constant.v0 = f32[1] constant({0.0})
|
||||
%constant.v1 = f32[1] constant({1.0})
|
||||
%dynamic-update-slice.0 = f32[3] dynamic-update-slice(%get-tuple-element.0, %constant.v0, s32[] %constant.1)
|
||||
%dynamic-update-slice.3 = f32[3] dynamic-update-slice(%get-tuple-element.3, %constant.v0, s32[] %constant.1)
|
||||
%add = add(s32[] %get-tuple-element.4, s32[] %constant.1)
|
||||
ROOT %tuple = tuple(%dynamic-update-slice.0, %get-tuple-element.1, %get-tuple-element.2, %dynamic-update-slice.3, %add)
|
||||
}
|
||||
|
||||
%while.condition (arg_tuple.cond:(f32[3], f32[2], f32[2], f32[3], s32[])) -> pred[] {
|
||||
%arg_tuple.cond = (f32[3], f32[2], f32[2], f32[3], s32[]) parameter(0)
|
||||
%get-tuple-element.cond = s32[] get-tuple-element(%arg_tuple.cond), index=4
|
||||
%constant.3 = s32[] constant(3)
|
||||
ROOT %compare = pred[] compare(s32[] %get-tuple-element.cond, s32[] %constant.3), direction=LT
|
||||
}
|
||||
|
||||
ENTRY %main (arg.0: f32[3], arg.1: f32[2]) -> (f32[3], f32[2], f32[2], f32[3]) {
|
||||
%constant.0 = s32[] constant(0)
|
||||
%arg.0 = f32[3] parameter(0)
|
||||
%arg.1 = f32[2] parameter(1)
|
||||
%input = tuple(%arg.0, %arg.1, %arg.1, %arg.0, %constant.0), origin={({"arg.0"}, {"arg.1"}, {"arg.1"}, {"arg.0"}, {"constant.0"})}
|
||||
%while = while(%input), condition=%while.condition, body=%while.body, origin={({"while.116" {0}}, {"while.116" {1}}, {"while.116" {2}}, {"while.116" {3}}, {"while.116" {4}})}
|
||||
%get-tuple-element.out0 = f32[3] get-tuple-element(%while), index=0
|
||||
%get-tuple-element.out1 = f32[2] get-tuple-element(%while), index=1
|
||||
%get-tuple-element.out2 = f32[2] get-tuple-element(%while), index=2
|
||||
%get-tuple-element.out3 = f32[3] get-tuple-element(%while), index=3
|
||||
ROOT %root = tuple(%get-tuple-element.out0, %get-tuple-element.out1, %get-tuple-element.out2, %get-tuple-element.out3)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
ASSERT_TRUE(WhileLoopSimplifier().Run(module.get()).value());
|
||||
HloInstruction* while_instr = FindFirstWhile(module.get());
|
||||
ASSERT_NE(while_instr->original_value(), nullptr);
|
||||
EXPECT_EQ(while_instr->original_value()->ToString(),
|
||||
R"(({"while.116" {0}}, {"while.116" {1}}, {"while.116" {4}}))");
|
||||
HloInstruction* while_init = while_instr->while_init();
|
||||
ASSERT_NE(while_init->original_value(), nullptr);
|
||||
EXPECT_EQ(while_init->original_value()->ToString(),
|
||||
R"(({"arg.0"}, {"arg.1"}, {"constant.0"}))");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user