[XLA] DOT dumper: Handle fusion nodes nested inside other nodes (e.g. map).

PiperOrigin-RevId: 173752314
This commit is contained in:
Justin Lebar 2017-10-27 22:43:46 -07:00 committed by TensorFlower Gardener
parent 8ec7540e00
commit e7645b6295

View File

@ -360,6 +360,21 @@ class HloDotDumper {
string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
void AddInstructionIncomingEdges(const HloInstruction* instr);
// For most instructions, GetNodeForEdge(instr) returns instr.
//
// The exception is fusion nodes. For these, we walk up the chain of nested
// fusion nodes starting at instr until we reach a node that either (a) isn't
// a fusion node, or (b) is a fusion node for which
// ShouldShowFusionSubcomputation is false.
//
// We do this because fusion nodes are expanded inline -- if
// ShouldShowFusionSubcomputation is true, the fusion node won't be present in
// the graph.
//
// In general when you want to draw an edge from A to B, you should actually
// draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
// If instr has just one computation and it's trivial (e.g. "return param0 +
// param1"), returns a string you can put into the node's body that names the
// subcomputation, e.g. "Subcomputation: <b>add</b>".
@ -595,16 +610,15 @@ tooltip = " ";
// belongs to a fusion node, it's drawn in place of the fusion instruction,
// so there's no need to link those.
if (parent_instr->opcode() != HloOpcode::kFusion) {
VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to "
<< parent_instr->name() << " as " << next_edge_id_;
edge_ids_.insert(
{{subcomp->root_instruction(), parent_instr}, next_edge_id_++});
const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
<< " as " << next_edge_id_;
edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
const char* edge_fmt =
R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
edges_.push_back(
Printf(edge_fmt, InstructionId(subcomp->root_instruction()),
InstructionId(parent_instr), SubcomputationId(subcomp),
subcomp->name(), parent_instr->name()));
edges_.push_back(Printf(
edge_fmt, InstructionId(from), InstructionId(parent_instr),
SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
}
string computation =
@ -633,15 +647,7 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
}
string HloDotDumper::DumpRootTag() {
HloInstruction* from = computation_->root_instruction();
// Fusion nodes are expanded inline, so if root is an expanded fusion node,
// walk up the graph until we find a node that isn't.
while (from->opcode() == HloOpcode::kFusion &&
ShouldShowFusionSubcomputation(from)) {
from = from->fused_expression_root();
}
const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
auto from_id = InstructionId(from);
if (!filter_.Show(from)) {
@ -1080,13 +1086,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
int64 operand_num, bool control_edge = false) {
// Fusion nodes' subcomputations are displayed inline, so if 'from' is a
// fusion node and the node's subcomputation is shown, we draw our edge
// starting at the fusion node's root instead of at the fusion node itself.
if (from->opcode() == HloOpcode::kFusion &&
ShouldShowFusionSubcomputation(from)) {
from = from->fused_expression_root();
}
from = GetNodeForEdge(from);
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
ShouldMergeIntoUsers(from)) {
return;
@ -1154,6 +1155,15 @@ string HloDotDumper::GetInstructionTrivialComputationStr(
return Join(lines, "<br/>");
}
const HloInstruction* HloDotDumper::GetNodeForEdge(
const HloInstruction* instr) {
while (instr->opcode() == HloOpcode::kFusion &&
ShouldShowFusionSubcomputation(instr)) {
instr = instr->fused_expression_root();
}
return instr;
}
tensorflow::mutex& RendererMutex() {
static tensorflow::mutex* mu = new tensorflow::mutex;
return *mu;