mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Make Graph::IsValidNode public
It can be reimplemented with existing public APIs, but instead of doing so, making this one public seems better. PiperOrigin-RevId: 166407897
This commit is contained in:
parent
0a2f40e92d
commit
b2ce451502
|
|
@ -503,17 +503,17 @@ string Graph::NewName(StringPiece prefix) {
|
|||
return strings::StrCat(prefix, "/_", name_counter_++);
|
||||
}
|
||||
|
||||
Status Graph::IsValidNode(Node* node) const {
|
||||
Status Graph::IsValidNode(const Node* node) const {
|
||||
if (node == nullptr) {
|
||||
return errors::InvalidArgument("Node is null");
|
||||
}
|
||||
const int id = node->id();
|
||||
if (id < 0) {
|
||||
return errors::InvalidArgument("node id ", id, "is less than zero");
|
||||
return errors::InvalidArgument("node id ", id, " is less than zero");
|
||||
}
|
||||
if (static_cast<size_t>(id) >= nodes_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"node id ", id, "is >= than number of nodes in graph ", nodes_.size());
|
||||
"node id ", id, " is >= than number of nodes in graph ", nodes_.size());
|
||||
}
|
||||
if (nodes_[id] != node) {
|
||||
return errors::InvalidArgument("Node with id ", id,
|
||||
|
|
|
|||
|
|
@ -516,10 +516,12 @@ class Graph {
|
|||
node->assigned_device_name_index_ = InternDeviceName(device_name);
|
||||
}
|
||||
|
||||
// Returns OK if `node` is non-null and belongs to this graph
|
||||
Status IsValidNode(const Node* node) const;
|
||||
|
||||
// TODO(josh11b): uint64 hash() const;
|
||||
|
||||
private:
|
||||
Status IsValidNode(Node* node) const;
|
||||
// If cost_node is non-null, then cost accounting (in CostModel)
|
||||
// will be associated with that node rather than the new one being
|
||||
// created.
|
||||
|
|
|
|||
|
|
@ -379,6 +379,37 @@ TEST_F(GraphTest, NewName) {
|
|||
EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1;
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, IsValidNode) {
|
||||
// Add 1 node to graph_
|
||||
Node* g1_node1;
|
||||
TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1));
|
||||
|
||||
// Add 2 nodes to graph2
|
||||
Graph graph2(OpRegistry::Global());
|
||||
Node* g2_node1;
|
||||
Node* g2_node2;
|
||||
TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1));
|
||||
TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2));
|
||||
|
||||
// nullptr
|
||||
Status s = graph_.IsValidNode(nullptr);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
EXPECT_EQ(string("Node is null"), s.error_message());
|
||||
|
||||
// node id_ is too high
|
||||
s = graph_.IsValidNode(g2_node2);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"),
|
||||
s.error_message());
|
||||
|
||||
// valid id_ but different ptr
|
||||
s = graph_.IsValidNode(g2_node1);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
EXPECT_EQ(string("Node with id 2 is different from the passed in node. "
|
||||
"Does it belong to a different graph?"),
|
||||
s.error_message());
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, InputEdges) {
|
||||
Node* a = FromNodeDef("A", "OneOutput", 0);
|
||||
Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user