mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Identify frame ids for all nodes in a graph.
PiperOrigin-RevId: 166397615
This commit is contained in:
parent
989713f265
commit
c4a58e3fdd
|
|
@ -45,6 +45,11 @@ bool IsEnter(const NodeDef& node) {
|
||||||
return op == "Enter" || op == "RefEnter";
|
return op == "Enter" || op == "RefEnter";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsExit(const NodeDef& node) {
|
||||||
|
const auto& op = node.op();
|
||||||
|
return op == "Exit" || op == "RefExit";
|
||||||
|
}
|
||||||
|
|
||||||
bool IsIdentity(const NodeDef& node) {
|
bool IsIdentity(const NodeDef& node) {
|
||||||
const auto& op = node.op();
|
const auto& op = node.op();
|
||||||
return op == "Identity" || op == "RefIdentity";
|
return op == "Identity" || op == "RefIdentity";
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ bool IsConcat(const NodeDef& node);
|
||||||
bool IsConstant(const NodeDef& node);
|
bool IsConstant(const NodeDef& node);
|
||||||
bool IsDequeueOp(const NodeDef& node);
|
bool IsDequeueOp(const NodeDef& node);
|
||||||
bool IsEnter(const NodeDef& node);
|
bool IsEnter(const NodeDef& node);
|
||||||
|
bool IsExit(const NodeDef& node);
|
||||||
bool IsIdentity(const NodeDef& node);
|
bool IsIdentity(const NodeDef& node);
|
||||||
bool IsMerge(const NodeDef& node);
|
bool IsMerge(const NodeDef& node);
|
||||||
bool IsNextIteration(const NodeDef& node);
|
bool IsNextIteration(const NodeDef& node);
|
||||||
|
|
|
||||||
|
|
@ -71,3 +71,29 @@ cc_test(
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "frame",
|
||||||
|
srcs = ["frame.cc"],
|
||||||
|
hdrs = ["frame.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
|
"//tensorflow/core/grappler:utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "frame_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["frame_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":frame",
|
||||||
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
||||||
62
tensorflow/core/grappler/utils/frame.cc
Normal file
62
tensorflow/core/grappler/utils/frame.cc
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/utils/frame.h"
|
||||||
|
#include <deque>
|
||||||
|
#include <stack>
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
int IdentifyFrames(
|
||||||
|
const GraphDef& graph,
|
||||||
|
std::unordered_map<const NodeDef*, std::vector<int>>* frames) {
|
||||||
|
NodeMap node_map(const_cast<GraphDef*>(&graph));
|
||||||
|
std::deque<std::pair<const NodeDef*, std::vector<int>>> ready_nodes;
|
||||||
|
for (const NodeDef& node : graph.node()) {
|
||||||
|
if (node.input_size() == 0) {
|
||||||
|
std::vector<int> empty;
|
||||||
|
ready_nodes.emplace_back(&node, empty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int frame_id = 0;
|
||||||
|
while (!ready_nodes.empty()) {
|
||||||
|
auto ready_node = ready_nodes.front();
|
||||||
|
for (const auto& fanout : node_map.GetOutputs(ready_node.first->name())) {
|
||||||
|
if (frames->count(fanout) < 1) {
|
||||||
|
std::vector<int> frame_ids = ready_node.second;
|
||||||
|
if (IsExit(*ready_node.first)) {
|
||||||
|
frame_ids.pop_back();
|
||||||
|
}
|
||||||
|
if (IsEnter(*fanout)) {
|
||||||
|
frame_ids.push_back(frame_id);
|
||||||
|
frame_id++;
|
||||||
|
}
|
||||||
|
ready_nodes.emplace_back(fanout, frame_ids);
|
||||||
|
} else {
|
||||||
|
CHECK(ready_node.second == (*frames)[fanout]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(*frames)[ready_node.first] = ready_node.second;
|
||||||
|
ready_nodes.pop_front();
|
||||||
|
}
|
||||||
|
return frame_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace grappler
|
||||||
|
} // namespace tensorflow
|
||||||
35
tensorflow/core/grappler/utils/frame.h
Normal file
35
tensorflow/core/grappler/utils/frame.h
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
// Returns the number of frames present in the graph, and populates
|
||||||
|
// the 'frames' argument with the collection of frames (denoted by their
|
||||||
|
// frame ids) in the outermost-to-innermost order. Frame ids are arbitrary.
|
||||||
|
int IdentifyFrames(
|
||||||
|
const GraphDef& graph,
|
||||||
|
std::unordered_map<const NodeDef*, std::vector<int>>* frames);
|
||||||
|
|
||||||
|
} // namespace grappler
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
|
||||||
89
tensorflow/core/grappler/utils/frame_test.cc
Normal file
89
tensorflow/core/grappler/utils/frame_test.cc
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/utils/frame.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class IdentifyFramesTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
static NodeDef CreateNode(const string& name,
|
||||||
|
const std::vector<string>& inputs) {
|
||||||
|
return CreateNode(name, "", inputs);
|
||||||
|
}
|
||||||
|
static NodeDef CreateNode(const string& name, const string& op,
|
||||||
|
const std::vector<string>& inputs) {
|
||||||
|
NodeDef node;
|
||||||
|
node.set_name(name);
|
||||||
|
if (!op.empty()) {
|
||||||
|
node.set_op(op);
|
||||||
|
}
|
||||||
|
for (const string& input : inputs) {
|
||||||
|
node.add_input(input);
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(IdentifyFramesTest, WithLoop) {
|
||||||
|
GraphDef graph;
|
||||||
|
// Create a two-level nested loop
|
||||||
|
*graph.add_node() = CreateNode("0", {});
|
||||||
|
*graph.add_node() = CreateNode("1", "Enter", {"0"});
|
||||||
|
*graph.add_node() = CreateNode("2", {"1"});
|
||||||
|
*graph.add_node() = CreateNode("3", "Merge", {"2", "14"});
|
||||||
|
*graph.add_node() = CreateNode("4", {"3"});
|
||||||
|
*graph.add_node() = CreateNode("5", "Switch", {"4"});
|
||||||
|
*graph.add_node() = CreateNode("6", {"5"});
|
||||||
|
*graph.add_node() = CreateNode("7", "Enter", {"6"});
|
||||||
|
*graph.add_node() = CreateNode("8", {"7"});
|
||||||
|
*graph.add_node() = CreateNode("9", "Merge", {"8", "12"});
|
||||||
|
*graph.add_node() = CreateNode("10", {"9"});
|
||||||
|
*graph.add_node() = CreateNode("11", "Switch", {"10"});
|
||||||
|
*graph.add_node() = CreateNode("12", "NextIteration", {"11"});
|
||||||
|
*graph.add_node() = CreateNode("13", "Exit", {"11"});
|
||||||
|
*graph.add_node() = CreateNode("14", "NextIteration", {"13"});
|
||||||
|
*graph.add_node() = CreateNode("15", {"5"});
|
||||||
|
*graph.add_node() = CreateNode("16", "Exit", {"15"});
|
||||||
|
*graph.add_node() = CreateNode("17", {"16"});
|
||||||
|
|
||||||
|
std::unordered_map<const NodeDef*, std::vector<int>> frames;
|
||||||
|
int num_frames = IdentifyFrames(graph, &frames);
|
||||||
|
std::unordered_map<string, std::vector<int>> expected = {
|
||||||
|
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}},
|
||||||
|
{"4", {0}}, {"5", {0}}, {"6", {0}}, {"7", {0, 1}},
|
||||||
|
{"8", {0, 1}}, {"9", {0, 1}}, {"10", {0, 1}}, {"11", {0, 1}},
|
||||||
|
{"12", {0, 1}}, {"13", {0, 1}}, {"14", {0}}, {"15", {0}},
|
||||||
|
{"16", {0}}, {"17", {}}};
|
||||||
|
EXPECT_EQ(num_frames, 2);
|
||||||
|
std::cout << "Number of frame: " << num_frames << std::endl;
|
||||||
|
for (const auto& node : frames) {
|
||||||
|
std::cout << node.first->name() << ": ";
|
||||||
|
for (int i = 0; i < node.second.size(); i++) {
|
||||||
|
EXPECT_EQ(expected[node.first->name()][i], node.second[i]);
|
||||||
|
std::cout << node.second[i] << " ";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace grappler
|
||||||
|
} // namespace tensorflow
|
||||||
Loading…
Reference in New Issue
Block a user