mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Speed up topological sort by avoiding copies. The speedup is about 10-20%.
PiperOrigin-RevId: 163800134
This commit is contained in:
parent
6446895aa6
commit
618f913bbd
|
|
@ -27,37 +27,49 @@ namespace grappler {
|
|||
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
|
||||
void TopologicalSort(GraphDef* graph) {
|
||||
NodeMap node_map(graph);
|
||||
std::deque<const NodeDef*> ready_nodes;
|
||||
std::vector<NodeDef*> ready_nodes;
|
||||
ready_nodes.reserve(graph->node_size());
|
||||
int front = 0;
|
||||
int back = 0;
|
||||
std::unordered_map<const NodeDef*, int> ready_inputs;
|
||||
for (const NodeDef& node : graph->node()) {
|
||||
if (node.input_size() == 0) {
|
||||
ready_nodes.push_back(&node);
|
||||
for (int i = 0; i < graph->node_size(); i++) {
|
||||
auto node = graph->mutable_node(i);
|
||||
if (node->input_size() == 0) {
|
||||
ready_nodes.push_back(node);
|
||||
back++;
|
||||
}
|
||||
if (node.op() == "Merge") {
|
||||
ready_inputs[&node] = 0;
|
||||
for (const auto& input : node.input()) {
|
||||
if (IsMerge(*node)) {
|
||||
ready_inputs[node] = 0;
|
||||
for (const auto& input : node->input()) {
|
||||
if (IsNextIteration(*node_map.GetNode(input))) {
|
||||
ready_inputs[&node]++;
|
||||
ready_inputs[node]++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ready_inputs[&node] = 0;
|
||||
ready_inputs[node] = 0;
|
||||
}
|
||||
}
|
||||
GraphDef sorted_graph;
|
||||
while (!ready_nodes.empty()) {
|
||||
auto ready_node = ready_nodes.front();
|
||||
*sorted_graph.add_node() = *ready_node;
|
||||
|
||||
while (front != back) {
|
||||
auto ready_node = ready_nodes[front];
|
||||
for (const auto& fanout : node_map.GetOutputs(ready_node->name())) {
|
||||
ready_inputs[fanout]++;
|
||||
if (ready_inputs[fanout] == fanout->input_size()) {
|
||||
ready_nodes.push_back(fanout);
|
||||
back++;
|
||||
}
|
||||
}
|
||||
ready_nodes.pop_front();
|
||||
front++;
|
||||
}
|
||||
if (sorted_graph.node_size() == graph->node_size()) {
|
||||
graph->mutable_node()->Swap(sorted_graph.mutable_node());
|
||||
|
||||
if (back == graph->node_size()) {
|
||||
GraphDef new_graph;
|
||||
new_graph.mutable_node()->Reserve(graph->node_size());
|
||||
for (int i = 0; i < graph->node_size(); i++) {
|
||||
auto new_node = new_graph.add_node();
|
||||
new_node->Swap(ready_nodes[i]);
|
||||
}
|
||||
graph->mutable_node()->Swap(new_graph.mutable_node());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user