mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Small improvements to the arithmetic optimizer
PiperOrigin-RevId: 165760972
This commit is contained in:
parent
b6409594d3
commit
a271c37db3
|
|
@ -177,7 +177,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
|
|||
if (rep == node) {
|
||||
continue;
|
||||
}
|
||||
const std::set<NodeDef*> fanouts = map.GetOutputs(node->name());
|
||||
const std::set<NodeDef*>& fanouts = map.GetOutputs(node->name());
|
||||
for (NodeDef* fanout : fanouts) {
|
||||
for (string& name : *fanout->mutable_input()) {
|
||||
int position;
|
||||
|
|
@ -190,7 +190,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const {
|
|||
} else {
|
||||
name = strings::StrCat("^", rep->name());
|
||||
}
|
||||
map.UpdateOutput(nodename, fanout->name(), name);
|
||||
map.AddOutput(rep->name(), fanout->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,8 +61,9 @@ void NodeMap::AddOutput(const string& node, const string& output) {
|
|||
|
||||
void NodeMap::UpdateOutput(const string& node, const string& old_output,
|
||||
const string& new_output) {
|
||||
outputs_[node].erase(nodes_[old_output]);
|
||||
outputs_[node].insert(nodes_[new_output]);
|
||||
std::set<NodeDef*>& outputs = outputs_[node];
|
||||
outputs.erase(nodes_[old_output]);
|
||||
outputs.insert(nodes_[new_output]);
|
||||
}
|
||||
|
||||
bool IsSameInput(const string& name1, const string& name2) {
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
|
|||
rewritten_graph_def = tf_optimizer.OptimizeGraph(
|
||||
rewriter_config_pb2.RewriterConfig(
|
||||
disable_model_pruning=True,
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS),
|
||||
original_metagraph)
|
||||
self.assertGreater(
|
||||
|
|
@ -146,6 +147,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
|
|||
rewritten_graph_def = tf_optimizer.OptimizeGraph(
|
||||
rewriter_config_pb2.RewriterConfig(
|
||||
disable_model_pruning=True,
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS,
|
||||
memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
|
||||
original_metagraph)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user