mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Grappler memory optimization: allow inputs to gradients with non-standard names to be recomputed
Includes Python tests for name-scoped gradients. PiperOrigin-RevId: 163720208
This commit is contained in:
parent
b876065afe
commit
6263539a15
|
|
@ -42,7 +42,6 @@ const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
|
|||
// Attribute which may be added to nodes to manually allow them to be
|
||||
// recomputed.
|
||||
const char* kRecomputeHint = "_recompute_hint";
|
||||
const char* kRecomputationTargetNamePrefix = "gradients/";
|
||||
|
||||
// Ops which we wouldn't mind recomputing to save memory.
|
||||
// TODO(allenl): Replace this list with a cost model.
|
||||
|
|
@ -57,18 +56,11 @@ std::unordered_set<string> GetCheapToRecomputeOps() {
|
|||
return cheap_ops;
|
||||
}
|
||||
|
||||
// Nodes whose inputs we may want to recompute (i.e. gradients).
|
||||
// TODO(allenl): Rather than blindly recomputing gradient inputs, use a static
|
||||
// schedule (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
|
||||
// whose outputs will sit around for a while.
|
||||
bool IsTargetOp(const NodeDef& node) {
|
||||
return node.name().find(kRecomputationTargetNamePrefix) == 0;
|
||||
}
|
||||
|
||||
// Find recomputable ops which feed into target nodes.
|
||||
std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
|
||||
const NodeMap& node_map, const GraphDef* graph,
|
||||
const std::function<bool(const NodeDef&)>& is_candidate) {
|
||||
const std::function<bool(const NodeDef&)>& is_candidate,
|
||||
const std::function<bool(const NodeDef&)>& is_target) {
|
||||
std::unordered_set<const NodeDef*> candidate_recompute_nodes;
|
||||
for (const auto& node : graph->node()) {
|
||||
if (!is_candidate(node)) {
|
||||
|
|
@ -78,7 +70,7 @@ std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
|
|||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||
// It only makes sense to recompute this if it feeds into a target
|
||||
// node. We expand this to dependencies in GetOpGroupsToRecompute.
|
||||
if (IsTargetOp(*output)) {
|
||||
if (is_target(*output)) {
|
||||
has_target_output = true;
|
||||
break;
|
||||
}
|
||||
|
|
@ -90,7 +82,7 @@ std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
|
|||
for (const string& input_name : node.input()) {
|
||||
// Don't recompute nodes which depend on target nodes.
|
||||
const NodeDef* input_node = node_map.GetNode(input_name);
|
||||
if (IsTargetOp(*input_node)) {
|
||||
if (is_target(*input_node)) {
|
||||
has_target_input = true;
|
||||
break;
|
||||
}
|
||||
|
|
@ -147,11 +139,12 @@ struct RecomputedSubGraph {
|
|||
// Find groups of ops to recompute together based on `should_recompute`.
|
||||
std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
|
||||
const GraphDef* graph, const NodeMap& node_map,
|
||||
const std::function<bool(const NodeDef&)>& should_recompute) {
|
||||
const std::function<bool(const NodeDef&)>& should_recompute,
|
||||
const std::function<bool(const NodeDef&)>& is_target) {
|
||||
std::unordered_set<const NodeDef*> visited_nodes;
|
||||
std::vector<RecomputedSubGraph> subgraphs_to_recompute;
|
||||
std::unordered_set<const NodeDef*> candidate_recompute_nodes =
|
||||
FindCandidateRecomputeNodes(node_map, graph, should_recompute);
|
||||
FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
|
||||
for (const NodeDef* recompute_node : candidate_recompute_nodes) {
|
||||
if (visited_nodes.count(recompute_node) > 0) {
|
||||
continue;
|
||||
|
|
@ -171,7 +164,7 @@ std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
|
|||
for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
|
||||
bool inserted_feed = false;
|
||||
for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
|
||||
if (IsTargetOp(*output)) {
|
||||
if (is_target(*output)) {
|
||||
current_recomputation.target_nodes.insert(output);
|
||||
if (!inserted_feed) {
|
||||
// Keep track of nodes which feed directly into a target node. These
|
||||
|
|
@ -416,6 +409,7 @@ void RecomputeSubgraph(
|
|||
}
|
||||
|
||||
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||
const string& recomputation_targets_name_prefix,
|
||||
GraphDef* graph, const GrapplerItem& item) {
|
||||
// The topological numberings and NodeMap will be stale as soon as we start
|
||||
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
|
||||
|
|
@ -433,6 +427,17 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
|||
for (const auto& feed : item.feed) {
|
||||
feeds.insert(NodeName(feed.first));
|
||||
}
|
||||
std::function<bool(const NodeDef&)> is_target =
|
||||
[&recomputation_targets_name_prefix](const NodeDef& node) {
|
||||
// Nodes whose inputs we may want to recompute. Typically targets will
|
||||
// be gradients (recomputation_targets_name_prefix="gradients/"),
|
||||
// although the prefix is configurable since gradients may be created in
|
||||
// a name scope.
|
||||
// TODO(allenl): Use a static schedule
|
||||
// (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
|
||||
// whose outputs will sit around for a while.
|
||||
return node.name().find(recomputation_targets_name_prefix) == 0;
|
||||
};
|
||||
if (optimization_level == RewriterConfig::HEURISTICS) {
|
||||
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
|
||||
// the cheap forward ops get grouped into a single subgraph which must
|
||||
|
|
@ -442,17 +447,20 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
|||
GetCheapToRecomputeOps();
|
||||
recomputed_subgraphs = GetOpGroupsToRecompute(
|
||||
graph, node_map,
|
||||
[&cheap_to_recompute_ops, &feeds](const NodeDef& node) {
|
||||
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&
|
||||
[&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) {
|
||||
return !is_target(node) && feeds.count(node.name()) == 0 &&
|
||||
(cheap_to_recompute_ops.count(node.op()) > 0 ||
|
||||
node.attr().count(kRecomputeHint) > 0);
|
||||
});
|
||||
},
|
||||
is_target);
|
||||
} else if (optimization_level == RewriterConfig::MANUAL) {
|
||||
recomputed_subgraphs =
|
||||
GetOpGroupsToRecompute(graph, node_map, [&feeds](const NodeDef& node) {
|
||||
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&
|
||||
recomputed_subgraphs = GetOpGroupsToRecompute(
|
||||
graph, node_map,
|
||||
[&feeds, &is_target](const NodeDef& node) {
|
||||
return !is_target(node) && feeds.count(node.name()) == 0 &&
|
||||
node.attr().count(kRecomputeHint) > 0;
|
||||
});
|
||||
},
|
||||
is_target);
|
||||
}
|
||||
if (!recomputed_subgraphs.empty()) {
|
||||
std::unordered_map<const NodeDef*, int> topological_numbering;
|
||||
|
|
@ -598,7 +606,9 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||
GraphDef* optimized_graph) {
|
||||
*optimized_graph = item.graph;
|
||||
|
||||
RecomputationRewritingPass(optimization_level_, optimized_graph, item);
|
||||
RecomputationRewritingPass(optimization_level_,
|
||||
recomputation_targets_name_prefix_,
|
||||
optimized_graph, item);
|
||||
|
||||
// Figure out what needs to be swapped;
|
||||
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
|
||||
|
|
|
|||
|
|
@ -25,8 +25,16 @@ namespace grappler {
|
|||
// Swap tensors in and out of device memory.
|
||||
class MemoryOptimizer : public GraphOptimizer {
|
||||
public:
|
||||
explicit MemoryOptimizer(RewriterConfig::MemOptType optimization_level)
|
||||
: optimization_level_(optimization_level) {}
|
||||
// optimization_level: Controls the level of autonomy for the memory
|
||||
// optimizer. See RewriterConfig::memory_optimization.
|
||||
// recomputation_targets_name_prefix: Name prefix for potential outputs of
|
||||
// recomputations. See
|
||||
// RewriterConfig::memory_optimizer_target_node_name_prefix.
|
||||
explicit MemoryOptimizer(
|
||||
RewriterConfig::MemOptType optimization_level,
|
||||
const string& recomputation_targets_name_prefix = "gradients/")
|
||||
: optimization_level_(optimization_level),
|
||||
recomputation_targets_name_prefix_(recomputation_targets_name_prefix) {}
|
||||
~MemoryOptimizer() override {}
|
||||
|
||||
string name() const override { return "memory_optimizer"; };
|
||||
|
|
@ -39,6 +47,7 @@ class MemoryOptimizer : public GraphOptimizer {
|
|||
|
||||
private:
|
||||
RewriterConfig::MemOptType optimization_level_;
|
||||
string recomputation_targets_name_prefix_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
|
|
|||
|
|
@ -101,7 +101,8 @@ TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
|
|||
|
||||
Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
|
||||
Output b = ops::Variable(s.WithOpName("b"), {2, 3, 4}, DT_FLOAT);
|
||||
Output d = ops::AddN(s.WithOpName("gradients/two_subgraph_inputs"), {a, b});
|
||||
Output d = ops::AddN(
|
||||
s.WithOpName("some_name_scope/gradients/two_subgraph_inputs"), {a, b});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
|
@ -112,7 +113,8 @@ TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
|
|||
(*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
|
||||
.set_i(0);
|
||||
|
||||
MemoryOptimizer optimizer(RewriterConfig::MANUAL);
|
||||
MemoryOptimizer optimizer(RewriterConfig::MANUAL,
|
||||
"some_name_scope/gradients");
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
|
||||
|
|
|
|||
|
|
@ -67,8 +67,16 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
|
||||
}
|
||||
if (cfg_.memory_optimization() > 1) {
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
new MemoryOptimizer(cfg_.memory_optimization())));
|
||||
if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
// Use the default target node name prefix "gradients/"
|
||||
new MemoryOptimizer(cfg_.memory_optimization())));
|
||||
} else {
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
|
||||
cfg_.memory_optimization(),
|
||||
cfg_.memory_optimizer_target_node_name_prefix())));
|
||||
}
|
||||
}
|
||||
if (cfg_.auto_parallel().enable()) {
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
|
|
|
|||
|
|
@ -41,13 +41,25 @@ message RewriterConfig {
|
|||
// Driven by manual op-level annotations.
|
||||
MANUAL = 2;
|
||||
// Driven by heuristics. The behavior of these heuristics is subject to
|
||||
// change. Currently includes an experimental recomputation heuristic.
|
||||
// change. Currently includes an experimental recomputation
|
||||
// heuristic. Manual annotations are respected, but additional nodes are
|
||||
// selected automatically.
|
||||
HEURISTICS = 3;
|
||||
}
|
||||
// Configures memory optimization passes through the meta-optimizer. Has no
|
||||
// effect on manually requested memory optimization passes in the optimizers
|
||||
// field.
|
||||
MemOptType memory_optimization = 4;
|
||||
// The prefix for nodes which are valid outputs of recomputations. Inputs to
|
||||
// nodes with this name prefix may be recomputed (subject either to manual
|
||||
// annotation of those input nodes or to manual annotation and heuristics
|
||||
// depending on memory_optimization), but the prefixed nodes themselves will
|
||||
// not be recomputed. Typically this will be "gradients/", indicating that
|
||||
// activations from the forward pass of a graph may be recomputed as inputs to
|
||||
// gradients, but may be adjusted if gradients are inside a name scope or if
|
||||
// inputs to non-gradients should be recomputed. Defaults to "gradients/" if
|
||||
// empty or not set.
|
||||
string memory_optimizer_target_node_name_prefix = 6;
|
||||
|
||||
// Configures AutoParallel optimization passes either through the
|
||||
// meta-optimizer or when manually specified through the optimizers field.
|
||||
|
|
|
|||
|
|
@ -89,9 +89,13 @@ class MemoryOptimizerSwapTest(test.TestCase):
|
|||
|
||||
|
||||
class MemoryOptimizerRecomputeTest(test.TestCase):
|
||||
"""Tests the Python interface to recomputation rewrites.
|
||||
|
||||
def _RunGraphWithConfig(self, config, batch_size=14, image_dim=12):
|
||||
"""Run a simple layered graph with conv, an intermediate op, and a ReLU."""
|
||||
See core/grappler/optimizers/memory_optimizer_test.cc for functional tests.
|
||||
"""
|
||||
|
||||
def _GetMetaGraph(self, batch_size=14, image_dim=12, optimizer_scope_name=''):
|
||||
"""A simple layered graph with conv, an intermediate op, and a ReLU."""
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
random_seed.set_random_seed(1)
|
||||
|
|
@ -106,31 +110,89 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
|
|||
current_activation = 2. * after_conv
|
||||
current_activation = nn.relu(current_activation)
|
||||
loss = math_ops.reduce_mean(current_activation)
|
||||
optimizer = train.AdamOptimizer(0.001)
|
||||
train_op = optimizer.minimize(loss)
|
||||
with ops.name_scope(optimizer_scope_name):
|
||||
optimizer = train.AdamOptimizer(0.001)
|
||||
train_op = optimizer.minimize(loss)
|
||||
init_op = variables.global_variables_initializer()
|
||||
with session.Session(config=config, graph=graph) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(train_op)
|
||||
sess.run(train_op)
|
||||
return sess.run(loss)
|
||||
metagraph = train.export_meta_graph()
|
||||
return (metagraph, init_op.name, train_op.name, loss.name)
|
||||
|
||||
def _GetMemoryOptimizerConfig(self):
|
||||
def testRewritingDefaultGradientNames(self):
|
||||
"""Tests that rewriting occurs with default gradient names."""
|
||||
(original_metagraph, _, _, _) = self._GetMetaGraph()
|
||||
rewritten_graph_def = tf_optimizer.OptimizeGraph(
|
||||
rewriter_config_pb2.RewriterConfig(
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS),
|
||||
original_metagraph)
|
||||
self.assertGreater(
|
||||
len(rewritten_graph_def.node),
|
||||
len(original_metagraph.graph_def.node))
|
||||
self.assertEqual(
|
||||
0,
|
||||
len([node for node in original_metagraph.graph_def.node
|
||||
if 'Recomputed/' in node.name]))
|
||||
self.assertEqual(
|
||||
20, # Two per layer
|
||||
len([node for node in rewritten_graph_def.node
|
||||
if 'Recomputed/' in node.name]))
|
||||
|
||||
def testRewritingNameScopedGradientNames(self):
|
||||
"""Tests that rewriting occurs with non-standard gradient names."""
|
||||
(original_metagraph, _, _, _) = self._GetMetaGraph(
|
||||
optimizer_scope_name='optimizer')
|
||||
rewritten_graph_def = tf_optimizer.OptimizeGraph(
|
||||
rewriter_config_pb2.RewriterConfig(
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS,
|
||||
memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
|
||||
original_metagraph)
|
||||
self.assertGreater(
|
||||
len(rewritten_graph_def.node),
|
||||
len(original_metagraph.graph_def.node))
|
||||
self.assertEqual(
|
||||
0,
|
||||
len([node for node in original_metagraph.graph_def.node
|
||||
if 'Recomputed/' in node.name]))
|
||||
self.assertEqual(
|
||||
20, # Two per layer
|
||||
len([node for node in rewritten_graph_def.node
|
||||
if 'Recomputed/' in node.name]))
|
||||
|
||||
def _GetMemoryOptimizerSessionConfig(self):
|
||||
rewrite_options = rewriter_config_pb2.RewriterConfig(
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS)
|
||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options)
|
||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||
|
||||
def testRecomputationRewritingNoErrors(self):
|
||||
"""Tests that there are no errors when we request a memory optimizer pass.
|
||||
def _RunMetaGraphWithConfig(
|
||||
self, config, metagraph, init_op_name, train_op_name, loss_op_name):
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
train.import_meta_graph(metagraph)
|
||||
init_op = graph.get_operation_by_name(init_op_name)
|
||||
train_op = graph.get_operation_by_name(train_op_name)
|
||||
loss_op = graph.get_tensor_by_name(loss_op_name)
|
||||
with session.Session(config=config, graph=graph) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(train_op)
|
||||
sess.run(train_op)
|
||||
return sess.run(loss_op)
|
||||
|
||||
Does not test that the memory optimizer actually runs. See
|
||||
core/grappler/optimizers/memory_optimizer_test.cc for a functional test of
|
||||
the graph rewriting.
|
||||
"""
|
||||
original_loss = self._RunGraphWithConfig(config_pb2.ConfigProto())
|
||||
memory_optimized_loss = self._RunGraphWithConfig(
|
||||
config=self._GetMemoryOptimizerConfig())
|
||||
def testRecomputationRewritingNoErrors(self):
|
||||
"""Tests that graph output is not significantly different with rewriting."""
|
||||
(original_metagraph, init_op_name, train_op_name, loss_op_name
|
||||
) = self._GetMetaGraph()
|
||||
original_loss = self._RunMetaGraphWithConfig(
|
||||
config=config_pb2.ConfigProto(),
|
||||
metagraph=original_metagraph,
|
||||
init_op_name=init_op_name,
|
||||
train_op_name=train_op_name,
|
||||
loss_op_name=loss_op_name)
|
||||
memory_optimized_loss = self._RunMetaGraphWithConfig(
|
||||
config=self._GetMemoryOptimizerSessionConfig(),
|
||||
metagraph=original_metagraph,
|
||||
init_op_name=init_op_name,
|
||||
train_op_name=train_op_name,
|
||||
loss_op_name=loss_op_name)
|
||||
self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user