mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[XLA] If an op has a single "large" operand, we want to fuse this op into some of its consumers, even if we can't fuse into all of them.
PiperOrigin-RevId: 157779106
This commit is contained in:
parent
2ee09b873a
commit
5bc685d7f1
|
|
@ -151,7 +151,26 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
|||
return true;
|
||||
};
|
||||
|
||||
if (std::all_of(hlo->users().begin(), hlo->users().end(),
|
||||
// An "effectively unary" operation is one that has one "large"
|
||||
// input with the others being negligible in terms of memory usage.
|
||||
// We use "has a smaller true rank than the output" as a heuristic
|
||||
// for "negligible" memory usage.
|
||||
auto effectively_unary = [](HloInstruction* hlo) {
|
||||
if (hlo->operands().size() == 1) {
|
||||
return true;
|
||||
}
|
||||
auto output_rank = ShapeUtil::TrueRank(hlo->shape());
|
||||
return std::count_if(
|
||||
hlo->operands().begin(), hlo->operands().end(),
|
||||
[output_rank](HloInstruction* operand) {
|
||||
return ((operand->opcode() != HloOpcode::kBroadcast) &&
|
||||
ShapeUtil::TrueRank(operand->shape()) >=
|
||||
output_rank);
|
||||
}) <= 1;
|
||||
};
|
||||
|
||||
if (effectively_unary(hlo) ||
|
||||
std::all_of(hlo->users().begin(), hlo->users().end(),
|
||||
user_fusable_into_hlo)) {
|
||||
all_consumers_fusable.insert(hlo);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
|
|||
|
||||
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {16, 16}), "0"));
|
||||
HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0));
|
||||
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
|
||||
HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1));
|
||||
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||
auto param0 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
|
||||
auto param1 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
|
||||
HloInstruction* binary1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
|
||||
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
|
||||
HloInstruction* unary = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(unary2, computation->root_instruction());
|
||||
EXPECT_EQ(unary, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||
auto param0 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
|
||||
HloInstruction* unary1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
|
||||
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
|
||||
HloInstruction* unary2 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(unary2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||
auto small_shape = ShapeUtil::MakeShape(F32, {16});
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, small_shape, "0"));
|
||||
auto param1 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
|
||||
HloInstruction* binary1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
|
||||
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
|
||||
HloInstruction* unary = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(unary, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user