Speed up topological sort by avoiding copies. The speedup is about 10-20%.

PiperOrigin-RevId: 163800134
This commit is contained in:
Yao Zhang 2017-08-01 00:33:15 -07:00 committed by TensorFlower Gardener
parent 6446895aa6
commit 618f913bbd

View File

@ -27,37 +27,49 @@ namespace grappler {
// For details, see https://en.wikipedia.org/wiki/Topological_sorting // For details, see https://en.wikipedia.org/wiki/Topological_sorting
void TopologicalSort(GraphDef* graph) { void TopologicalSort(GraphDef* graph) {
NodeMap node_map(graph); NodeMap node_map(graph);
std::deque<const NodeDef*> ready_nodes; std::vector<NodeDef*> ready_nodes;
ready_nodes.reserve(graph->node_size());
int front = 0;
int back = 0;
std::unordered_map<const NodeDef*, int> ready_inputs; std::unordered_map<const NodeDef*, int> ready_inputs;
for (const NodeDef& node : graph->node()) { for (int i = 0; i < graph->node_size(); i++) {
if (node.input_size() == 0) { auto node = graph->mutable_node(i);
ready_nodes.push_back(&node); if (node->input_size() == 0) {
ready_nodes.push_back(node);
back++;
} }
if (node.op() == "Merge") { if (IsMerge(*node)) {
ready_inputs[&node] = 0; ready_inputs[node] = 0;
for (const auto& input : node.input()) { for (const auto& input : node->input()) {
if (IsNextIteration(*node_map.GetNode(input))) { if (IsNextIteration(*node_map.GetNode(input))) {
ready_inputs[&node]++; ready_inputs[node]++;
} }
} }
} else { } else {
ready_inputs[&node] = 0; ready_inputs[node] = 0;
} }
} }
GraphDef sorted_graph;
while (!ready_nodes.empty()) { while (front != back) {
auto ready_node = ready_nodes.front(); auto ready_node = ready_nodes[front];
*sorted_graph.add_node() = *ready_node;
for (const auto& fanout : node_map.GetOutputs(ready_node->name())) { for (const auto& fanout : node_map.GetOutputs(ready_node->name())) {
ready_inputs[fanout]++; ready_inputs[fanout]++;
if (ready_inputs[fanout] == fanout->input_size()) { if (ready_inputs[fanout] == fanout->input_size()) {
ready_nodes.push_back(fanout); ready_nodes.push_back(fanout);
back++;
} }
} }
ready_nodes.pop_front(); front++;
} }
if (sorted_graph.node_size() == graph->node_size()) {
graph->mutable_node()->Swap(sorted_graph.mutable_node()); if (back == graph->node_size()) {
GraphDef new_graph;
new_graph.mutable_node()->Reserve(graph->node_size());
for (int i = 0; i < graph->node_size(); i++) {
auto new_node = new_graph.add_node();
new_node->Swap(ready_nodes[i]);
}
graph->mutable_node()->Swap(new_graph.mutable_node());
} }
} }