mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Speed up topological sort by avoiding copies. The speedup is about 10-20%.
PiperOrigin-RevId: 163800134
This commit is contained in:
parent
6446895aa6
commit
618f913bbd
|
|
@ -27,37 +27,49 @@ namespace grappler {
|
||||||
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
|
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
|
||||||
void TopologicalSort(GraphDef* graph) {
|
void TopologicalSort(GraphDef* graph) {
|
||||||
NodeMap node_map(graph);
|
NodeMap node_map(graph);
|
||||||
std::deque<const NodeDef*> ready_nodes;
|
std::vector<NodeDef*> ready_nodes;
|
||||||
|
ready_nodes.reserve(graph->node_size());
|
||||||
|
int front = 0;
|
||||||
|
int back = 0;
|
||||||
std::unordered_map<const NodeDef*, int> ready_inputs;
|
std::unordered_map<const NodeDef*, int> ready_inputs;
|
||||||
for (const NodeDef& node : graph->node()) {
|
for (int i = 0; i < graph->node_size(); i++) {
|
||||||
if (node.input_size() == 0) {
|
auto node = graph->mutable_node(i);
|
||||||
ready_nodes.push_back(&node);
|
if (node->input_size() == 0) {
|
||||||
|
ready_nodes.push_back(node);
|
||||||
|
back++;
|
||||||
}
|
}
|
||||||
if (node.op() == "Merge") {
|
if (IsMerge(*node)) {
|
||||||
ready_inputs[&node] = 0;
|
ready_inputs[node] = 0;
|
||||||
for (const auto& input : node.input()) {
|
for (const auto& input : node->input()) {
|
||||||
if (IsNextIteration(*node_map.GetNode(input))) {
|
if (IsNextIteration(*node_map.GetNode(input))) {
|
||||||
ready_inputs[&node]++;
|
ready_inputs[node]++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ready_inputs[&node] = 0;
|
ready_inputs[node] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GraphDef sorted_graph;
|
|
||||||
while (!ready_nodes.empty()) {
|
while (front != back) {
|
||||||
auto ready_node = ready_nodes.front();
|
auto ready_node = ready_nodes[front];
|
||||||
*sorted_graph.add_node() = *ready_node;
|
|
||||||
for (const auto& fanout : node_map.GetOutputs(ready_node->name())) {
|
for (const auto& fanout : node_map.GetOutputs(ready_node->name())) {
|
||||||
ready_inputs[fanout]++;
|
ready_inputs[fanout]++;
|
||||||
if (ready_inputs[fanout] == fanout->input_size()) {
|
if (ready_inputs[fanout] == fanout->input_size()) {
|
||||||
ready_nodes.push_back(fanout);
|
ready_nodes.push_back(fanout);
|
||||||
|
back++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ready_nodes.pop_front();
|
front++;
|
||||||
}
|
}
|
||||||
if (sorted_graph.node_size() == graph->node_size()) {
|
|
||||||
graph->mutable_node()->Swap(sorted_graph.mutable_node());
|
if (back == graph->node_size()) {
|
||||||
|
GraphDef new_graph;
|
||||||
|
new_graph.mutable_node()->Reserve(graph->node_size());
|
||||||
|
for (int i = 0; i < graph->node_size(); i++) {
|
||||||
|
auto new_node = new_graph.add_node();
|
||||||
|
new_node->Swap(ready_nodes[i]);
|
||||||
|
}
|
||||||
|
graph->mutable_node()->Swap(new_graph.mutable_node());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user