Changes to get inlined graph and proper names after JIT updates (#30244)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30244

This makes several small changes to the tensorboard graph parsing methods to address the recent changes to the PyTorch JIT trace/graph.
- Inline graph to get information for all nodes
- Assign and propagate scope names to GetAttr nodes
- Prune all useless GetAttr nodes (any with a ClassType output type - tensors and primitives are kept)
- Create output nodes so output tensor shape can be examined

Reviewed By: sanekmelnikov

Differential Revision: D18556323

fbshipit-source-id: b73a809bacfa554c3fe9c4ae3563525f57539874
This commit is contained in:
Jonathan Reynolds 2019-11-21 16:56:42 -08:00 committed by Facebook Github Bot
parent 983728489a
commit 0c04763d59
3 changed files with 227 additions and 19 deletions

View File

@ -0,0 +1,152 @@
node {
name: "input/input"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "myLinear/Linear[l]/22"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "myLinear/Linear[l]/bias/17"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "myLinear/Linear[l]/weight/18"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "myLinear/Linear[l]/19"
op: "aten::t"
input: "myLinear/Linear[l]/weight/18"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "myLinear/Linear[l]/20"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/21"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/22"
op: "aten::addmm"
input: "myLinear/Linear[l]/bias/17"
input: "input/input"
input: "myLinear/Linear[l]/19"
input: "myLinear/Linear[l]/20"
input: "myLinear/Linear[l]/21"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}

View File

@ -76,6 +76,7 @@ if TEST_TENSORBOARD:
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
from torch.utils.tensorboard._convert_np import make_np from torch.utils.tensorboard._convert_np import make_np
from torch.utils.tensorboard import _caffe2_graph as c2_graph from torch.utils.tensorboard import _caffe2_graph as c2_graph
from torch.utils.tensorboard._pytorch_graph import graph
from google.protobuf import text_format from google.protobuf import text_format
from PIL import Image from PIL import Image
@ -502,6 +503,9 @@ class TestTensorBoardPytorchGraph(BaseTestCase):
with self.createSummaryWriter() as w: with self.createSummaryWriter() as w:
w.add_graph(myLinear(), dummy_input) w.add_graph(myLinear(), dummy_input)
graphdef, _ = graph(myLinear(), dummy_input)
self.assertTrue(compare_proto(graphdef, self))
def test_mlp_graph(self): def test_mlp_graph(self):
dummy_input = (torch.zeros(2, 1, 28, 28),) dummy_input = (torch.zeros(2, 1, 28, 28),)

View File

@ -18,6 +18,8 @@ methods_OP = ['attributeNames', 'hasMultipleOutputs', 'hasUses', 'inputs',
# But the below are sufficient for now. # But the below are sufficient for now.
methods_IO = ['node', 'offset', 'debugName'] methods_IO = ['node', 'offset', 'debugName']
GETATTR_KIND = 'prim::GetAttr'
CLASSTYPE_KIND = 'ClassType'
class NodeBase(object): class NodeBase(object):
def __init__(self, debugName=None, inputs=None, scope=None, tensor_size=None, op_type='UnSpecified', attributes=''): def __init__(self, debugName=None, inputs=None, scope=None, tensor_size=None, op_type='UnSpecified', attributes=''):
@ -124,14 +126,6 @@ class GraphPy(object):
self.nodes_io[x.debugName] = x self.nodes_io[x.debugName] = x
if isinstance(x, NodePyOP): if isinstance(x, NodePyOP):
self.nodes_op.append(x) self.nodes_op.append(x)
for node_output, outputSize in zip(x.outputs, x.outputstensor_size):
self.scope_name_appeared.append(x.scopeName)
self.nodes_io[node_output] = NodeBase(node_output,
x.inputs,
x.scopeName,
outputSize,
op_type=x.kind,
attributes=x.attributes)
def printall(self): def printall(self):
print('all nodes') print('all nodes')
@ -146,6 +140,18 @@ class GraphPy(object):
self.shallowest_scope_name = fullscope.split('/')[0] self.shallowest_scope_name = fullscope.split('/')[0]
def populate_namespace_from_OP_to_IO(self): def populate_namespace_from_OP_to_IO(self):
for node in self.nodes_op:
for node_output, outputSize in zip(node.outputs, node.outputstensor_size):
self.scope_name_appeared.append(node.scopeName)
self.nodes_io[node_output] = NodeBase(node_output,
node.inputs,
node.scopeName,
outputSize,
op_type=node.kind,
attributes=node.attributes)
self.find_common_root()
for node in self.nodes_op: for node in self.nodes_op:
for input_node_id in node.inputs: for input_node_id in node.inputs:
self.unique_name_to_scoped_name[input_node_id] = node.scopeName + '/' + input_node_id self.unique_name_to_scoped_name[input_node_id] = node.scopeName + '/' + input_node_id
@ -184,13 +190,14 @@ class GraphPy(object):
return nodes return nodes
def parse(graph, args=None, omit_useless_nodes=True): def parse(graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces """This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format. protobuf format.
Args: Args:
graph (PyTorch module): The model to be parsed. graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model. args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph. omit_useless_nodes (boolean): Whether to remove nodes from the graph.
""" """
@ -198,22 +205,66 @@ def parse(graph, args=None, omit_useless_nodes=True):
scope = {} scope = {}
nodes_py = GraphPy() nodes_py = GraphPy()
for i, node in enumerate(graph.inputs()): for node in graph.inputs():
if omit_useless_nodes: if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout) if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue continue
if i < n_inputs: if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input')) nodes_py.append(NodePyIO(node, 'input'))
else:
nodes_py.append(NodePyIO(node)) # parameter
attr_to_scope = dict()
for node in graph.nodes(): for node in graph.nodes():
nodes_py.append(NodePyOP(node)) if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
parent_attr_name = parent.s('name')
parent_scope = attr_to_scope[parent_attr_name]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[attr_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
else:
attr_to_scope[attr_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[attr_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')
replacements = [
alias_to_name[alias]
if alias in alias_to_name
else alias.split('.')[-1]
for alias in module_aliases
]
node.scopeName = base_name
if any(replacements):
node.scopeName += '/' + '/'.join(replacements)
for node in graph.outputs(): # must place last.
NodePyIO(node, 'output')
nodes_py.find_common_root()
nodes_py.populate_namespace_from_OP_to_IO() nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto() return nodes_py.to_proto()
@ -233,6 +284,7 @@ def graph(model, args, verbose=False):
try: try:
trace = torch.jit.trace(model, args) trace = torch.jit.trace(model, args)
graph = trace.graph graph = trace.graph
torch._C._jit_pass_inline(graph)
except RuntimeError as e: except RuntimeError as e:
print(e) print(e)
print('Error occurs, No graph saved') print('Error occurs, No graph saved')
@ -240,7 +292,7 @@ def graph(model, args, verbose=False):
if verbose: if verbose:
print(graph) print(graph)
list_of_nodes = parse(graph, args) list_of_nodes = parse(graph, trace, args)
# We are hardcoding that this was run on CPU even though it might have actually # We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing # run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution. # on actual execution.