mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
87 lines
2.2 KiB
Python
87 lines
2.2 KiB
Python
#!/usr/bin/env python3
|
|
|
|
"""
|
|
This module encapsulates dependencies on pygraphviz
|
|
"""
|
|
|
|
import colorsys
|
|
|
|
import cimodel.lib.conf_tree as conf_tree
|
|
|
|
|
|
def rgb2hex(rgb_tuple):
|
|
def to_hex(f):
|
|
return "%02x" % int(f * 255)
|
|
|
|
return "#" + "".join(map(to_hex, list(rgb_tuple)))
|
|
|
|
|
|
def handle_missing_graphviz(f):
|
|
"""
|
|
If the user has not installed pygraphviz, this causes
|
|
calls to the draw() method of the returned object to do nothing.
|
|
"""
|
|
try:
|
|
import pygraphviz # noqa: F401
|
|
return f
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
class FakeGraph:
|
|
def draw(self, *args, **kwargs):
|
|
pass
|
|
|
|
return lambda _: FakeGraph()
|
|
|
|
|
|
@handle_missing_graphviz
|
|
def generate_graph(toplevel_config_node):
|
|
"""
|
|
Traverses the graph once first just to find the max depth
|
|
"""
|
|
|
|
config_list = conf_tree.dfs(toplevel_config_node)
|
|
|
|
max_depth = 0
|
|
for config in config_list:
|
|
max_depth = max(max_depth, config.get_depth())
|
|
|
|
# color the nodes using the max depth
|
|
|
|
from pygraphviz import AGraph
|
|
dot = AGraph()
|
|
|
|
def node_discovery_callback(node, sibling_index, sibling_count):
|
|
depth = node.get_depth()
|
|
|
|
sat_min, sat_max = 0.1, 0.6
|
|
sat_range = sat_max - sat_min
|
|
|
|
saturation_fraction = sibling_index / float(sibling_count - 1) if sibling_count > 1 else 1
|
|
saturation = sat_min + sat_range * saturation_fraction
|
|
|
|
# TODO Use a hash of the node label to determine the color
|
|
hue = depth / float(max_depth + 1)
|
|
|
|
rgb_tuple = colorsys.hsv_to_rgb(hue, saturation, 1)
|
|
|
|
this_node_key = node.get_node_key()
|
|
|
|
dot.add_node(
|
|
this_node_key,
|
|
label=node.get_label(),
|
|
style="filled",
|
|
# fillcolor=hex_color + ":orange",
|
|
fillcolor=rgb2hex(rgb_tuple),
|
|
penwidth=3,
|
|
color=rgb2hex(colorsys.hsv_to_rgb(hue, saturation, 0.9))
|
|
)
|
|
|
|
def child_callback(node, child):
|
|
this_node_key = node.get_node_key()
|
|
child_node_key = child.get_node_key()
|
|
dot.add_edge((this_node_key, child_node_key))
|
|
|
|
conf_tree.dfs_recurse(toplevel_config_node, lambda x: None, node_discovery_callback, child_callback)
|
|
return dot
|