from collections import defaultdict from pycaffe2 import utils import sys import subprocess try: import pydot except ImportError: print ('Cannot import pydot, which is required for drawing a network. This ' 'can usually be installed in python with "pip install pydot". Also, ' 'pydot requires graphviz to convert dot files to pdf: in ubuntu, this ' 'can usually be installed with "sudo apt-get install graphviz".') print ('net_drawer will not run correctly. Please install the correct ' 'dependencies.') pydot = None from caffe2.proto import caffe2_pb2 from google.protobuf import text_format OP_STYLE = {'shape': 'box', 'color': '#0F9D58', 'style': 'filled', 'fontcolor': '#FFFFFF'} BLOB_STYLE = {'shape': 'octagon'} def GetPydotGraph(operators, name, rankdir='LR'): graph = pydot.Dot(name, rankdir=rankdir) pydot_nodes = {} pydot_node_counts = defaultdict(int) node_id = 0 for op_id, op in enumerate(operators): if op.name: op_node = pydot.Node( '%s/%s (op#%d)' % (op.name, op.type, op_id), **OP_STYLE) else: op_node = pydot.Node( '%s (op#%d)' % (op.type, op_id), **OP_STYLE) graph.add_node(op_node) # print 'Op: %s' % op.name # print 'inputs: %s' % str(op.input) # print 'outputs: %s' % str(op.output) for input_name in op.input: if input_name not in pydot_nodes: input_node = pydot.Node( input_name + str(pydot_node_counts[input_name]), label=input_name, **BLOB_STYLE) pydot_nodes[input_name] = input_node else: input_node = pydot_nodes[input_name] graph.add_node(input_node) graph.add_edge(pydot.Edge(input_node, op_node)) for output_name in op.output: if output_name in pydot_nodes: # we are overwriting an existing blob. need to updat the count. pydot_node_counts[output_name] += 1 output_node = pydot.Node( output_name + str(pydot_node_counts[output_name]), label=output_name, **BLOB_STYLE) pydot_nodes[output_name] = output_node graph.add_node(output_node) graph.add_edge(pydot.Edge(op_node, output_node)) return graph def GetPydotGraphMinimal(operators, name, rankdir='LR'): """Different from GetPydotGraph, hide all blob nodes and only show op nodes. """ graph = pydot.Dot(name, rankdir=rankdir) pydot_nodes = {} blob_parents = {} pydot_node_counts = defaultdict(int) node_id = 0 for op_id, op in enumerate(operators): if op.name: op_node = pydot.Node( '%s/%s (op#%d)' % (op.name, op.type, op_id), **OP_STYLE) else: op_node = pydot.Node( '%s (op#%d)' % (op.type, op_id), **OP_STYLE) graph.add_node(op_node) for input_name in op.input: if input_name in blob_parents: graph.add_edge(pydot.Edge(blob_parents[input_name], op_node)) for output_name in op.output: blob_parents[output_name] = op_node return graph def GetOperatorMapForPlan(plan_def): operator_map = {} for net_id, net in enumerate(plan_def.network): if net.HasField('name'): operator_map[plan_def.name + "_" + net.name] = net.op else: operator_map[plan_def.name + "_network_%d" % net_id] = net.op return operator_map def main(): with open(sys.argv[1], 'r') as fid: content = fid.read() graphs = utils.GetContentFromProtoString( content,{ caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x), caffe2_pb2.NetDef: lambda x: {x.name: x.op}, }) for key, operators in graphs.iteritems(): graph = GetPydotGraph(operators, key) filename = graph.get_name() + '.dot' graph.write(filename, format='raw') pdf_filename = filename[:-3] + 'pdf' with open(pdf_filename, 'w') as fid: try: subprocess.call(['dot', '-Tpdf', filename], stdout=fid) except OSError: print ('pydot requires graphviz to convert dot files to pdf: in ubuntu ' 'this can usually be installed with "sudo apt-get install ' 'graphviz". We have generated the .dot file but will not ' 'generate pdf file for now due to missing graphviz binaries.') if __name__ == '__main__': main()