Make Graph::IsValidNode public

It can be reimplemented with existing public APIs, but instead of doing so,
making this one public seems better.

PiperOrigin-RevId: 166407897
This commit is contained in:
Igor Ganichev 2017-08-24 16:00:28 -07:00 committed by TensorFlower Gardener
parent 0a2f40e92d
commit b2ce451502
3 changed files with 37 additions and 4 deletions

View File

@ -503,17 +503,17 @@ string Graph::NewName(StringPiece prefix) {
return strings::StrCat(prefix, "/_", name_counter_++);
}
Status Graph::IsValidNode(Node* node) const {
Status Graph::IsValidNode(const Node* node) const {
if (node == nullptr) {
return errors::InvalidArgument("Node is null");
}
const int id = node->id();
if (id < 0) {
return errors::InvalidArgument("node id ", id, "is less than zero");
return errors::InvalidArgument("node id ", id, " is less than zero");
}
if (static_cast<size_t>(id) >= nodes_.size()) {
return errors::InvalidArgument(
"node id ", id, "is >= than number of nodes in graph ", nodes_.size());
"node id ", id, " is >= than number of nodes in graph ", nodes_.size());
}
if (nodes_[id] != node) {
return errors::InvalidArgument("Node with id ", id,

View File

@ -516,10 +516,12 @@ class Graph {
node->assigned_device_name_index_ = InternDeviceName(device_name);
}
// Returns OK if `node` is non-null and belongs to this graph
Status IsValidNode(const Node* node) const;
// TODO(josh11b): uint64 hash() const;
private:
Status IsValidNode(Node* node) const;
// If cost_node is non-null, then cost accounting (in CostModel)
// will be associated with that node rather than the new one being
// created.

View File

@ -379,6 +379,37 @@ TEST_F(GraphTest, NewName) {
EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1;
}
TEST_F(GraphTest, IsValidNode) {
// Add 1 node to graph_
Node* g1_node1;
TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1));
// Add 2 nodes to graph2
Graph graph2(OpRegistry::Global());
Node* g2_node1;
Node* g2_node2;
TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1));
TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2));
// nullptr
Status s = graph_.IsValidNode(nullptr);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
EXPECT_EQ(string("Node is null"), s.error_message());
// node id_ is too high
s = graph_.IsValidNode(g2_node2);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"),
s.error_message());
// valid id_ but different ptr
s = graph_.IsValidNode(g2_node1);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
EXPECT_EQ(string("Node with id 2 is different from the passed in node. "
"Does it belong to a different graph?"),
s.error_message());
}
TEST_F(GraphTest, InputEdges) {
Node* a = FromNodeDef("A", "OneOutput", 0);
Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);