Small improvements to the arithmetic optimizer

PiperOrigin-RevId: 165760972
This commit is contained in:
Benoit Steiner 2017-08-18 15:24:50 -07:00 committed by TensorFlower Gardener
parent b6409594d3
commit a271c37db3
3 changed files with 7 additions and 4 deletions

View File

@ -177,7 +177,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
if (rep == node) {
continue;
}
const std::set<NodeDef*> fanouts = map.GetOutputs(node->name());
const std::set<NodeDef*>& fanouts = map.GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
for (string& name : *fanout->mutable_input()) {
int position;
@ -190,7 +190,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
} else {
name = strings::StrCat("^", rep->name());
}
map.UpdateOutput(nodename, fanout->name(), name);
map.AddOutput(rep->name(), fanout->name());
}
}
}

View File

@ -61,8 +61,9 @@ void NodeMap::AddOutput(const string& node, const string& output) {
void NodeMap::UpdateOutput(const string& node, const string& old_output,
const string& new_output) {
outputs_[node].erase(nodes_[old_output]);
outputs_[node].insert(nodes_[new_output]);
std::set<NodeDef*>& outputs = outputs_[node];
outputs.erase(nodes_[old_output]);
outputs.insert(nodes_[new_output]);
}
bool IsSameInput(const string& name1, const string& name2) {

View File

@ -125,6 +125,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
rewritten_graph_def = tf_optimizer.OptimizeGraph(
rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS),
original_metagraph)
self.assertGreater(
@ -146,6 +147,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
rewritten_graph_def = tf_optimizer.OptimizeGraph(
rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS,
memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
original_metagraph)