pytorch/torch/package/_digraph.py
Michael Suo 53c21172c0 [package] add simple graph data structure (#57337)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57337

Add a really simple graph data sturcutre for tracking dependencies. API
based on networkx, but I didn't want to require the dependency.

Differential Revision: D28114186

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Pulled By: suo

fbshipit-source-id: 802fd067017e493a48d6672538080e61d249accd
2021-05-05 17:57:00 -07:00

73 lines
2.1 KiB
Python

class DiGraph:
"""Really simple unweighted directed graph data structure to track dependencies.
The API is pretty much the same as networkx so if you add something just
copy their API.
"""
def __init__(self):
# Dict of node -> dict of arbitrary attributes
self._node = {}
# Nested dict of node -> successor node -> nothing.
# (didn't implement edge data)
self._succ = {}
def add_node(self, n, **kwargs):
"""Add a node to the graph.
Args:
n: the node. Can we any object that is a valid dict key.
**kwargs: any attributes you want to attach to the node.
"""
if n not in self._node:
self._node[n] = kwargs
self._succ[n] = {}
else:
self._node[n].update(kwargs)
def add_edge(self, u, v):
"""Add an edge to graph between nodes ``u`` and ``v``
``u`` and ``v`` will be created if they do not already exist.
"""
# add nodes
if u not in self._node:
self._node[u] = {}
self._succ[u] = {}
if v not in self._node:
self._node[v] = {}
self._succ[v] = {}
# add the edge
self._succ[u][v] = True
def successors(self, n):
"""Returns an iterator over successor nodes of n."""
try:
return iter(self._succ[n])
except KeyError as e:
raise ValueError(f"The node {n} is not in the digraph.") from e
@property
def edges(self):
"""Returns an iterator over all edges (u, v) in the graph"""
for n, successors in self._succ.items():
for succ in successors:
yield n, succ
@property
def nodes(self):
"""Returns a dictionary of all nodes to their attributes."""
return self._node
def __iter__(self):
"""Iterate over the nodes."""
return iter(self._node)
def __contains__(self, n):
"""Returns True if ``n`` is a node in the graph, False otherwise."""
try:
return n in self._node
except TypeError:
return False