Handle control flow logic properly:

* Don't fold enter/exit nodes since that can interact badly with frames
 * Create proper control dependencies on switch nodes

PiperOrigin-RevId: 158066691
This commit is contained in:
Benoit Steiner 2017-06-05 14:59:21 -07:00 committed by TensorFlower Gardener
parent 9e6899720a
commit 93c57c6e4f
7 changed files with 137 additions and 13 deletions

View File

@ -35,6 +35,11 @@ bool IsDequeueOp(const NodeDef& node) {
op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo"; op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
} }
bool IsIdentity(const NodeDef& node) {
const auto& op = node.op();
return op == "Identity";
}
bool IsMerge(const NodeDef& node) { bool IsMerge(const NodeDef& node) {
const auto op = node.op(); const auto op = node.op();
return op == "Merge"; return op == "Merge";
@ -52,6 +57,11 @@ bool IsReduction(const NodeDef& node) {
op == "Mean" || op == "Any" || op == "All"; op == "Mean" || op == "Any" || op == "All";
} }
bool IsSwitch(const NodeDef& node) {
const auto& op = node.op();
return op == "Switch";
}
bool IsTranspose(const NodeDef& node) { bool IsTranspose(const NodeDef& node) {
const auto op = node.op(); const auto op = node.op();
return op == "Transpose"; return op == "Transpose";

View File

@ -24,9 +24,11 @@ namespace grappler {
bool IsConcat(const NodeDef& node); bool IsConcat(const NodeDef& node);
bool IsConstant(const NodeDef& node); bool IsConstant(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node); bool IsDequeueOp(const NodeDef& node);
bool IsIdentity(const NodeDef& node);
bool IsMerge(const NodeDef& node); bool IsMerge(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node); bool IsPlaceholder(const NodeDef& node);
bool IsReduction(const NodeDef& node); bool IsReduction(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
bool IsTranspose(const NodeDef& node); bool IsTranspose(const NodeDef& node);
bool IsVariable(const NodeDef& node); bool IsVariable(const NodeDef& node);

View File

@ -99,17 +99,67 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) {
} }
return Status::OK(); return Status::OK();
} }
string AsControlDependency(const NodeDef& node) {
return strings::StrCat("^", node.name());
}
} // namespace } // namespace
ConstantFolding::ConstantFolding() { ConstantFolding::ConstantFolding() {
ops_to_preserve_ = ops_to_preserve_ = std::regex(
std::regex("Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader"); "Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|Enter|Exit|"
"NextIteration");
}
string ConstantFolding::AddControlDependency(const string& input_name) {
const NodeDef* node = node_map_->GetNode(input_name);
if (!IsSwitch(*node)) {
return AsControlDependency(*node);
} else {
// We can't anchor control dependencies directly on the switch node: unlike
// other nodes only one of the outputs of the switch node will be generated
// when the switch node is executed, and we need to make sure the control
// dependency is only triggered when the corresponding output is triggered.
// We start by looking for an identity node connected to the output of the
// switch node, and use it to anchor the control dependency.
auto outputs = node_map_->GetOutputs(node->name());
for (const NodeDef* node : outputs) {
if (IsIdentity(*node)) {
CHECK_EQ(1, node->input_size());
if (IsSameInput(node->input(0), input_name)) {
return AsControlDependency(*node);
}
}
}
// We haven't found an existing node where we can anchor the control
// dependency: add a new identity node.
int position = 0;
string ctrl_dep_name = ParseNodeName(input_name, &position);
strings::StrAppend(&ctrl_dep_name, "_", position);
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
const DataType output_type = node->attr().at("T").type();
NodeDef* added_node = graph_.add_node();
added_node->set_name(ctrl_dep_name);
added_node->set_op("Identity");
(*added_node->mutable_attr())["T"].set_type(output_type);
*added_node->add_input() = input_name;
node_map_->AddNode(added_node->name(), added_node);
node_map_->AddOutput(node->name(), added_node->name());
return AsControlDependency(*added_node);
}
} }
Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
GraphProperties properties(item); GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically()); TF_RETURN_IF_ERROR(properties.InferStatically());
for (NodeDef& node : *graph_.mutable_node()) { // We may add some nodes to the graph to encode control dependencies: there is
// no need to process these, so only iterate over the nodes of the input
// graph.
const int node_count = graph_.node_size();
for (int i = 0; i < node_count; ++i) {
NodeDef& node = *graph_.mutable_node(i);
const string op = node.op(); const string op = node.op();
if (op != "Shape" && op != "Size" && op != "Rank") { if (op != "Shape" && op != "Size" && op != "Rank") {
continue; continue;
@ -179,9 +229,12 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
value.AsProtoTensorContent( value.AsProtoTensorContent(
(*node.mutable_attr())["value"].mutable_tensor()); (*node.mutable_attr())["value"].mutable_tensor());
// Turn the inputs into control dependencies. // Turn the inputs into control dependencies: this is needed to ensure
// that the constant value will only be generated in the cases where the
// shape/rank/size would have been generated in the original graph.
string ctrl_dep = AddControlDependency(node.input(0));
CHECK_EQ(1, node.input_size()); CHECK_EQ(1, node.input_size());
node.set_input(0, strings::StrCat("^", NodeName(node.input(0)))); node.set_input(0, ctrl_dep);
} }
} }
} }
@ -364,8 +417,7 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
string node_name = ParseNodeName(output->input(i), &position); string node_name = ParseNodeName(output->input(i), &position);
if (node_name == node.name()) { if (node_name == node.name()) {
if (position < 0) { if (position < 0) {
*output->mutable_input(i) = *output->mutable_input(i) = AsControlDependency(const_nodes[0]);
strings::StrCat("^", const_nodes[0].name());
} else { } else {
*output->mutable_input(i) = const_nodes[position].name(); *output->mutable_input(i) = const_nodes[position].name();
} }

View File

@ -26,6 +26,7 @@ namespace tensorflow {
namespace grappler { namespace grappler {
const char kConstantFoldingConst[] = "ConstantFolding"; const char kConstantFoldingConst[] = "ConstantFolding";
const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
// Contant folding optimization for a graph. // Contant folding optimization for a graph.
class ConstantFolding : public GraphOptimizer { class ConstantFolding : public GraphOptimizer {
@ -43,6 +44,7 @@ class ConstantFolding : public GraphOptimizer {
const GraphDef& optimize_output, double result) override; const GraphDef& optimize_output, double result) override;
private: private:
string AddControlDependency(const string& input_name);
Status MaterializeShapes(const GrapplerItem& item); Status MaterializeShapes(const GrapplerItem& item);
bool IsFoldable(const NodeDef& node) const; bool IsFoldable(const NodeDef& node) const;

View File

@ -245,6 +245,41 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) {
EXPECT_EQ(3, found); EXPECT_EQ(3, found);
} }
TEST_F(ConstantFoldingTest, SwitchNodes) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl);
ops::Rank rank(scope.WithOpName("rank"), s.output_false);
ops::Identity i(scope.WithOpName("i"), s.output_true);
ops::Size size(scope.WithOpName("size"), i);
ops::Square p1(scope.WithOpName("p1"), rank);
ops::Square p2(scope.WithOpName("p2"), size);
ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
GrapplerItem item;
item.fetch.push_back("m");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
for (const auto& node : output.node()) {
if (node.name() == "rank") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
}
if (node.name() == "size") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^i", node.input(0));
}
}
}
TEST_F(ConstantFoldingTest, NoOpReduction) { TEST_F(ConstantFoldingTest, NoOpReduction) {
// Build a simple graph with a reduction that can be reduced to the identity. // Build a simple graph with a reduction that can be reduced to the identity.
tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); tensorflow::Scope scope = tensorflow::Scope::NewRootScope();

View File

@ -34,13 +34,21 @@ NodeMap::NodeMap(GraphDef* graph) : graph_(graph) {
} }
} }
NodeDef* NodeMap::GetNode(const string& name) { NodeDef* NodeMap::GetNode(const string& name) const {
string node_name = NodeName(name); string node_name = NodeName(name);
return nodes_[node_name]; auto it = nodes_.find(node_name);
if (it == nodes_.end()) {
return nullptr;
}
return it->second;
} }
std::set<NodeDef*> NodeMap::GetOutputs(const string& node_name) { const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const {
return outputs_[node_name]; auto it = outputs_.find(node_name);
if (it == outputs_.end()) {
return empty_set_;
}
return it->second;
} }
void NodeMap::AddNode(const string& name, NodeDef* node) { void NodeMap::AddNode(const string& name, NodeDef* node) {
@ -57,6 +65,17 @@ void NodeMap::UpdateOutput(const string& node, const string& old_output,
outputs_[node].insert(nodes_[new_output]); outputs_[node].insert(nodes_[new_output]);
} }
bool IsSameInput(const string& name1, const string& name2) {
if (name1 == name2) {
return true;
}
int position1;
string node1 = ParseNodeName(name1, &position1);
int position2;
string node2 = ParseNodeName(name2, &position2);
return (position1 == position2) && (node1 == node2);
}
string ParseNodeName(const string& name, int* position) { string ParseNodeName(const string& name, int* position) {
// Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any) // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
// to get a node name. // to get a node name.

View File

@ -31,8 +31,8 @@ namespace grappler {
class NodeMap { class NodeMap {
public: public:
explicit NodeMap(GraphDef* graph); explicit NodeMap(GraphDef* graph);
NodeDef* GetNode(const string& name); NodeDef* GetNode(const string& name) const;
std::set<NodeDef*> GetOutputs(const string& node_name); const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
// This method doesn't record the outputs of the added node; the outputs need // This method doesn't record the outputs of the added node; the outputs need
// to be explictly added by the AddOutput method. // to be explictly added by the AddOutput method.
void AddNode(const string& name, NodeDef* node); void AddNode(const string& name, NodeDef* node);
@ -42,6 +42,7 @@ class NodeMap {
private: private:
GraphDef* graph_; GraphDef* graph_;
std::set<NodeDef*> empty_set_;
std::unordered_map<string, NodeDef*> nodes_; std::unordered_map<string, NodeDef*> nodes_;
std::unordered_map<string, std::set<NodeDef*>> outputs_; std::unordered_map<string, std::set<NodeDef*>> outputs_;
}; };
@ -50,6 +51,9 @@ class NodeMap {
// the ^ character. // the ^ character.
bool IsControlInput(const string& name); bool IsControlInput(const string& name);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
// Return the node name corresponding to 'name' if name is valid, or the empty // Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise. // string otherwise.
string NodeName(const string& name); string NodeName(const string& name);