pytorch/torch/package/_digraph.py
John Reese f625bb4bc9 [codemod][usort] apply import merging for fbcode (1 of 11) (#78973)
Summary:
Applies new import merging and sorting from µsort v1.0.

When merging imports, µsort will make a best-effort to move associated
comments to match merged elements, but there are known limitations due to
the diynamic nature of Python and developer tooling. These changes should
not produce any dangerous runtime changes, but may require touch-ups to
satisfy linters and other tooling.

Note that µsort uses case-insensitive, lexicographical sorting, which
results in a different ordering compared to isort. This provides a more
consistent sorting order, matching the case-insensitive order used when
sorting import statements by module name, and ensures that "frog", "FROG",
and "Frog" always sort next to each other.

For details on µsort's sorting and merging semantics, see the user guide:
https://usort.readthedocs.io/en/stable/guide.html#sorting

Test Plan: S271899

Reviewed By: lisroach

Differential Revision: D36402110

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78973
Approved by: https://github.com/osalpekar
2022-06-06 23:44:28 +00:00

174 lines
5.5 KiB
Python

from collections import deque
from typing import List, Set
class DiGraph:
"""Really simple unweighted directed graph data structure to track dependencies.
The API is pretty much the same as networkx so if you add something just
copy their API.
"""
def __init__(self):
# Dict of node -> dict of arbitrary attributes
self._node = {}
# Nested dict of node -> successor node -> nothing.
# (didn't implement edge data)
self._succ = {}
# Nested dict of node -> predecessor node -> nothing.
self._pred = {}
# Keep track of the order in which nodes are added to
# the graph.
self._node_order = {}
self._insertion_idx = 0
def add_node(self, n, **kwargs):
"""Add a node to the graph.
Args:
n: the node. Can we any object that is a valid dict key.
**kwargs: any attributes you want to attach to the node.
"""
if n not in self._node:
self._node[n] = kwargs
self._succ[n] = {}
self._pred[n] = {}
self._node_order[n] = self._insertion_idx
self._insertion_idx += 1
else:
self._node[n].update(kwargs)
def add_edge(self, u, v):
"""Add an edge to graph between nodes ``u`` and ``v``
``u`` and ``v`` will be created if they do not already exist.
"""
# add nodes
self.add_node(u)
self.add_node(v)
# add the edge
self._succ[u][v] = True
self._pred[v][u] = True
def successors(self, n):
"""Returns an iterator over successor nodes of n."""
try:
return iter(self._succ[n])
except KeyError as e:
raise ValueError(f"The node {n} is not in the digraph.") from e
def predecessors(self, n):
"""Returns an iterator over predecessors nodes of n."""
try:
return iter(self._pred[n])
except KeyError as e:
raise ValueError(f"The node {n} is not in the digraph.") from e
@property
def edges(self):
"""Returns an iterator over all edges (u, v) in the graph"""
for n, successors in self._succ.items():
for succ in successors:
yield n, succ
@property
def nodes(self):
"""Returns a dictionary of all nodes to their attributes."""
return self._node
def __iter__(self):
"""Iterate over the nodes."""
return iter(self._node)
def __contains__(self, n):
"""Returns True if ``n`` is a node in the graph, False otherwise."""
try:
return n in self._node
except TypeError:
return False
def forward_transitive_closure(self, src: str) -> Set[str]:
"""Returns a set of nodes that are reachable from src"""
result = set(src)
working_set = deque(src)
while len(working_set) > 0:
cur = working_set.popleft()
for n in self.successors(cur):
if n not in result:
result.add(n)
working_set.append(n)
return result
def backward_transitive_closure(self, src: str) -> Set[str]:
"""Returns a set of nodes that are reachable from src in reverse direction"""
result = set(src)
working_set = deque(src)
while len(working_set) > 0:
cur = working_set.popleft()
for n in self.predecessors(cur):
if n not in result:
result.add(n)
working_set.append(n)
return result
def all_paths(self, src: str, dst: str):
"""Returns a subgraph rooted at src that shows all the paths to dst."""
result_graph = DiGraph()
# First compute forward transitive closure of src (all things reachable from src).
forward_reachable_from_src = self.forward_transitive_closure(src)
if dst not in forward_reachable_from_src:
return result_graph
# Second walk the reverse dependencies of dst, adding each node to
# the output graph iff it is also present in forward_reachable_from_src.
# we don't use backward_transitive_closures for optimization purposes
working_set = deque(dst)
while len(working_set) > 0:
cur = working_set.popleft()
for n in self.predecessors(cur):
if n in forward_reachable_from_src:
result_graph.add_edge(n, cur)
# only explore further if its reachable from src
working_set.append(n)
return result_graph.to_dot()
def first_path(self, dst: str) -> List[str]:
"""Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
path = []
while dst:
path.append(dst)
candidates = self._pred[dst].keys()
dst, min_idx = "", None
for candidate in candidates:
idx = self._node_order.get(candidate, None)
if idx is None:
break
if min_idx is None or idx < min_idx:
min_idx = idx
dst = candidate
return list(reversed(path))
def to_dot(self) -> str:
"""Returns the dot representation of the graph.
Returns:
A dot representation of the graph.
"""
edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges)
return f"""\
digraph G {{
rankdir = LR;
node [shape=box];
{edges}
}}
"""