mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[XLA] If an op has a single "large" operand, we want to fuse this op into some of its consumers, even if we can't fuse into all of them.
PiperOrigin-RevId: 157779106
This commit is contained in:
parent
2ee09b873a
commit
5bc685d7f1
|
|
@ -151,7 +151,26 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (std::all_of(hlo->users().begin(), hlo->users().end(),
|
// An "effectively unary" operation is one that has one "large"
|
||||||
|
// input with the others being negligible in terms of memory usage.
|
||||||
|
// We use "has a smaller true rank than the output" as a heuristic
|
||||||
|
// for "negligible" memory usage.
|
||||||
|
auto effectively_unary = [](HloInstruction* hlo) {
|
||||||
|
if (hlo->operands().size() == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto output_rank = ShapeUtil::TrueRank(hlo->shape());
|
||||||
|
return std::count_if(
|
||||||
|
hlo->operands().begin(), hlo->operands().end(),
|
||||||
|
[output_rank](HloInstruction* operand) {
|
||||||
|
return ((operand->opcode() != HloOpcode::kBroadcast) &&
|
||||||
|
ShapeUtil::TrueRank(operand->shape()) >=
|
||||||
|
output_rank);
|
||||||
|
}) <= 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (effectively_unary(hlo) ||
|
||||||
|
std::all_of(hlo->users().begin(), hlo->users().end(),
|
||||||
user_fusable_into_hlo)) {
|
user_fusable_into_hlo)) {
|
||||||
all_consumers_fusable.insert(hlo);
|
all_consumers_fusable.insert(hlo);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -156,21 +156,67 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
|
||||||
|
|
||||||
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
|
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||||
0, ShapeUtil::MakeShape(F32, {16, 16}), "0"));
|
auto param0 =
|
||||||
HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
|
||||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0));
|
auto param1 =
|
||||||
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
|
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
|
||||||
HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary(
|
HloInstruction* binary1 = builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1));
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
|
||||||
|
HloInstruction* unary = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
|
||||||
|
|
||||||
auto module = MakeUnique<HloModule>(TestName());
|
auto module = MakeUnique<HloModule>(TestName());
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
EXPECT_EQ(unary2, computation->root_instruction());
|
EXPECT_EQ(unary, computation->root_instruction());
|
||||||
EXPECT_FALSE(
|
EXPECT_FALSE(
|
||||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||||
.Run(module.get())
|
.Run(module.get())
|
||||||
.ValueOrDie());
|
.ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||||
|
auto param0 =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
|
||||||
|
HloInstruction* unary1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
|
||||||
|
HloInstruction* unary2 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
|
||||||
|
|
||||||
|
auto module = MakeUnique<HloModule>(TestName());
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
EXPECT_EQ(unary2, computation->root_instruction());
|
||||||
|
EXPECT_TRUE(
|
||||||
|
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||||
|
.Run(module.get())
|
||||||
|
.ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
|
||||||
|
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||||
|
auto small_shape = ShapeUtil::MakeShape(F32, {16});
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
auto param0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, small_shape, "0"));
|
||||||
|
auto param1 =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
|
||||||
|
HloInstruction* binary1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
|
||||||
|
HloInstruction* unary = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
|
||||||
|
|
||||||
|
auto module = MakeUnique<HloModule>(TestName());
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
EXPECT_EQ(unary, computation->root_instruction());
|
||||||
|
EXPECT_TRUE(
|
||||||
|
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||||
|
.Run(module.get())
|
||||||
|
.ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user