mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[fx/graph_drawer] Add skip_node_names_in_args option, default to True (#73815)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73815 Add `skip_node_names_in_args` (default=`True`) which will skip including node names in args/kwargs during graph drawing. Test Plan: Default (`skip_node_names_in_args=True`): {F707455583} Vs. `skip_node_names_in_args=False`: {F707046375} Reviewed By: wushirong Differential Revision: D34659144 fbshipit-source-id: 9f0bd7bee98dc1ca8eecdabc960804564d83777b (cherry picked from commit a0ed64b51f0187115586f4001dc81148c7ed18b9)
This commit is contained in:
parent
2c3509606d
commit
e99e3fa580
|
|
@ -60,9 +60,19 @@ if HAS_PYDOT:
|
|||
f.write(g.get_dot_graph().create_svg())
|
||||
"""
|
||||
|
||||
def __init__(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
graph_module: torch.fx.GraphModule,
|
||||
name: str,
|
||||
ignore_getattr: bool = False,
|
||||
skip_node_names_in_args: bool = True,
|
||||
):
|
||||
self._name = name
|
||||
self._dot_graphs = {name: self._to_dot(graph_module, name, ignore_getattr)}
|
||||
self._dot_graphs = {
|
||||
name: self._to_dot(
|
||||
graph_module, name, ignore_getattr, skip_node_names_in_args
|
||||
)
|
||||
}
|
||||
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op != "call_module":
|
||||
|
|
@ -73,7 +83,12 @@ if HAS_PYDOT:
|
|||
if not isinstance(leaf_node, torch.fx.GraphModule):
|
||||
continue
|
||||
|
||||
self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(leaf_node, f"{name}_{node.target}", ignore_getattr)
|
||||
self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
|
||||
leaf_node,
|
||||
f"{name}_{node.target}",
|
||||
ignore_getattr,
|
||||
skip_node_names_in_args,
|
||||
)
|
||||
|
||||
def get_dot_graph(self, submod_name=None) -> pydot.Dot:
|
||||
if submod_name is None:
|
||||
|
|
@ -129,15 +144,32 @@ if HAS_PYDOT:
|
|||
|
||||
return _get_qualified_name(target)
|
||||
|
||||
def _get_node_label(self, module: torch.fx.GraphModule, node: torch.fx.Node) -> str:
|
||||
def _get_node_label(
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
node: torch.fx.Node,
|
||||
skip_node_names_in_args: bool,
|
||||
) -> str:
|
||||
def _get_str_for_args_kwargs(arg):
|
||||
if isinstance(arg, tuple):
|
||||
s = r",\n".join(_format_arg(a, max_list_len=10) for a in arg)
|
||||
return fr"(\l{s},\n)".replace("{", r"\{").replace("}", r"\}")
|
||||
if isinstance(arg, dict):
|
||||
s = r",\n".join(f"{k}: {_format_arg(v, max_list_len=10)}" for k, v in arg.items())
|
||||
return fr"{{\l{s},\n}}".replace("{", r"\{").replace("}", r"\}")
|
||||
return _format_arg(arg).replace("{", r"\{").replace("}", r"\}")
|
||||
prefix, suffix = r"|args=(\l", r",\n)\l"
|
||||
arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
|
||||
elif isinstance(arg, dict):
|
||||
prefix, suffix = r"|kwargs={\l", r",\n}\l"
|
||||
arg_strs_list = [
|
||||
f"{k}: {_format_arg(v, max_list_len=8)}"
|
||||
for k, v in arg.items()
|
||||
]
|
||||
else: # Fall back to nothing in unexpected case.
|
||||
return ""
|
||||
|
||||
# Strip out node names if requested.
|
||||
if skip_node_names_in_args:
|
||||
arg_strs_list = [a for a in arg_strs_list if "%" not in a]
|
||||
if len(arg_strs_list) == 0:
|
||||
return ""
|
||||
arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
|
||||
return arg_strs.replace("{", r"\{").replace("}", r"\}")
|
||||
|
||||
|
||||
label = "{" + f"name=%{node.name}|op_code={node.op}\n"
|
||||
|
|
@ -154,9 +186,9 @@ if HAS_PYDOT:
|
|||
else:
|
||||
label += f"|target={self._typename(node.target)}" + r"\n"
|
||||
if len(node.args) > 0:
|
||||
label += f"|args={_get_str_for_args_kwargs(node.args)}" + r"\l"
|
||||
label += _get_str_for_args_kwargs(node.args)
|
||||
if len(node.kwargs) > 0:
|
||||
label += f"|kwargs={_get_str_for_args_kwargs(node.kwargs)}" + r"\l"
|
||||
label += _get_str_for_args_kwargs(node.kwargs)
|
||||
label += f"|num_users={len(node.users)}" + r"\n"
|
||||
|
||||
tensor_meta = node.meta.get('tensor_meta')
|
||||
|
|
@ -221,7 +253,13 @@ if HAS_PYDOT:
|
|||
def _get_tensor_label(self, t: torch.Tensor) -> str:
|
||||
return str(t.dtype) + str(list(t.shape)) + r"\n"
|
||||
|
||||
def _to_dot(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool) -> pydot.Dot:
|
||||
def _to_dot(
|
||||
self,
|
||||
graph_module: torch.fx.GraphModule,
|
||||
name: str,
|
||||
ignore_getattr: bool,
|
||||
skip_node_names_in_args: bool,
|
||||
) -> pydot.Dot:
|
||||
"""
|
||||
Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph
|
||||
"""
|
||||
|
|
@ -233,7 +271,7 @@ if HAS_PYDOT:
|
|||
|
||||
style = self._get_node_style(node)
|
||||
dot_node = pydot.Node(
|
||||
node.name, label=self._get_node_label(graph_module, node), **style
|
||||
node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args), **style
|
||||
)
|
||||
dot_graph.add_node(dot_node)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user