mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Changes to get inlined graph and proper names after JIT updates (#30244)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30244 This makes several small changes to the tensorboard graph parsing methods to address the recent changes to the PyTorch JIT trace/graph. - Inline graph to get information for all nodes - Assign and propagate scope names to GetAttr nodes - Prune all useless GetAttr nodes (any with a ClassType output type - tensors and primitives are kept) - Create output nodes so output tensor shape can be examined Reviewed By: sanekmelnikov Differential Revision: D18556323 fbshipit-source-id: b73a809bacfa554c3fe9c4ae3563525f57539874
This commit is contained in:
parent
983728489a
commit
0c04763d59
152
test/expect/TestTensorBoard.test_pytorch_graph.expect
Normal file
152
test/expect/TestTensorBoard.test_pytorch_graph.expect
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
node {
|
||||
name: "input/input"
|
||||
op: "IO Node"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "attr"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "output/output.1"
|
||||
op: "IO Node"
|
||||
input: "myLinear/Linear[l]/22"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "attr"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "myLinear/Linear[l]/bias/17"
|
||||
op: "prim::GetAttr"
|
||||
input: "myLinear/Linear[l]/weight/14"
|
||||
attr {
|
||||
key: "attr"
|
||||
value {
|
||||
s: "{ name : bias }"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "myLinear/Linear[l]/weight/18"
|
||||
op: "prim::GetAttr"
|
||||
input: "myLinear/Linear[l]/weight/14"
|
||||
attr {
|
||||
key: "attr"
|
||||
value {
|
||||
s: "{ name : weight }"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "myLinear/Linear[l]/19"
|
||||
op: "aten::t"
|
||||
input: "myLinear/Linear[l]/weight/18"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "attr"
|
||||
value {
|
||||
s: "{}"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "myLinear/Linear[l]/20"
|
||||
op: "prim::Constant"
|
||||
attr {
|
||||
key: "attr"
|
||||
value {
|
||||
s: "{ value : 1}"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "myLinear/Linear[l]/21"
|
||||
op: "prim::Constant"
|
||||
attr {
|
||||
key: "attr"
|
||||
value {
|
||||
s: "{ value : 1}"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "myLinear/Linear[l]/22"
|
||||
op: "aten::addmm"
|
||||
input: "myLinear/Linear[l]/bias/17"
|
||||
input: "input/input"
|
||||
input: "myLinear/Linear[l]/19"
|
||||
input: "myLinear/Linear[l]/20"
|
||||
input: "myLinear/Linear[l]/21"
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
dim {
|
||||
size: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "attr"
|
||||
value {
|
||||
s: "{}"
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 22
|
||||
}
|
||||
|
|
@ -76,6 +76,7 @@ if TEST_TENSORBOARD:
|
|||
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
|
||||
from torch.utils.tensorboard._convert_np import make_np
|
||||
from torch.utils.tensorboard import _caffe2_graph as c2_graph
|
||||
from torch.utils.tensorboard._pytorch_graph import graph
|
||||
from google.protobuf import text_format
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -502,6 +503,9 @@ class TestTensorBoardPytorchGraph(BaseTestCase):
|
|||
with self.createSummaryWriter() as w:
|
||||
w.add_graph(myLinear(), dummy_input)
|
||||
|
||||
graphdef, _ = graph(myLinear(), dummy_input)
|
||||
self.assertTrue(compare_proto(graphdef, self))
|
||||
|
||||
def test_mlp_graph(self):
|
||||
dummy_input = (torch.zeros(2, 1, 28, 28),)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ methods_OP = ['attributeNames', 'hasMultipleOutputs', 'hasUses', 'inputs',
|
|||
# But the below are sufficient for now.
|
||||
methods_IO = ['node', 'offset', 'debugName']
|
||||
|
||||
GETATTR_KIND = 'prim::GetAttr'
|
||||
CLASSTYPE_KIND = 'ClassType'
|
||||
|
||||
class NodeBase(object):
|
||||
def __init__(self, debugName=None, inputs=None, scope=None, tensor_size=None, op_type='UnSpecified', attributes=''):
|
||||
|
|
@ -124,14 +126,6 @@ class GraphPy(object):
|
|||
self.nodes_io[x.debugName] = x
|
||||
if isinstance(x, NodePyOP):
|
||||
self.nodes_op.append(x)
|
||||
for node_output, outputSize in zip(x.outputs, x.outputstensor_size):
|
||||
self.scope_name_appeared.append(x.scopeName)
|
||||
self.nodes_io[node_output] = NodeBase(node_output,
|
||||
x.inputs,
|
||||
x.scopeName,
|
||||
outputSize,
|
||||
op_type=x.kind,
|
||||
attributes=x.attributes)
|
||||
|
||||
def printall(self):
|
||||
print('all nodes')
|
||||
|
|
@ -146,6 +140,18 @@ class GraphPy(object):
|
|||
self.shallowest_scope_name = fullscope.split('/')[0]
|
||||
|
||||
def populate_namespace_from_OP_to_IO(self):
|
||||
for node in self.nodes_op:
|
||||
for node_output, outputSize in zip(node.outputs, node.outputstensor_size):
|
||||
self.scope_name_appeared.append(node.scopeName)
|
||||
self.nodes_io[node_output] = NodeBase(node_output,
|
||||
node.inputs,
|
||||
node.scopeName,
|
||||
outputSize,
|
||||
op_type=node.kind,
|
||||
attributes=node.attributes)
|
||||
|
||||
self.find_common_root()
|
||||
|
||||
for node in self.nodes_op:
|
||||
for input_node_id in node.inputs:
|
||||
self.unique_name_to_scoped_name[input_node_id] = node.scopeName + '/' + input_node_id
|
||||
|
|
@ -184,13 +190,14 @@ class GraphPy(object):
|
|||
return nodes
|
||||
|
||||
|
||||
def parse(graph, args=None, omit_useless_nodes=True):
|
||||
def parse(graph, trace, args=None, omit_useless_nodes=True):
|
||||
"""This method parses an optimized PyTorch model graph and produces
|
||||
a list of nodes and node stats for eventual conversion to TensorBoard
|
||||
protobuf format.
|
||||
|
||||
Args:
|
||||
graph (PyTorch module): The model to be parsed.
|
||||
graph (PyTorch module): The model graph to be parsed.
|
||||
trace (PyTorch JIT TracedModule): The model trace to be parsed.
|
||||
args (tuple): input tensor[s] for the model.
|
||||
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
|
||||
"""
|
||||
|
|
@ -198,22 +205,66 @@ def parse(graph, args=None, omit_useless_nodes=True):
|
|||
|
||||
scope = {}
|
||||
nodes_py = GraphPy()
|
||||
for i, node in enumerate(graph.inputs()):
|
||||
for node in graph.inputs():
|
||||
if omit_useless_nodes:
|
||||
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
|
||||
continue
|
||||
|
||||
if i < n_inputs:
|
||||
if node.type().kind() != CLASSTYPE_KIND:
|
||||
nodes_py.append(NodePyIO(node, 'input'))
|
||||
else:
|
||||
nodes_py.append(NodePyIO(node)) # parameter
|
||||
|
||||
attr_to_scope = dict()
|
||||
for node in graph.nodes():
|
||||
nodes_py.append(NodePyOP(node))
|
||||
if node.kind() == GETATTR_KIND:
|
||||
attr_name = node.s('name')
|
||||
parent = node.input().node()
|
||||
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
|
||||
parent_attr_name = parent.s('name')
|
||||
parent_scope = attr_to_scope[parent_attr_name]
|
||||
attr_scope = parent_scope.split('/')[-1]
|
||||
attr_to_scope[attr_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
|
||||
else:
|
||||
attr_to_scope[attr_name] = '__module.{}'.format(attr_name)
|
||||
# We don't need classtype nodes; scope will provide this information
|
||||
if node.output().type().kind() != CLASSTYPE_KIND:
|
||||
node_py = NodePyOP(node)
|
||||
node_py.scopeName = attr_to_scope[attr_name]
|
||||
nodes_py.append(node_py)
|
||||
else:
|
||||
nodes_py.append(NodePyOP(node))
|
||||
|
||||
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
|
||||
node_py = NodePyIO(node, 'output')
|
||||
node_py.debugName = "output.{}".format(i + 1)
|
||||
node_py.inputs = [node.debugName()]
|
||||
nodes_py.append(node_py)
|
||||
|
||||
def parse_traced_name(module_name):
|
||||
prefix = 'TracedModule['
|
||||
suffix = ']'
|
||||
if module_name.startswith(prefix) and module_name.endswith(suffix):
|
||||
module_name = module_name[len(prefix):-len(suffix)]
|
||||
return module_name
|
||||
|
||||
alias_to_name = dict()
|
||||
base_name = parse_traced_name(trace._name)
|
||||
for name, module in trace.named_modules(prefix='__module'):
|
||||
mod_name = parse_traced_name(module._name)
|
||||
attr_name = name.split('.')[-1]
|
||||
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
|
||||
|
||||
for node in nodes_py.nodes_op:
|
||||
module_aliases = node.scopeName.split('/')
|
||||
replacements = [
|
||||
alias_to_name[alias]
|
||||
if alias in alias_to_name
|
||||
else alias.split('.')[-1]
|
||||
for alias in module_aliases
|
||||
]
|
||||
node.scopeName = base_name
|
||||
if any(replacements):
|
||||
node.scopeName += '/' + '/'.join(replacements)
|
||||
|
||||
for node in graph.outputs(): # must place last.
|
||||
NodePyIO(node, 'output')
|
||||
nodes_py.find_common_root()
|
||||
nodes_py.populate_namespace_from_OP_to_IO()
|
||||
return nodes_py.to_proto()
|
||||
|
||||
|
|
@ -233,6 +284,7 @@ def graph(model, args, verbose=False):
|
|||
try:
|
||||
trace = torch.jit.trace(model, args)
|
||||
graph = trace.graph
|
||||
torch._C._jit_pass_inline(graph)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print('Error occurs, No graph saved')
|
||||
|
|
@ -240,7 +292,7 @@ def graph(model, args, verbose=False):
|
|||
|
||||
if verbose:
|
||||
print(graph)
|
||||
list_of_nodes = parse(graph, args)
|
||||
list_of_nodes = parse(graph, trace, args)
|
||||
# We are hardcoding that this was run on CPU even though it might have actually
|
||||
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
|
||||
# on actual execution.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user