mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Grappler memory optimization: allow inputs to gradients with non-standard names to be recomputed
Includes Python tests for name-scoped gradients. PiperOrigin-RevId: 163720208
This commit is contained in:
parent
b876065afe
commit
6263539a15
|
|
@ -42,7 +42,6 @@ const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
|
||||||
// Attribute which may be added to nodes to manually allow them to be
|
// Attribute which may be added to nodes to manually allow them to be
|
||||||
// recomputed.
|
// recomputed.
|
||||||
const char* kRecomputeHint = "_recompute_hint";
|
const char* kRecomputeHint = "_recompute_hint";
|
||||||
const char* kRecomputationTargetNamePrefix = "gradients/";
|
|
||||||
|
|
||||||
// Ops which we wouldn't mind recomputing to save memory.
|
// Ops which we wouldn't mind recomputing to save memory.
|
||||||
// TODO(allenl): Replace this list with a cost model.
|
// TODO(allenl): Replace this list with a cost model.
|
||||||
|
|
@ -57,18 +56,11 @@ std::unordered_set<string> GetCheapToRecomputeOps() {
|
||||||
return cheap_ops;
|
return cheap_ops;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nodes whose inputs we may want to recompute (i.e. gradients).
|
|
||||||
// TODO(allenl): Rather than blindly recomputing gradient inputs, use a static
|
|
||||||
// schedule (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
|
|
||||||
// whose outputs will sit around for a while.
|
|
||||||
bool IsTargetOp(const NodeDef& node) {
|
|
||||||
return node.name().find(kRecomputationTargetNamePrefix) == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find recomputable ops which feed into target nodes.
|
// Find recomputable ops which feed into target nodes.
|
||||||
std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
|
std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
|
||||||
const NodeMap& node_map, const GraphDef* graph,
|
const NodeMap& node_map, const GraphDef* graph,
|
||||||
const std::function<bool(const NodeDef&)>& is_candidate) {
|
const std::function<bool(const NodeDef&)>& is_candidate,
|
||||||
|
const std::function<bool(const NodeDef&)>& is_target) {
|
||||||
std::unordered_set<const NodeDef*> candidate_recompute_nodes;
|
std::unordered_set<const NodeDef*> candidate_recompute_nodes;
|
||||||
for (const auto& node : graph->node()) {
|
for (const auto& node : graph->node()) {
|
||||||
if (!is_candidate(node)) {
|
if (!is_candidate(node)) {
|
||||||
|
|
@ -78,7 +70,7 @@ std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
|
||||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||||
// It only makes sense to recompute this if it feeds into a target
|
// It only makes sense to recompute this if it feeds into a target
|
||||||
// node. We expand this to dependencies in GetOpGroupsToRecompute.
|
// node. We expand this to dependencies in GetOpGroupsToRecompute.
|
||||||
if (IsTargetOp(*output)) {
|
if (is_target(*output)) {
|
||||||
has_target_output = true;
|
has_target_output = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -90,7 +82,7 @@ std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
|
||||||
for (const string& input_name : node.input()) {
|
for (const string& input_name : node.input()) {
|
||||||
// Don't recompute nodes which depend on target nodes.
|
// Don't recompute nodes which depend on target nodes.
|
||||||
const NodeDef* input_node = node_map.GetNode(input_name);
|
const NodeDef* input_node = node_map.GetNode(input_name);
|
||||||
if (IsTargetOp(*input_node)) {
|
if (is_target(*input_node)) {
|
||||||
has_target_input = true;
|
has_target_input = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -147,11 +139,12 @@ struct RecomputedSubGraph {
|
||||||
// Find groups of ops to recompute together based on `should_recompute`.
|
// Find groups of ops to recompute together based on `should_recompute`.
|
||||||
std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
|
std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
|
||||||
const GraphDef* graph, const NodeMap& node_map,
|
const GraphDef* graph, const NodeMap& node_map,
|
||||||
const std::function<bool(const NodeDef&)>& should_recompute) {
|
const std::function<bool(const NodeDef&)>& should_recompute,
|
||||||
|
const std::function<bool(const NodeDef&)>& is_target) {
|
||||||
std::unordered_set<const NodeDef*> visited_nodes;
|
std::unordered_set<const NodeDef*> visited_nodes;
|
||||||
std::vector<RecomputedSubGraph> subgraphs_to_recompute;
|
std::vector<RecomputedSubGraph> subgraphs_to_recompute;
|
||||||
std::unordered_set<const NodeDef*> candidate_recompute_nodes =
|
std::unordered_set<const NodeDef*> candidate_recompute_nodes =
|
||||||
FindCandidateRecomputeNodes(node_map, graph, should_recompute);
|
FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
|
||||||
for (const NodeDef* recompute_node : candidate_recompute_nodes) {
|
for (const NodeDef* recompute_node : candidate_recompute_nodes) {
|
||||||
if (visited_nodes.count(recompute_node) > 0) {
|
if (visited_nodes.count(recompute_node) > 0) {
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -171,7 +164,7 @@ std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
|
||||||
for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
|
for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
|
||||||
bool inserted_feed = false;
|
bool inserted_feed = false;
|
||||||
for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
|
for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
|
||||||
if (IsTargetOp(*output)) {
|
if (is_target(*output)) {
|
||||||
current_recomputation.target_nodes.insert(output);
|
current_recomputation.target_nodes.insert(output);
|
||||||
if (!inserted_feed) {
|
if (!inserted_feed) {
|
||||||
// Keep track of nodes which feed directly into a target node. These
|
// Keep track of nodes which feed directly into a target node. These
|
||||||
|
|
@ -416,6 +409,7 @@ void RecomputeSubgraph(
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||||
|
const string& recomputation_targets_name_prefix,
|
||||||
GraphDef* graph, const GrapplerItem& item) {
|
GraphDef* graph, const GrapplerItem& item) {
|
||||||
// The topological numberings and NodeMap will be stale as soon as we start
|
// The topological numberings and NodeMap will be stale as soon as we start
|
||||||
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
|
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
|
||||||
|
|
@ -433,6 +427,17 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||||
for (const auto& feed : item.feed) {
|
for (const auto& feed : item.feed) {
|
||||||
feeds.insert(NodeName(feed.first));
|
feeds.insert(NodeName(feed.first));
|
||||||
}
|
}
|
||||||
|
std::function<bool(const NodeDef&)> is_target =
|
||||||
|
[&recomputation_targets_name_prefix](const NodeDef& node) {
|
||||||
|
// Nodes whose inputs we may want to recompute. Typically targets will
|
||||||
|
// be gradients (recomputation_targets_name_prefix="gradients/"),
|
||||||
|
// although the prefix is configurable since gradients may be created in
|
||||||
|
// a name scope.
|
||||||
|
// TODO(allenl): Use a static schedule
|
||||||
|
// (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
|
||||||
|
// whose outputs will sit around for a while.
|
||||||
|
return node.name().find(recomputation_targets_name_prefix) == 0;
|
||||||
|
};
|
||||||
if (optimization_level == RewriterConfig::HEURISTICS) {
|
if (optimization_level == RewriterConfig::HEURISTICS) {
|
||||||
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
|
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
|
||||||
// the cheap forward ops get grouped into a single subgraph which must
|
// the cheap forward ops get grouped into a single subgraph which must
|
||||||
|
|
@ -442,17 +447,20 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||||
GetCheapToRecomputeOps();
|
GetCheapToRecomputeOps();
|
||||||
recomputed_subgraphs = GetOpGroupsToRecompute(
|
recomputed_subgraphs = GetOpGroupsToRecompute(
|
||||||
graph, node_map,
|
graph, node_map,
|
||||||
[&cheap_to_recompute_ops, &feeds](const NodeDef& node) {
|
[&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) {
|
||||||
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&
|
return !is_target(node) && feeds.count(node.name()) == 0 &&
|
||||||
(cheap_to_recompute_ops.count(node.op()) > 0 ||
|
(cheap_to_recompute_ops.count(node.op()) > 0 ||
|
||||||
node.attr().count(kRecomputeHint) > 0);
|
node.attr().count(kRecomputeHint) > 0);
|
||||||
});
|
},
|
||||||
|
is_target);
|
||||||
} else if (optimization_level == RewriterConfig::MANUAL) {
|
} else if (optimization_level == RewriterConfig::MANUAL) {
|
||||||
recomputed_subgraphs =
|
recomputed_subgraphs = GetOpGroupsToRecompute(
|
||||||
GetOpGroupsToRecompute(graph, node_map, [&feeds](const NodeDef& node) {
|
graph, node_map,
|
||||||
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&
|
[&feeds, &is_target](const NodeDef& node) {
|
||||||
|
return !is_target(node) && feeds.count(node.name()) == 0 &&
|
||||||
node.attr().count(kRecomputeHint) > 0;
|
node.attr().count(kRecomputeHint) > 0;
|
||||||
});
|
},
|
||||||
|
is_target);
|
||||||
}
|
}
|
||||||
if (!recomputed_subgraphs.empty()) {
|
if (!recomputed_subgraphs.empty()) {
|
||||||
std::unordered_map<const NodeDef*, int> topological_numbering;
|
std::unordered_map<const NodeDef*, int> topological_numbering;
|
||||||
|
|
@ -598,7 +606,9 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
*optimized_graph = item.graph;
|
*optimized_graph = item.graph;
|
||||||
|
|
||||||
RecomputationRewritingPass(optimization_level_, optimized_graph, item);
|
RecomputationRewritingPass(optimization_level_,
|
||||||
|
recomputation_targets_name_prefix_,
|
||||||
|
optimized_graph, item);
|
||||||
|
|
||||||
// Figure out what needs to be swapped;
|
// Figure out what needs to be swapped;
|
||||||
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
|
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,16 @@ namespace grappler {
|
||||||
// Swap tensors in and out of device memory.
|
// Swap tensors in and out of device memory.
|
||||||
class MemoryOptimizer : public GraphOptimizer {
|
class MemoryOptimizer : public GraphOptimizer {
|
||||||
public:
|
public:
|
||||||
explicit MemoryOptimizer(RewriterConfig::MemOptType optimization_level)
|
// optimization_level: Controls the level of autonomy for the memory
|
||||||
: optimization_level_(optimization_level) {}
|
// optimizer. See RewriterConfig::memory_optimization.
|
||||||
|
// recomputation_targets_name_prefix: Name prefix for potential outputs of
|
||||||
|
// recomputations. See
|
||||||
|
// RewriterConfig::memory_optimizer_target_node_name_prefix.
|
||||||
|
explicit MemoryOptimizer(
|
||||||
|
RewriterConfig::MemOptType optimization_level,
|
||||||
|
const string& recomputation_targets_name_prefix = "gradients/")
|
||||||
|
: optimization_level_(optimization_level),
|
||||||
|
recomputation_targets_name_prefix_(recomputation_targets_name_prefix) {}
|
||||||
~MemoryOptimizer() override {}
|
~MemoryOptimizer() override {}
|
||||||
|
|
||||||
string name() const override { return "memory_optimizer"; };
|
string name() const override { return "memory_optimizer"; };
|
||||||
|
|
@ -39,6 +47,7 @@ class MemoryOptimizer : public GraphOptimizer {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
RewriterConfig::MemOptType optimization_level_;
|
RewriterConfig::MemOptType optimization_level_;
|
||||||
|
string recomputation_targets_name_prefix_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,8 @@ TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
|
||||||
|
|
||||||
Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
|
Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
|
||||||
Output b = ops::Variable(s.WithOpName("b"), {2, 3, 4}, DT_FLOAT);
|
Output b = ops::Variable(s.WithOpName("b"), {2, 3, 4}, DT_FLOAT);
|
||||||
Output d = ops::AddN(s.WithOpName("gradients/two_subgraph_inputs"), {a, b});
|
Output d = ops::AddN(
|
||||||
|
s.WithOpName("some_name_scope/gradients/two_subgraph_inputs"), {a, b});
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
@ -112,7 +113,8 @@ TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
|
||||||
(*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
|
(*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
|
||||||
.set_i(0);
|
.set_i(0);
|
||||||
|
|
||||||
MemoryOptimizer optimizer(RewriterConfig::MANUAL);
|
MemoryOptimizer optimizer(RewriterConfig::MANUAL,
|
||||||
|
"some_name_scope/gradients");
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,8 +67,16 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
|
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
|
||||||
}
|
}
|
||||||
if (cfg_.memory_optimization() > 1) {
|
if (cfg_.memory_optimization() > 1) {
|
||||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
|
||||||
new MemoryOptimizer(cfg_.memory_optimization())));
|
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||||
|
// Use the default target node name prefix "gradients/"
|
||||||
|
new MemoryOptimizer(cfg_.memory_optimization())));
|
||||||
|
} else {
|
||||||
|
optimizers.push_back(
|
||||||
|
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
|
||||||
|
cfg_.memory_optimization(),
|
||||||
|
cfg_.memory_optimizer_target_node_name_prefix())));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (cfg_.auto_parallel().enable()) {
|
if (cfg_.auto_parallel().enable()) {
|
||||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||||
|
|
|
||||||
|
|
@ -41,13 +41,25 @@ message RewriterConfig {
|
||||||
// Driven by manual op-level annotations.
|
// Driven by manual op-level annotations.
|
||||||
MANUAL = 2;
|
MANUAL = 2;
|
||||||
// Driven by heuristics. The behavior of these heuristics is subject to
|
// Driven by heuristics. The behavior of these heuristics is subject to
|
||||||
// change. Currently includes an experimental recomputation heuristic.
|
// change. Currently includes an experimental recomputation
|
||||||
|
// heuristic. Manual annotations are respected, but additional nodes are
|
||||||
|
// selected automatically.
|
||||||
HEURISTICS = 3;
|
HEURISTICS = 3;
|
||||||
}
|
}
|
||||||
// Configures memory optimization passes through the meta-optimizer. Has no
|
// Configures memory optimization passes through the meta-optimizer. Has no
|
||||||
// effect on manually requested memory optimization passes in the optimizers
|
// effect on manually requested memory optimization passes in the optimizers
|
||||||
// field.
|
// field.
|
||||||
MemOptType memory_optimization = 4;
|
MemOptType memory_optimization = 4;
|
||||||
|
// The prefix for nodes which are valid outputs of recomputations. Inputs to
|
||||||
|
// nodes with this name prefix may be recomputed (subject either to manual
|
||||||
|
// annotation of those input nodes or to manual annotation and heuristics
|
||||||
|
// depending on memory_optimization), but the prefixed nodes themselves will
|
||||||
|
// not be recomputed. Typically this will be "gradients/", indicating that
|
||||||
|
// activations from the forward pass of a graph may be recomputed as inputs to
|
||||||
|
// gradients, but may be adjusted if gradients are inside a name scope or if
|
||||||
|
// inputs to non-gradients should be recomputed. Defaults to "gradients/" if
|
||||||
|
// empty or not set.
|
||||||
|
string memory_optimizer_target_node_name_prefix = 6;
|
||||||
|
|
||||||
// Configures AutoParallel optimization passes either through the
|
// Configures AutoParallel optimization passes either through the
|
||||||
// meta-optimizer or when manually specified through the optimizers field.
|
// meta-optimizer or when manually specified through the optimizers field.
|
||||||
|
|
|
||||||
|
|
@ -89,9 +89,13 @@ class MemoryOptimizerSwapTest(test.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class MemoryOptimizerRecomputeTest(test.TestCase):
|
class MemoryOptimizerRecomputeTest(test.TestCase):
|
||||||
|
"""Tests the Python interface to recomputation rewrites.
|
||||||
|
|
||||||
def _RunGraphWithConfig(self, config, batch_size=14, image_dim=12):
|
See core/grappler/optimizers/memory_optimizer_test.cc for functional tests.
|
||||||
"""Run a simple layered graph with conv, an intermediate op, and a ReLU."""
|
"""
|
||||||
|
|
||||||
|
def _GetMetaGraph(self, batch_size=14, image_dim=12, optimizer_scope_name=''):
|
||||||
|
"""A simple layered graph with conv, an intermediate op, and a ReLU."""
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
random_seed.set_random_seed(1)
|
random_seed.set_random_seed(1)
|
||||||
|
|
@ -106,31 +110,89 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
|
||||||
current_activation = 2. * after_conv
|
current_activation = 2. * after_conv
|
||||||
current_activation = nn.relu(current_activation)
|
current_activation = nn.relu(current_activation)
|
||||||
loss = math_ops.reduce_mean(current_activation)
|
loss = math_ops.reduce_mean(current_activation)
|
||||||
optimizer = train.AdamOptimizer(0.001)
|
with ops.name_scope(optimizer_scope_name):
|
||||||
train_op = optimizer.minimize(loss)
|
optimizer = train.AdamOptimizer(0.001)
|
||||||
|
train_op = optimizer.minimize(loss)
|
||||||
init_op = variables.global_variables_initializer()
|
init_op = variables.global_variables_initializer()
|
||||||
with session.Session(config=config, graph=graph) as sess:
|
metagraph = train.export_meta_graph()
|
||||||
sess.run(init_op)
|
return (metagraph, init_op.name, train_op.name, loss.name)
|
||||||
sess.run(train_op)
|
|
||||||
sess.run(train_op)
|
|
||||||
return sess.run(loss)
|
|
||||||
|
|
||||||
def _GetMemoryOptimizerConfig(self):
|
def testRewritingDefaultGradientNames(self):
|
||||||
|
"""Tests that rewriting occurs with default gradient names."""
|
||||||
|
(original_metagraph, _, _, _) = self._GetMetaGraph()
|
||||||
|
rewritten_graph_def = tf_optimizer.OptimizeGraph(
|
||||||
|
rewriter_config_pb2.RewriterConfig(
|
||||||
|
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS),
|
||||||
|
original_metagraph)
|
||||||
|
self.assertGreater(
|
||||||
|
len(rewritten_graph_def.node),
|
||||||
|
len(original_metagraph.graph_def.node))
|
||||||
|
self.assertEqual(
|
||||||
|
0,
|
||||||
|
len([node for node in original_metagraph.graph_def.node
|
||||||
|
if 'Recomputed/' in node.name]))
|
||||||
|
self.assertEqual(
|
||||||
|
20, # Two per layer
|
||||||
|
len([node for node in rewritten_graph_def.node
|
||||||
|
if 'Recomputed/' in node.name]))
|
||||||
|
|
||||||
|
def testRewritingNameScopedGradientNames(self):
|
||||||
|
"""Tests that rewriting occurs with non-standard gradient names."""
|
||||||
|
(original_metagraph, _, _, _) = self._GetMetaGraph(
|
||||||
|
optimizer_scope_name='optimizer')
|
||||||
|
rewritten_graph_def = tf_optimizer.OptimizeGraph(
|
||||||
|
rewriter_config_pb2.RewriterConfig(
|
||||||
|
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS,
|
||||||
|
memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
|
||||||
|
original_metagraph)
|
||||||
|
self.assertGreater(
|
||||||
|
len(rewritten_graph_def.node),
|
||||||
|
len(original_metagraph.graph_def.node))
|
||||||
|
self.assertEqual(
|
||||||
|
0,
|
||||||
|
len([node for node in original_metagraph.graph_def.node
|
||||||
|
if 'Recomputed/' in node.name]))
|
||||||
|
self.assertEqual(
|
||||||
|
20, # Two per layer
|
||||||
|
len([node for node in rewritten_graph_def.node
|
||||||
|
if 'Recomputed/' in node.name]))
|
||||||
|
|
||||||
|
def _GetMemoryOptimizerSessionConfig(self):
|
||||||
rewrite_options = rewriter_config_pb2.RewriterConfig(
|
rewrite_options = rewriter_config_pb2.RewriterConfig(
|
||||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS)
|
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS)
|
||||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options)
|
graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options)
|
||||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||||
|
|
||||||
def testRecomputationRewritingNoErrors(self):
|
def _RunMetaGraphWithConfig(
|
||||||
"""Tests that there are no errors when we request a memory optimizer pass.
|
self, config, metagraph, init_op_name, train_op_name, loss_op_name):
|
||||||
|
graph = ops.Graph()
|
||||||
|
with graph.as_default():
|
||||||
|
train.import_meta_graph(metagraph)
|
||||||
|
init_op = graph.get_operation_by_name(init_op_name)
|
||||||
|
train_op = graph.get_operation_by_name(train_op_name)
|
||||||
|
loss_op = graph.get_tensor_by_name(loss_op_name)
|
||||||
|
with session.Session(config=config, graph=graph) as sess:
|
||||||
|
sess.run(init_op)
|
||||||
|
sess.run(train_op)
|
||||||
|
sess.run(train_op)
|
||||||
|
return sess.run(loss_op)
|
||||||
|
|
||||||
Does not test that the memory optimizer actually runs. See
|
def testRecomputationRewritingNoErrors(self):
|
||||||
core/grappler/optimizers/memory_optimizer_test.cc for a functional test of
|
"""Tests that graph output is not significantly different with rewriting."""
|
||||||
the graph rewriting.
|
(original_metagraph, init_op_name, train_op_name, loss_op_name
|
||||||
"""
|
) = self._GetMetaGraph()
|
||||||
original_loss = self._RunGraphWithConfig(config_pb2.ConfigProto())
|
original_loss = self._RunMetaGraphWithConfig(
|
||||||
memory_optimized_loss = self._RunGraphWithConfig(
|
config=config_pb2.ConfigProto(),
|
||||||
config=self._GetMemoryOptimizerConfig())
|
metagraph=original_metagraph,
|
||||||
|
init_op_name=init_op_name,
|
||||||
|
train_op_name=train_op_name,
|
||||||
|
loss_op_name=loss_op_name)
|
||||||
|
memory_optimized_loss = self._RunMetaGraphWithConfig(
|
||||||
|
config=self._GetMemoryOptimizerSessionConfig(),
|
||||||
|
metagraph=original_metagraph,
|
||||||
|
init_op_name=init_op_name,
|
||||||
|
train_op_name=train_op_name,
|
||||||
|
loss_op_name=loss_op_name)
|
||||||
self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
|
self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user