mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[XLA] Handle higher-order HLOs (e.g. While) in CallInliner and test.
PiperOrigin-RevId: 168029345
This commit is contained in:
parent
8988ae365f
commit
f83f6b9ef1
|
|
@ -20,30 +20,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
namespace {
|
||||||
StatusOr<bool> CallInliner::Run(HloModule* module) {
|
|
||||||
std::deque<HloInstruction*> work_queue;
|
|
||||||
|
|
||||||
// Seed the work queue with call instructions from the main computation.
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
module->entry_computation()->Accept([&](HloInstruction* hlo) {
|
|
||||||
if (hlo->opcode() == HloOpcode::kCall) {
|
|
||||||
work_queue.push_back(hlo);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
|
|
||||||
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
|
|
||||||
|
|
||||||
bool mutated = false;
|
|
||||||
while (!work_queue.empty()) {
|
|
||||||
mutated = true;
|
|
||||||
HloInstruction* call = work_queue.front();
|
|
||||||
work_queue.pop_front();
|
|
||||||
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue));
|
|
||||||
}
|
|
||||||
return mutated;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Traverses the callee computation, inlining cloned nodes into the caller
|
// Traverses the callee computation, inlining cloned nodes into the caller
|
||||||
// computation and connecting them to producers/consumers appropriately.
|
// computation and connecting them to producers/consumers appropriately.
|
||||||
|
|
@ -141,6 +118,64 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||||
std::deque<HloInstruction*>* work_queue_;
|
std::deque<HloInstruction*>* work_queue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
StatusOr<bool> CallInliner::Run(HloModule* module) {
|
||||||
|
std::deque<HloInstruction*> work_queue;
|
||||||
|
tensorflow::gtl::FlatSet<HloComputation*> seen;
|
||||||
|
|
||||||
|
auto scan_computation = [&work_queue,
|
||||||
|
&seen](HloComputation* computation) -> Status {
|
||||||
|
if (!seen.insert(computation).second) {
|
||||||
|
return Status::OK(); // Already seen.
|
||||||
|
}
|
||||||
|
return computation->Accept([&](HloInstruction* hlo) {
|
||||||
|
if (!hlo->called_computations().empty()) {
|
||||||
|
work_queue.push_back(hlo);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// Seed the work queue with call instructions from the main computation.
|
||||||
|
TF_RETURN_IF_ERROR(scan_computation(module->entry_computation()));
|
||||||
|
|
||||||
|
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
|
||||||
|
|
||||||
|
bool mutated = false;
|
||||||
|
while (!work_queue.empty()) {
|
||||||
|
HloInstruction* caller = work_queue.front();
|
||||||
|
work_queue.pop_front();
|
||||||
|
switch (caller->opcode()) {
|
||||||
|
case HloOpcode::kCall:
|
||||||
|
mutated = true;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(caller, &work_queue));
|
||||||
|
break;
|
||||||
|
case HloOpcode::kWhile:
|
||||||
|
TF_RETURN_IF_ERROR(scan_computation(caller->while_condition()));
|
||||||
|
TF_RETURN_IF_ERROR(scan_computation(caller->while_body()));
|
||||||
|
break;
|
||||||
|
case HloOpcode::kSelectAndScatter:
|
||||||
|
TF_RETURN_IF_ERROR(scan_computation(caller->select()));
|
||||||
|
TF_RETURN_IF_ERROR(scan_computation(caller->scatter()));
|
||||||
|
break;
|
||||||
|
case HloOpcode::kMap:
|
||||||
|
case HloOpcode::kReduceWindow:
|
||||||
|
case HloOpcode::kReduce:
|
||||||
|
TF_RETURN_IF_ERROR(scan_computation(caller->to_apply()));
|
||||||
|
break;
|
||||||
|
case HloOpcode::kFusion:
|
||||||
|
// Fusion nodes don't represent true calls, but instead delimit a
|
||||||
|
// boundary for the backend-specific fusion capabilities.
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return Unimplemented("Unknown higher-order HLO opcode: %s",
|
||||||
|
caller->ToString().c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mutated;
|
||||||
|
}
|
||||||
|
|
||||||
Status CallInliner::ReplaceWithInlinedBody(
|
Status CallInliner::ReplaceWithInlinedBody(
|
||||||
HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
|
HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
|
||||||
TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
|
TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
|
||||||
|
|
|
||||||
|
|
@ -73,5 +73,44 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
||||||
EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
|
EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests for referential transparency (a function that calls a function that
|
||||||
|
// returns false should be identical to just returning false).
|
||||||
|
TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
|
||||||
|
const Shape pred = ShapeUtil::MakeShape(PRED, {});
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
|
||||||
|
// Create a lambda that calls a function that returns the false predicate.
|
||||||
|
// Note we also use this lambda twice by reference, just to make the test a
|
||||||
|
// little trickier.
|
||||||
|
HloComputation::Builder just_false(TestName() + ".false");
|
||||||
|
just_false.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||||
|
HloComputation* false_computation =
|
||||||
|
module->AddEmbeddedComputation(just_false.Build());
|
||||||
|
|
||||||
|
HloComputation::Builder call_false_builder(TestName() + ".call_false");
|
||||||
|
call_false_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateCall(pred, {}, false_computation));
|
||||||
|
HloComputation* call_false =
|
||||||
|
module->AddEmbeddedComputation(call_false_builder.Build());
|
||||||
|
|
||||||
|
HloComputation::Builder outer(TestName() + ".outer");
|
||||||
|
HloInstruction* init_value = outer.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||||
|
outer.AddInstruction(
|
||||||
|
HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
|
||||||
|
|
||||||
|
auto computation = module->AddEntryComputation(outer.Build());
|
||||||
|
|
||||||
|
CallInliner call_inliner;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
|
||||||
|
ASSERT_TRUE(mutated);
|
||||||
|
EXPECT_THAT(
|
||||||
|
computation->root_instruction()->while_condition()->root_instruction(),
|
||||||
|
op::Constant());
|
||||||
|
EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(),
|
||||||
|
op::Constant());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user