mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[XLA] DOT dumper: Handle fusion nodes nested inside other nodes (e.g. map).
PiperOrigin-RevId: 173752314
This commit is contained in:
parent
8ec7540e00
commit
e7645b6295
|
|
@ -360,6 +360,21 @@ class HloDotDumper {
|
|||
string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
|
||||
void AddInstructionIncomingEdges(const HloInstruction* instr);
|
||||
|
||||
// For most instructions, GetNodeForEdge(instr) returns instr.
|
||||
//
|
||||
// The exception is fusion nodes. For these, we walk up the chain of nested
|
||||
// fusion nodes starting at instr until we reach a node that either (a) isn't
|
||||
// a fusion node, or (b) is a fusion node for which
|
||||
// ShouldShowFusionSubcomputation is false.
|
||||
//
|
||||
// We do this because fusion nodes are expanded inline -- if
|
||||
// ShouldShowFusionSubcomputation is true, the fusion node won't be present in
|
||||
// the graph.
|
||||
//
|
||||
// In general when you want to draw an edge from A to B, you should actually
|
||||
// draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
|
||||
const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
|
||||
|
||||
// If instr has just one computation and it's trivial (e.g. "return param0 +
|
||||
// param1"), returns a string you can put into the node's body that names the
|
||||
// subcomputation, e.g. "Subcomputation: <b>add</b>".
|
||||
|
|
@ -595,16 +610,15 @@ tooltip = " ";
|
|||
// belongs to a fusion node, it's drawn in place of the fusion instruction,
|
||||
// so there's no need to link those.
|
||||
if (parent_instr->opcode() != HloOpcode::kFusion) {
|
||||
VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to "
|
||||
<< parent_instr->name() << " as " << next_edge_id_;
|
||||
edge_ids_.insert(
|
||||
{{subcomp->root_instruction(), parent_instr}, next_edge_id_++});
|
||||
const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
|
||||
VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
|
||||
<< " as " << next_edge_id_;
|
||||
edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
|
||||
const char* edge_fmt =
|
||||
R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
|
||||
edges_.push_back(
|
||||
Printf(edge_fmt, InstructionId(subcomp->root_instruction()),
|
||||
InstructionId(parent_instr), SubcomputationId(subcomp),
|
||||
subcomp->name(), parent_instr->name()));
|
||||
edges_.push_back(Printf(
|
||||
edge_fmt, InstructionId(from), InstructionId(parent_instr),
|
||||
SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
|
||||
}
|
||||
|
||||
string computation =
|
||||
|
|
@ -633,15 +647,7 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
|
|||
}
|
||||
|
||||
string HloDotDumper::DumpRootTag() {
|
||||
HloInstruction* from = computation_->root_instruction();
|
||||
|
||||
// Fusion nodes are expanded inline, so if root is an expanded fusion node,
|
||||
// walk up the graph until we find a node that isn't.
|
||||
while (from->opcode() == HloOpcode::kFusion &&
|
||||
ShouldShowFusionSubcomputation(from)) {
|
||||
from = from->fused_expression_root();
|
||||
}
|
||||
|
||||
const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
|
||||
auto from_id = InstructionId(from);
|
||||
|
||||
if (!filter_.Show(from)) {
|
||||
|
|
@ -1080,13 +1086,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
|
|||
void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
|
||||
auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
|
||||
int64 operand_num, bool control_edge = false) {
|
||||
// Fusion nodes' subcomputations are displayed inline, so if 'from' is a
|
||||
// fusion node and the node's subcomputation is shown, we draw our edge
|
||||
// starting at the fusion node's root instead of at the fusion node itself.
|
||||
if (from->opcode() == HloOpcode::kFusion &&
|
||||
ShouldShowFusionSubcomputation(from)) {
|
||||
from = from->fused_expression_root();
|
||||
}
|
||||
from = GetNodeForEdge(from);
|
||||
|
||||
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
|
||||
ShouldMergeIntoUsers(from)) {
|
||||
return;
|
||||
|
|
@ -1154,6 +1155,15 @@ string HloDotDumper::GetInstructionTrivialComputationStr(
|
|||
return Join(lines, "<br/>");
|
||||
}
|
||||
|
||||
const HloInstruction* HloDotDumper::GetNodeForEdge(
|
||||
const HloInstruction* instr) {
|
||||
while (instr->opcode() == HloOpcode::kFusion &&
|
||||
ShouldShowFusionSubcomputation(instr)) {
|
||||
instr = instr->fused_expression_root();
|
||||
}
|
||||
return instr;
|
||||
}
|
||||
|
||||
tensorflow::mutex& RendererMutex() {
|
||||
static tensorflow::mutex* mu = new tensorflow::mutex;
|
||||
return *mu;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user