mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
check invalid string type for dest_nodes in extract_sub_graph (#13057)
* BUG: check str type * TST: add unit test * CLN: remove list check * CLN: use warning * CLN: 2 indent * CLN: raise TypeError if not list * CLN: check string only
This commit is contained in:
parent
d2d42ee8b3
commit
fe3a2e65cc
|
|
@ -21,6 +21,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
import six
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
|
|
@ -123,6 +124,9 @@ def extract_sub_graph(graph_def, dest_nodes):
|
||||||
if not isinstance(graph_def, graph_pb2.GraphDef):
|
if not isinstance(graph_def, graph_pb2.GraphDef):
|
||||||
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
|
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
|
||||||
|
|
||||||
|
if isinstance(dest_nodes, six.string_types):
|
||||||
|
raise TypeError("dest_nodes must be a list.")
|
||||||
|
|
||||||
edges = {} # Keyed by the dest node name.
|
edges = {} # Keyed by the dest node name.
|
||||||
name_to_node_map = {} # Keyed by node name.
|
name_to_node_map = {} # Keyed by node name.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -188,6 +188,13 @@ class DeviceFunctionsTest(test.TestCase):
|
||||||
self.assertEqual("n3", sub_graph.node[2].name)
|
self.assertEqual("n3", sub_graph.node[2].name)
|
||||||
self.assertEqual("n5", sub_graph.node[3].name)
|
self.assertEqual("n5", sub_graph.node[3].name)
|
||||||
|
|
||||||
|
def testExtractSubGraphWithInvalidDestNodes(self):
|
||||||
|
graph_def = graph_pb2.GraphDef()
|
||||||
|
n1 = graph_def.node.add()
|
||||||
|
n1.name = "n1"
|
||||||
|
with self.assertRaisesRegexp(TypeError, "must be a list"):
|
||||||
|
graph_util.extract_sub_graph(graph_def, "n1")
|
||||||
|
|
||||||
def testConvertVariablesToConstsWithFunctions(self):
|
def testConvertVariablesToConstsWithFunctions(self):
|
||||||
@function.Defun(dtypes.float32)
|
@function.Defun(dtypes.float32)
|
||||||
def plus_one(x):
|
def plus_one(x):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user