[fx/graph_drawer] Add skip_node_names_in_args option, default to True (#73815)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73815

Add `skip_node_names_in_args` (default=`True`) which will skip including node names in args/kwargs during graph drawing.

Test Plan:
Default (`skip_node_names_in_args=True`):

{F707455583}

Vs. `skip_node_names_in_args=False`:

{F707046375}

Reviewed By: wushirong

Differential Revision: D34659144

fbshipit-source-id: 9f0bd7bee98dc1ca8eecdabc960804564d83777b
(cherry picked from commit a0ed64b51f0187115586f4001dc81148c7ed18b9)
This commit is contained in:
Jordan Fix 2022-03-07 17:38:13 -08:00 committed by PyTorch MergeBot
parent 2c3509606d
commit e99e3fa580

View File

@ -60,9 +60,19 @@ if HAS_PYDOT:
f.write(g.get_dot_graph().create_svg())
"""
def __init__(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool = False):
def __init__(
self,
graph_module: torch.fx.GraphModule,
name: str,
ignore_getattr: bool = False,
skip_node_names_in_args: bool = True,
):
self._name = name
self._dot_graphs = {name: self._to_dot(graph_module, name, ignore_getattr)}
self._dot_graphs = {
name: self._to_dot(
graph_module, name, ignore_getattr, skip_node_names_in_args
)
}
for node in graph_module.graph.nodes:
if node.op != "call_module":
@ -73,7 +83,12 @@ if HAS_PYDOT:
if not isinstance(leaf_node, torch.fx.GraphModule):
continue
self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(leaf_node, f"{name}_{node.target}", ignore_getattr)
self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
leaf_node,
f"{name}_{node.target}",
ignore_getattr,
skip_node_names_in_args,
)
def get_dot_graph(self, submod_name=None) -> pydot.Dot:
if submod_name is None:
@ -129,15 +144,32 @@ if HAS_PYDOT:
return _get_qualified_name(target)
def _get_node_label(self, module: torch.fx.GraphModule, node: torch.fx.Node) -> str:
def _get_node_label(
self,
module: torch.fx.GraphModule,
node: torch.fx.Node,
skip_node_names_in_args: bool,
) -> str:
def _get_str_for_args_kwargs(arg):
if isinstance(arg, tuple):
s = r",\n".join(_format_arg(a, max_list_len=10) for a in arg)
return fr"(\l{s},\n)".replace("{", r"\{").replace("}", r"\}")
if isinstance(arg, dict):
s = r",\n".join(f"{k}: {_format_arg(v, max_list_len=10)}" for k, v in arg.items())
return fr"{{\l{s},\n}}".replace("{", r"\{").replace("}", r"\}")
return _format_arg(arg).replace("{", r"\{").replace("}", r"\}")
prefix, suffix = r"|args=(\l", r",\n)\l"
arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
elif isinstance(arg, dict):
prefix, suffix = r"|kwargs={\l", r",\n}\l"
arg_strs_list = [
f"{k}: {_format_arg(v, max_list_len=8)}"
for k, v in arg.items()
]
else: # Fall back to nothing in unexpected case.
return ""
# Strip out node names if requested.
if skip_node_names_in_args:
arg_strs_list = [a for a in arg_strs_list if "%" not in a]
if len(arg_strs_list) == 0:
return ""
arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
return arg_strs.replace("{", r"\{").replace("}", r"\}")
label = "{" + f"name=%{node.name}|op_code={node.op}\n"
@ -154,9 +186,9 @@ if HAS_PYDOT:
else:
label += f"|target={self._typename(node.target)}" + r"\n"
if len(node.args) > 0:
label += f"|args={_get_str_for_args_kwargs(node.args)}" + r"\l"
label += _get_str_for_args_kwargs(node.args)
if len(node.kwargs) > 0:
label += f"|kwargs={_get_str_for_args_kwargs(node.kwargs)}" + r"\l"
label += _get_str_for_args_kwargs(node.kwargs)
label += f"|num_users={len(node.users)}" + r"\n"
tensor_meta = node.meta.get('tensor_meta')
@ -221,7 +253,13 @@ if HAS_PYDOT:
def _get_tensor_label(self, t: torch.Tensor) -> str:
return str(t.dtype) + str(list(t.shape)) + r"\n"
def _to_dot(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool) -> pydot.Dot:
def _to_dot(
self,
graph_module: torch.fx.GraphModule,
name: str,
ignore_getattr: bool,
skip_node_names_in_args: bool,
) -> pydot.Dot:
"""
Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph
"""
@ -233,7 +271,7 @@ if HAS_PYDOT:
style = self._get_node_style(node)
dot_node = pydot.Node(
node.name, label=self._get_node_label(graph_module, node), **style
node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args), **style
)
dot_graph.add_node(dot_node)