mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Propagates constants through switch nodes.
PiperOrigin-RevId: 158163537
This commit is contained in:
parent
b01d4b9058
commit
e55f2e036d
|
|
@ -381,8 +381,14 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
|
|||
if (output_tensors.size() > 1) {
|
||||
node_name = strings::StrCat(node_name, "-", i);
|
||||
}
|
||||
outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
|
||||
delete output_tensors[i].tensor;
|
||||
if (output_tensors[i].tensor) {
|
||||
outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
|
||||
delete output_tensors[i].tensor;
|
||||
} else {
|
||||
// Create an empty NodeDef to identify dead outputs (e.g. the output of a
|
||||
// switch that's not selected by the switch predicate).
|
||||
outputs->push_back(NodeDef());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -391,7 +397,14 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
|
|||
std::vector<NodeDef> const_nodes;
|
||||
TF_RETURN_IF_ERROR(EvaluateOneFoldable(node, &const_nodes));
|
||||
|
||||
NodeDef* constant_output = nullptr;
|
||||
for (const auto& const_node : const_nodes) {
|
||||
if (const_node.name().empty()) {
|
||||
// Dead output: we can't create a constant to encode its value, so we'll
|
||||
// just skip it. We'll preserve the edges that originate from that output
|
||||
// below to preserve the overall behavior of the graph wrt dead edges.
|
||||
continue;
|
||||
}
|
||||
NodeDef* added_node = output->add_node();
|
||||
*added_node = const_node;
|
||||
node_map_->AddNode(added_node->name(), added_node);
|
||||
|
|
@ -408,6 +421,11 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All the constant nodes encoding output values have the same control
|
||||
// dependencies (since these are the control dependencies of the node we're
|
||||
// trying to fold). Record one such constant node.
|
||||
constant_output = added_node;
|
||||
}
|
||||
|
||||
auto outputs = node_map_->GetOutputs(node.name());
|
||||
|
|
@ -417,9 +435,21 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
|
|||
string node_name = ParseNodeName(output->input(i), &position);
|
||||
if (node_name == node.name()) {
|
||||
if (position < 0) {
|
||||
*output->mutable_input(i) = AsControlDependency(const_nodes[0]);
|
||||
} else {
|
||||
// Propagate control dependencies if possible. If not, we'll just
|
||||
// preserve the existing control dependencies.
|
||||
if (constant_output != nullptr) {
|
||||
*output->mutable_input(i) = AsControlDependency(*constant_output);
|
||||
}
|
||||
|
||||
} else if (position < const_nodes.size() &&
|
||||
!const_nodes[position].name().empty()) {
|
||||
// Replace alive outputs with the corresponding constant.
|
||||
*output->mutable_input(i) = const_nodes[position].name();
|
||||
} else {
|
||||
// Leave this edge alone.
|
||||
VLOG(1) << "Preserving edge from " << node.name() << ":" << position
|
||||
<< "[" << node.op() << "] to " << output->name() << ":" << i
|
||||
<< "[" << output->op() << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -249,16 +249,28 @@ TEST_F(ConstantFoldingTest, SwitchNodes) {
|
|||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
|
||||
ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
|
||||
ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl);
|
||||
ops::Rank rank(scope.WithOpName("rank"), s.output_false);
|
||||
ops::Identity i(scope.WithOpName("i"), s.output_true);
|
||||
ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
|
||||
ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
|
||||
ops::Identity i(scope.WithOpName("i"), s1.output_true);
|
||||
ops::Size size(scope.WithOpName("size"), i);
|
||||
ops::Square p1(scope.WithOpName("p1"), rank);
|
||||
ops::Square p2(scope.WithOpName("p2"), size);
|
||||
ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
|
||||
|
||||
Output predicate =
|
||||
ops::Const(scope.WithOpName("false"), false, TensorShape({}));
|
||||
Output constant =
|
||||
ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
|
||||
ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
|
||||
ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
|
||||
ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
|
||||
ops::Merge m2(scope.WithOpName("m2"),
|
||||
{statically_known.output, never_generated.output});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch.push_back("m");
|
||||
item.fetch.push_back("m2");
|
||||
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
|
|
@ -277,6 +289,19 @@ TEST_F(ConstantFoldingTest, SwitchNodes) {
|
|||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("^i", node.input(0));
|
||||
}
|
||||
if (node.name() == "ConstantFolding/switch2-0") {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(0, node.input_size());
|
||||
}
|
||||
if (node.name() == "ConstantFolding/i2") {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(0, node.input_size());
|
||||
}
|
||||
if (node.name() == "i3") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("switch2:1", node.input(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user