Grappler memory optimization: allow inputs to gradients with non-standard names to be recomputed

Includes Python tests for name-scoped gradients.

PiperOrigin-RevId: 163720208
This commit is contained in:
Allen Lavoie 2017-07-31 11:17:49 -07:00 committed by TensorFlower Gardener
parent b876065afe
commit 6263539a15
6 changed files with 152 additions and 49 deletions

View File

@ -42,7 +42,6 @@ const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
// Attribute which may be added to nodes to manually allow them to be // Attribute which may be added to nodes to manually allow them to be
// recomputed. // recomputed.
const char* kRecomputeHint = "_recompute_hint"; const char* kRecomputeHint = "_recompute_hint";
const char* kRecomputationTargetNamePrefix = "gradients/";
// Ops which we wouldn't mind recomputing to save memory. // Ops which we wouldn't mind recomputing to save memory.
// TODO(allenl): Replace this list with a cost model. // TODO(allenl): Replace this list with a cost model.
@ -57,18 +56,11 @@ std::unordered_set<string> GetCheapToRecomputeOps() {
return cheap_ops; return cheap_ops;
} }
// Nodes whose inputs we may want to recompute (i.e. gradients).
// TODO(allenl): Rather than blindly recomputing gradient inputs, use a static
// schedule (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
// whose outputs will sit around for a while.
bool IsTargetOp(const NodeDef& node) {
return node.name().find(kRecomputationTargetNamePrefix) == 0;
}
// Find recomputable ops which feed into target nodes. // Find recomputable ops which feed into target nodes.
std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes( std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
const NodeMap& node_map, const GraphDef* graph, const NodeMap& node_map, const GraphDef* graph,
const std::function<bool(const NodeDef&)>& is_candidate) { const std::function<bool(const NodeDef&)>& is_candidate,
const std::function<bool(const NodeDef&)>& is_target) {
std::unordered_set<const NodeDef*> candidate_recompute_nodes; std::unordered_set<const NodeDef*> candidate_recompute_nodes;
for (const auto& node : graph->node()) { for (const auto& node : graph->node()) {
if (!is_candidate(node)) { if (!is_candidate(node)) {
@ -78,7 +70,7 @@ std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
for (const NodeDef* output : node_map.GetOutputs(node.name())) { for (const NodeDef* output : node_map.GetOutputs(node.name())) {
// It only makes sense to recompute this if it feeds into a target // It only makes sense to recompute this if it feeds into a target
// node. We expand this to dependencies in GetOpGroupsToRecompute. // node. We expand this to dependencies in GetOpGroupsToRecompute.
if (IsTargetOp(*output)) { if (is_target(*output)) {
has_target_output = true; has_target_output = true;
break; break;
} }
@ -90,7 +82,7 @@ std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
for (const string& input_name : node.input()) { for (const string& input_name : node.input()) {
// Don't recompute nodes which depend on target nodes. // Don't recompute nodes which depend on target nodes.
const NodeDef* input_node = node_map.GetNode(input_name); const NodeDef* input_node = node_map.GetNode(input_name);
if (IsTargetOp(*input_node)) { if (is_target(*input_node)) {
has_target_input = true; has_target_input = true;
break; break;
} }
@ -147,11 +139,12 @@ struct RecomputedSubGraph {
// Find groups of ops to recompute together based on `should_recompute`. // Find groups of ops to recompute together based on `should_recompute`.
std::vector<RecomputedSubGraph> GetOpGroupsToRecompute( std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
const GraphDef* graph, const NodeMap& node_map, const GraphDef* graph, const NodeMap& node_map,
const std::function<bool(const NodeDef&)>& should_recompute) { const std::function<bool(const NodeDef&)>& should_recompute,
const std::function<bool(const NodeDef&)>& is_target) {
std::unordered_set<const NodeDef*> visited_nodes; std::unordered_set<const NodeDef*> visited_nodes;
std::vector<RecomputedSubGraph> subgraphs_to_recompute; std::vector<RecomputedSubGraph> subgraphs_to_recompute;
std::unordered_set<const NodeDef*> candidate_recompute_nodes = std::unordered_set<const NodeDef*> candidate_recompute_nodes =
FindCandidateRecomputeNodes(node_map, graph, should_recompute); FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
for (const NodeDef* recompute_node : candidate_recompute_nodes) { for (const NodeDef* recompute_node : candidate_recompute_nodes) {
if (visited_nodes.count(recompute_node) > 0) { if (visited_nodes.count(recompute_node) > 0) {
continue; continue;
@ -171,7 +164,7 @@ std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
for (const NodeDef* recompute_node : unpruned_recompute_nodes) { for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
bool inserted_feed = false; bool inserted_feed = false;
for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) { for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
if (IsTargetOp(*output)) { if (is_target(*output)) {
current_recomputation.target_nodes.insert(output); current_recomputation.target_nodes.insert(output);
if (!inserted_feed) { if (!inserted_feed) {
// Keep track of nodes which feed directly into a target node. These // Keep track of nodes which feed directly into a target node. These
@ -416,6 +409,7 @@ void RecomputeSubgraph(
} }
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
const string& recomputation_targets_name_prefix,
GraphDef* graph, const GrapplerItem& item) { GraphDef* graph, const GrapplerItem& item) {
// The topological numberings and NodeMap will be stale as soon as we start // The topological numberings and NodeMap will be stale as soon as we start
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
@ -433,6 +427,17 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
for (const auto& feed : item.feed) { for (const auto& feed : item.feed) {
feeds.insert(NodeName(feed.first)); feeds.insert(NodeName(feed.first));
} }
std::function<bool(const NodeDef&)> is_target =
[&recomputation_targets_name_prefix](const NodeDef& node) {
// Nodes whose inputs we may want to recompute. Typically targets will
// be gradients (recomputation_targets_name_prefix="gradients/"),
// although the prefix is configurable since gradients may be created in
// a name scope.
// TODO(allenl): Use a static schedule
// (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
// whose outputs will sit around for a while.
return node.name().find(recomputation_targets_name_prefix) == 0;
};
if (optimization_level == RewriterConfig::HEURISTICS) { if (optimization_level == RewriterConfig::HEURISTICS) {
// TODO(allenl): Handle ResNet-like architectures better. Right now all of // TODO(allenl): Handle ResNet-like architectures better. Right now all of
// the cheap forward ops get grouped into a single subgraph which must // the cheap forward ops get grouped into a single subgraph which must
@ -442,17 +447,20 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
GetCheapToRecomputeOps(); GetCheapToRecomputeOps();
recomputed_subgraphs = GetOpGroupsToRecompute( recomputed_subgraphs = GetOpGroupsToRecompute(
graph, node_map, graph, node_map,
[&cheap_to_recompute_ops, &feeds](const NodeDef& node) { [&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) {
return !IsTargetOp(node) && feeds.count(node.name()) == 0 && return !is_target(node) && feeds.count(node.name()) == 0 &&
(cheap_to_recompute_ops.count(node.op()) > 0 || (cheap_to_recompute_ops.count(node.op()) > 0 ||
node.attr().count(kRecomputeHint) > 0); node.attr().count(kRecomputeHint) > 0);
}); },
is_target);
} else if (optimization_level == RewriterConfig::MANUAL) { } else if (optimization_level == RewriterConfig::MANUAL) {
recomputed_subgraphs = recomputed_subgraphs = GetOpGroupsToRecompute(
GetOpGroupsToRecompute(graph, node_map, [&feeds](const NodeDef& node) { graph, node_map,
return !IsTargetOp(node) && feeds.count(node.name()) == 0 && [&feeds, &is_target](const NodeDef& node) {
return !is_target(node) && feeds.count(node.name()) == 0 &&
node.attr().count(kRecomputeHint) > 0; node.attr().count(kRecomputeHint) > 0;
}); },
is_target);
} }
if (!recomputed_subgraphs.empty()) { if (!recomputed_subgraphs.empty()) {
std::unordered_map<const NodeDef*, int> topological_numbering; std::unordered_map<const NodeDef*, int> topological_numbering;
@ -598,7 +606,9 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) { GraphDef* optimized_graph) {
*optimized_graph = item.graph; *optimized_graph = item.graph;
RecomputationRewritingPass(optimization_level_, optimized_graph, item); RecomputationRewritingPass(optimization_level_,
recomputation_targets_name_prefix_,
optimized_graph, item);
// Figure out what needs to be swapped; // Figure out what needs to be swapped;
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap; std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;

View File

@ -25,8 +25,16 @@ namespace grappler {
// Swap tensors in and out of device memory. // Swap tensors in and out of device memory.
class MemoryOptimizer : public GraphOptimizer { class MemoryOptimizer : public GraphOptimizer {
public: public:
explicit MemoryOptimizer(RewriterConfig::MemOptType optimization_level) // optimization_level: Controls the level of autonomy for the memory
: optimization_level_(optimization_level) {} // optimizer. See RewriterConfig::memory_optimization.
// recomputation_targets_name_prefix: Name prefix for potential outputs of
// recomputations. See
// RewriterConfig::memory_optimizer_target_node_name_prefix.
explicit MemoryOptimizer(
RewriterConfig::MemOptType optimization_level,
const string& recomputation_targets_name_prefix = "gradients/")
: optimization_level_(optimization_level),
recomputation_targets_name_prefix_(recomputation_targets_name_prefix) {}
~MemoryOptimizer() override {} ~MemoryOptimizer() override {}
string name() const override { return "memory_optimizer"; }; string name() const override { return "memory_optimizer"; };
@ -39,6 +47,7 @@ class MemoryOptimizer : public GraphOptimizer {
private: private:
RewriterConfig::MemOptType optimization_level_; RewriterConfig::MemOptType optimization_level_;
string recomputation_targets_name_prefix_;
}; };
} // end namespace grappler } // end namespace grappler

View File

@ -101,7 +101,8 @@ TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT); Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
Output b = ops::Variable(s.WithOpName("b"), {2, 3, 4}, DT_FLOAT); Output b = ops::Variable(s.WithOpName("b"), {2, 3, 4}, DT_FLOAT);
Output d = ops::AddN(s.WithOpName("gradients/two_subgraph_inputs"), {a, b}); Output d = ops::AddN(
s.WithOpName("some_name_scope/gradients/two_subgraph_inputs"), {a, b});
GrapplerItem item; GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph)); TF_CHECK_OK(s.ToGraphDef(&item.graph));
@ -112,7 +113,8 @@ TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
(*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"] (*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
.set_i(0); .set_i(0);
MemoryOptimizer optimizer(RewriterConfig::MANUAL); MemoryOptimizer optimizer(RewriterConfig::MANUAL,
"some_name_scope/gradients");
GraphDef output; GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output); Status status = optimizer.Optimize(nullptr, item, &output);

View File

@ -67,8 +67,16 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer())); std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
} }
if (cfg_.memory_optimization() > 1) { if (cfg_.memory_optimization() > 1) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>( if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
new MemoryOptimizer(cfg_.memory_optimization()))); optimizers.push_back(std::unique_ptr<GraphOptimizer>(
// Use the default target node name prefix "gradients/"
new MemoryOptimizer(cfg_.memory_optimization())));
} else {
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
cfg_.memory_optimization(),
cfg_.memory_optimizer_target_node_name_prefix())));
}
} }
if (cfg_.auto_parallel().enable()) { if (cfg_.auto_parallel().enable()) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>( optimizers.push_back(std::unique_ptr<GraphOptimizer>(

View File

@ -41,13 +41,25 @@ message RewriterConfig {
// Driven by manual op-level annotations. // Driven by manual op-level annotations.
MANUAL = 2; MANUAL = 2;
// Driven by heuristics. The behavior of these heuristics is subject to // Driven by heuristics. The behavior of these heuristics is subject to
// change. Currently includes an experimental recomputation heuristic. // change. Currently includes an experimental recomputation
// heuristic. Manual annotations are respected, but additional nodes are
// selected automatically.
HEURISTICS = 3; HEURISTICS = 3;
} }
// Configures memory optimization passes through the meta-optimizer. Has no // Configures memory optimization passes through the meta-optimizer. Has no
// effect on manually requested memory optimization passes in the optimizers // effect on manually requested memory optimization passes in the optimizers
// field. // field.
MemOptType memory_optimization = 4; MemOptType memory_optimization = 4;
// The prefix for nodes which are valid outputs of recomputations. Inputs to
// nodes with this name prefix may be recomputed (subject either to manual
// annotation of those input nodes or to manual annotation and heuristics
// depending on memory_optimization), but the prefixed nodes themselves will
// not be recomputed. Typically this will be "gradients/", indicating that
// activations from the forward pass of a graph may be recomputed as inputs to
// gradients, but may be adjusted if gradients are inside a name scope or if
// inputs to non-gradients should be recomputed. Defaults to "gradients/" if
// empty or not set.
string memory_optimizer_target_node_name_prefix = 6;
// Configures AutoParallel optimization passes either through the // Configures AutoParallel optimization passes either through the
// meta-optimizer or when manually specified through the optimizers field. // meta-optimizer or when manually specified through the optimizers field.

View File

@ -89,9 +89,13 @@ class MemoryOptimizerSwapTest(test.TestCase):
class MemoryOptimizerRecomputeTest(test.TestCase): class MemoryOptimizerRecomputeTest(test.TestCase):
"""Tests the Python interface to recomputation rewrites.
def _RunGraphWithConfig(self, config, batch_size=14, image_dim=12): See core/grappler/optimizers/memory_optimizer_test.cc for functional tests.
"""Run a simple layered graph with conv, an intermediate op, and a ReLU.""" """
def _GetMetaGraph(self, batch_size=14, image_dim=12, optimizer_scope_name=''):
"""A simple layered graph with conv, an intermediate op, and a ReLU."""
graph = ops.Graph() graph = ops.Graph()
with graph.as_default(): with graph.as_default():
random_seed.set_random_seed(1) random_seed.set_random_seed(1)
@ -106,31 +110,89 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
current_activation = 2. * after_conv current_activation = 2. * after_conv
current_activation = nn.relu(current_activation) current_activation = nn.relu(current_activation)
loss = math_ops.reduce_mean(current_activation) loss = math_ops.reduce_mean(current_activation)
optimizer = train.AdamOptimizer(0.001) with ops.name_scope(optimizer_scope_name):
train_op = optimizer.minimize(loss) optimizer = train.AdamOptimizer(0.001)
train_op = optimizer.minimize(loss)
init_op = variables.global_variables_initializer() init_op = variables.global_variables_initializer()
with session.Session(config=config, graph=graph) as sess: metagraph = train.export_meta_graph()
sess.run(init_op) return (metagraph, init_op.name, train_op.name, loss.name)
sess.run(train_op)
sess.run(train_op)
return sess.run(loss)
def _GetMemoryOptimizerConfig(self): def testRewritingDefaultGradientNames(self):
"""Tests that rewriting occurs with default gradient names."""
(original_metagraph, _, _, _) = self._GetMetaGraph()
rewritten_graph_def = tf_optimizer.OptimizeGraph(
rewriter_config_pb2.RewriterConfig(
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS),
original_metagraph)
self.assertGreater(
len(rewritten_graph_def.node),
len(original_metagraph.graph_def.node))
self.assertEqual(
0,
len([node for node in original_metagraph.graph_def.node
if 'Recomputed/' in node.name]))
self.assertEqual(
20, # Two per layer
len([node for node in rewritten_graph_def.node
if 'Recomputed/' in node.name]))
def testRewritingNameScopedGradientNames(self):
"""Tests that rewriting occurs with non-standard gradient names."""
(original_metagraph, _, _, _) = self._GetMetaGraph(
optimizer_scope_name='optimizer')
rewritten_graph_def = tf_optimizer.OptimizeGraph(
rewriter_config_pb2.RewriterConfig(
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS,
memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
original_metagraph)
self.assertGreater(
len(rewritten_graph_def.node),
len(original_metagraph.graph_def.node))
self.assertEqual(
0,
len([node for node in original_metagraph.graph_def.node
if 'Recomputed/' in node.name]))
self.assertEqual(
20, # Two per layer
len([node for node in rewritten_graph_def.node
if 'Recomputed/' in node.name]))
def _GetMemoryOptimizerSessionConfig(self):
rewrite_options = rewriter_config_pb2.RewriterConfig( rewrite_options = rewriter_config_pb2.RewriterConfig(
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS) memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS)
graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options) graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options)
return config_pb2.ConfigProto(graph_options=graph_options) return config_pb2.ConfigProto(graph_options=graph_options)
def testRecomputationRewritingNoErrors(self): def _RunMetaGraphWithConfig(
"""Tests that there are no errors when we request a memory optimizer pass. self, config, metagraph, init_op_name, train_op_name, loss_op_name):
graph = ops.Graph()
with graph.as_default():
train.import_meta_graph(metagraph)
init_op = graph.get_operation_by_name(init_op_name)
train_op = graph.get_operation_by_name(train_op_name)
loss_op = graph.get_tensor_by_name(loss_op_name)
with session.Session(config=config, graph=graph) as sess:
sess.run(init_op)
sess.run(train_op)
sess.run(train_op)
return sess.run(loss_op)
Does not test that the memory optimizer actually runs. See def testRecomputationRewritingNoErrors(self):
core/grappler/optimizers/memory_optimizer_test.cc for a functional test of """Tests that graph output is not significantly different with rewriting."""
the graph rewriting. (original_metagraph, init_op_name, train_op_name, loss_op_name
""" ) = self._GetMetaGraph()
original_loss = self._RunGraphWithConfig(config_pb2.ConfigProto()) original_loss = self._RunMetaGraphWithConfig(
memory_optimized_loss = self._RunGraphWithConfig( config=config_pb2.ConfigProto(),
config=self._GetMemoryOptimizerConfig()) metagraph=original_metagraph,
init_op_name=init_op_name,
train_op_name=train_op_name,
loss_op_name=loss_op_name)
memory_optimized_loss = self._RunMetaGraphWithConfig(
config=self._GetMemoryOptimizerSessionConfig(),
metagraph=original_metagraph,
init_op_name=init_op_name,
train_op_name=train_op_name,
loss_op_name=loss_op_name)
self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4) self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)