mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Handle control flow logic properly:
* Don't fold enter/exit nodes since that can interact badly with frames * Create proper control dependencies on switch nodes PiperOrigin-RevId: 158066691
This commit is contained in:
parent
9e6899720a
commit
93c57c6e4f
|
|
@ -35,6 +35,11 @@ bool IsDequeueOp(const NodeDef& node) {
|
|||
op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
|
||||
}
|
||||
|
||||
bool IsIdentity(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Identity";
|
||||
}
|
||||
|
||||
bool IsMerge(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Merge";
|
||||
|
|
@ -52,6 +57,11 @@ bool IsReduction(const NodeDef& node) {
|
|||
op == "Mean" || op == "Any" || op == "All";
|
||||
}
|
||||
|
||||
bool IsSwitch(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Switch";
|
||||
}
|
||||
|
||||
bool IsTranspose(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Transpose";
|
||||
|
|
|
|||
|
|
@ -24,9 +24,11 @@ namespace grappler {
|
|||
bool IsConcat(const NodeDef& node);
|
||||
bool IsConstant(const NodeDef& node);
|
||||
bool IsDequeueOp(const NodeDef& node);
|
||||
bool IsIdentity(const NodeDef& node);
|
||||
bool IsMerge(const NodeDef& node);
|
||||
bool IsPlaceholder(const NodeDef& node);
|
||||
bool IsReduction(const NodeDef& node);
|
||||
bool IsSwitch(const NodeDef& node);
|
||||
bool IsTranspose(const NodeDef& node);
|
||||
bool IsVariable(const NodeDef& node);
|
||||
|
||||
|
|
|
|||
|
|
@ -99,17 +99,67 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string AsControlDependency(const NodeDef& node) {
|
||||
return strings::StrCat("^", node.name());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ConstantFolding::ConstantFolding() {
|
||||
ops_to_preserve_ =
|
||||
std::regex("Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader");
|
||||
ops_to_preserve_ = std::regex(
|
||||
"Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|Enter|Exit|"
|
||||
"NextIteration");
|
||||
}
|
||||
|
||||
string ConstantFolding::AddControlDependency(const string& input_name) {
|
||||
const NodeDef* node = node_map_->GetNode(input_name);
|
||||
if (!IsSwitch(*node)) {
|
||||
return AsControlDependency(*node);
|
||||
} else {
|
||||
// We can't anchor control dependencies directly on the switch node: unlike
|
||||
// other nodes only one of the outputs of the switch node will be generated
|
||||
// when the switch node is executed, and we need to make sure the control
|
||||
// dependency is only triggered when the corresponding output is triggered.
|
||||
// We start by looking for an identity node connected to the output of the
|
||||
// switch node, and use it to anchor the control dependency.
|
||||
auto outputs = node_map_->GetOutputs(node->name());
|
||||
for (const NodeDef* node : outputs) {
|
||||
if (IsIdentity(*node)) {
|
||||
CHECK_EQ(1, node->input_size());
|
||||
if (IsSameInput(node->input(0), input_name)) {
|
||||
return AsControlDependency(*node);
|
||||
}
|
||||
}
|
||||
}
|
||||
// We haven't found an existing node where we can anchor the control
|
||||
// dependency: add a new identity node.
|
||||
int position = 0;
|
||||
string ctrl_dep_name = ParseNodeName(input_name, &position);
|
||||
strings::StrAppend(&ctrl_dep_name, "_", position);
|
||||
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
|
||||
const DataType output_type = node->attr().at("T").type();
|
||||
|
||||
NodeDef* added_node = graph_.add_node();
|
||||
added_node->set_name(ctrl_dep_name);
|
||||
added_node->set_op("Identity");
|
||||
(*added_node->mutable_attr())["T"].set_type(output_type);
|
||||
*added_node->add_input() = input_name;
|
||||
node_map_->AddNode(added_node->name(), added_node);
|
||||
node_map_->AddOutput(node->name(), added_node->name());
|
||||
return AsControlDependency(*added_node);
|
||||
}
|
||||
}
|
||||
|
||||
Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
|
||||
GraphProperties properties(item);
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically());
|
||||
for (NodeDef& node : *graph_.mutable_node()) {
|
||||
// We may add some nodes to the graph to encode control dependencies: there is
|
||||
// no need to process these, so only iterate over the nodes of the input
|
||||
// graph.
|
||||
const int node_count = graph_.node_size();
|
||||
for (int i = 0; i < node_count; ++i) {
|
||||
NodeDef& node = *graph_.mutable_node(i);
|
||||
const string op = node.op();
|
||||
if (op != "Shape" && op != "Size" && op != "Rank") {
|
||||
continue;
|
||||
|
|
@ -179,9 +229,12 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
|
|||
value.AsProtoTensorContent(
|
||||
(*node.mutable_attr())["value"].mutable_tensor());
|
||||
|
||||
// Turn the inputs into control dependencies.
|
||||
// Turn the inputs into control dependencies: this is needed to ensure
|
||||
// that the constant value will only be generated in the cases where the
|
||||
// shape/rank/size would have been generated in the original graph.
|
||||
string ctrl_dep = AddControlDependency(node.input(0));
|
||||
CHECK_EQ(1, node.input_size());
|
||||
node.set_input(0, strings::StrCat("^", NodeName(node.input(0))));
|
||||
node.set_input(0, ctrl_dep);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -364,8 +417,7 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
|
|||
string node_name = ParseNodeName(output->input(i), &position);
|
||||
if (node_name == node.name()) {
|
||||
if (position < 0) {
|
||||
*output->mutable_input(i) =
|
||||
strings::StrCat("^", const_nodes[0].name());
|
||||
*output->mutable_input(i) = AsControlDependency(const_nodes[0]);
|
||||
} else {
|
||||
*output->mutable_input(i) = const_nodes[position].name();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ namespace tensorflow {
|
|||
namespace grappler {
|
||||
|
||||
const char kConstantFoldingConst[] = "ConstantFolding";
|
||||
const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
|
||||
|
||||
// Contant folding optimization for a graph.
|
||||
class ConstantFolding : public GraphOptimizer {
|
||||
|
|
@ -43,6 +44,7 @@ class ConstantFolding : public GraphOptimizer {
|
|||
const GraphDef& optimize_output, double result) override;
|
||||
|
||||
private:
|
||||
string AddControlDependency(const string& input_name);
|
||||
Status MaterializeShapes(const GrapplerItem& item);
|
||||
|
||||
bool IsFoldable(const NodeDef& node) const;
|
||||
|
|
|
|||
|
|
@ -245,6 +245,41 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) {
|
|||
EXPECT_EQ(3, found);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, SwitchNodes) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
|
||||
ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
|
||||
ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl);
|
||||
ops::Rank rank(scope.WithOpName("rank"), s.output_false);
|
||||
ops::Identity i(scope.WithOpName("i"), s.output_true);
|
||||
ops::Size size(scope.WithOpName("size"), i);
|
||||
ops::Square p1(scope.WithOpName("p1"), rank);
|
||||
ops::Square p2(scope.WithOpName("p2"), size);
|
||||
ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch.push_back("m");
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
for (const auto& node : output.node()) {
|
||||
if (node.name() == "rank") {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
|
||||
}
|
||||
if (node.name() == "size") {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("^i", node.input(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, NoOpReduction) {
|
||||
// Build a simple graph with a reduction that can be reduced to the identity.
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
|
|
|
|||
|
|
@ -34,13 +34,21 @@ NodeMap::NodeMap(GraphDef* graph) : graph_(graph) {
|
|||
}
|
||||
}
|
||||
|
||||
NodeDef* NodeMap::GetNode(const string& name) {
|
||||
NodeDef* NodeMap::GetNode(const string& name) const {
|
||||
string node_name = NodeName(name);
|
||||
return nodes_[node_name];
|
||||
auto it = nodes_.find(node_name);
|
||||
if (it == nodes_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::set<NodeDef*> NodeMap::GetOutputs(const string& node_name) {
|
||||
return outputs_[node_name];
|
||||
const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const {
|
||||
auto it = outputs_.find(node_name);
|
||||
if (it == outputs_.end()) {
|
||||
return empty_set_;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void NodeMap::AddNode(const string& name, NodeDef* node) {
|
||||
|
|
@ -57,6 +65,17 @@ void NodeMap::UpdateOutput(const string& node, const string& old_output,
|
|||
outputs_[node].insert(nodes_[new_output]);
|
||||
}
|
||||
|
||||
bool IsSameInput(const string& name1, const string& name2) {
|
||||
if (name1 == name2) {
|
||||
return true;
|
||||
}
|
||||
int position1;
|
||||
string node1 = ParseNodeName(name1, &position1);
|
||||
int position2;
|
||||
string node2 = ParseNodeName(name2, &position2);
|
||||
return (position1 == position2) && (node1 == node2);
|
||||
}
|
||||
|
||||
string ParseNodeName(const string& name, int* position) {
|
||||
// Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
|
||||
// to get a node name.
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ namespace grappler {
|
|||
class NodeMap {
|
||||
public:
|
||||
explicit NodeMap(GraphDef* graph);
|
||||
NodeDef* GetNode(const string& name);
|
||||
std::set<NodeDef*> GetOutputs(const string& node_name);
|
||||
NodeDef* GetNode(const string& name) const;
|
||||
const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
|
||||
// This method doesn't record the outputs of the added node; the outputs need
|
||||
// to be explictly added by the AddOutput method.
|
||||
void AddNode(const string& name, NodeDef* node);
|
||||
|
|
@ -42,6 +42,7 @@ class NodeMap {
|
|||
|
||||
private:
|
||||
GraphDef* graph_;
|
||||
std::set<NodeDef*> empty_set_;
|
||||
std::unordered_map<string, NodeDef*> nodes_;
|
||||
std::unordered_map<string, std::set<NodeDef*>> outputs_;
|
||||
};
|
||||
|
|
@ -50,6 +51,9 @@ class NodeMap {
|
|||
// the ^ character.
|
||||
bool IsControlInput(const string& name);
|
||||
|
||||
// True iff 'name1' and 'name2' refer to the same input.
|
||||
bool IsSameInput(const string& name1, const string& name2);
|
||||
|
||||
// Return the node name corresponding to 'name' if name is valid, or the empty
|
||||
// string otherwise.
|
||||
string NodeName(const string& name);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user