Propagates constants through switch nodes.

PiperOrigin-RevId: 158163537
This commit is contained in:
Benoit Steiner 2017-06-06 11:09:54 -07:00 committed by TensorFlower Gardener
parent b01d4b9058
commit e55f2e036d
2 changed files with 62 additions and 7 deletions

View File

@ -381,8 +381,14 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
if (output_tensors.size() > 1) {
node_name = strings::StrCat(node_name, "-", i);
}
outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
delete output_tensors[i].tensor;
if (output_tensors[i].tensor) {
outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
delete output_tensors[i].tensor;
} else {
// Create an empty NodeDef to identify dead outputs (e.g. the output of a
// switch that's not selected by the switch predicate).
outputs->push_back(NodeDef());
}
}
return Status::OK();
}
@ -391,7 +397,14 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
std::vector<NodeDef> const_nodes;
TF_RETURN_IF_ERROR(EvaluateOneFoldable(node, &const_nodes));
NodeDef* constant_output = nullptr;
for (const auto& const_node : const_nodes) {
if (const_node.name().empty()) {
// Dead output: we can't create a constant to encode its value, so we'll
// just skip it. We'll preserve the edges that originate from that output
// below to preserve the overall behavior of the graph wrt dead edges.
continue;
}
NodeDef* added_node = output->add_node();
*added_node = const_node;
node_map_->AddNode(added_node->name(), added_node);
@ -408,6 +421,11 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
}
}
}
// All the constant nodes encoding output values have the same control
// dependencies (since these are the control dependencies of the node we're
// trying to fold). Record one such constant node.
constant_output = added_node;
}
auto outputs = node_map_->GetOutputs(node.name());
@ -417,9 +435,21 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
string node_name = ParseNodeName(output->input(i), &position);
if (node_name == node.name()) {
if (position < 0) {
*output->mutable_input(i) = AsControlDependency(const_nodes[0]);
} else {
// Propagate control dependencies if possible. If not, we'll just
// preserve the existing control dependencies.
if (constant_output != nullptr) {
*output->mutable_input(i) = AsControlDependency(*constant_output);
}
} else if (position < const_nodes.size() &&
!const_nodes[position].name().empty()) {
// Replace alive outputs with the corresponding constant.
*output->mutable_input(i) = const_nodes[position].name();
} else {
// Leave this edge alone.
VLOG(1) << "Preserving edge from " << node.name() << ":" << position
<< "[" << node.op() << "] to " << output->name() << ":" << i
<< "[" << output->op() << "]";
}
}
}

View File

@ -249,16 +249,28 @@ TEST_F(ConstantFoldingTest, SwitchNodes) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl);
ops::Rank rank(scope.WithOpName("rank"), s.output_false);
ops::Identity i(scope.WithOpName("i"), s.output_true);
ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
ops::Identity i(scope.WithOpName("i"), s1.output_true);
ops::Size size(scope.WithOpName("size"), i);
ops::Square p1(scope.WithOpName("p1"), rank);
ops::Square p2(scope.WithOpName("p2"), size);
ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
Output predicate =
ops::Const(scope.WithOpName("false"), false, TensorShape({}));
Output constant =
ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
ops::Merge m2(scope.WithOpName("m2"),
{statically_known.output, never_generated.output});
GrapplerItem item;
item.fetch.push_back("m");
item.fetch.push_back("m2");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
@ -277,6 +289,19 @@ TEST_F(ConstantFoldingTest, SwitchNodes) {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^i", node.input(0));
}
if (node.name() == "ConstantFolding/switch2-0") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(0, node.input_size());
}
if (node.name() == "ConstantFolding/i2") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(0, node.input_size());
}
if (node.name() == "i3") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("switch2:1", node.input(0));
}
}
}