[XLA] If an op has a single "large" operand, we want to fuse this op into some of its consumers, even if we can't fuse into all of them.

PiperOrigin-RevId: 157779106
This commit is contained in:
A. Unique TensorFlower 2017-06-01 16:24:40 -07:00 committed by TensorFlower Gardener
parent 2ee09b873a
commit 5bc685d7f1
2 changed files with 74 additions and 9 deletions

View File

@ -151,7 +151,26 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
return true;
};
if (std::all_of(hlo->users().begin(), hlo->users().end(),
// An "effectively unary" operation is one that has one "large"
// input with the others being negligible in terms of memory usage.
// We use "has a smaller true rank than the output" as a heuristic
// for "negligible" memory usage.
auto effectively_unary = [](HloInstruction* hlo) {
if (hlo->operands().size() == 1) {
return true;
}
auto output_rank = ShapeUtil::TrueRank(hlo->shape());
return std::count_if(
hlo->operands().begin(), hlo->operands().end(),
[output_rank](HloInstruction* operand) {
return ((operand->opcode() != HloOpcode::kBroadcast) &&
ShapeUtil::TrueRank(operand->shape()) >=
output_rank);
}) <= 1;
};
if (effectively_unary(hlo) ||
std::all_of(hlo->users().begin(), hlo->users().end(),
user_fusable_into_hlo)) {
all_consumers_fusable.insert(hlo);
}

View File

@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {16, 16}), "0"));
HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0));
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1));
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
auto param1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary2, computation->root_instruction());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_FALSE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
HloComputation::Builder builder(TestName());
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
HloInstruction* unary1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
HloInstruction* unary2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary2, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto small_shape = ShapeUtil::MakeShape(F32, {16});
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, small_shape, "0"));
auto param1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
} // namespace xla