mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add type error to start_queue_runners if given session is not a tf.Session. Due to semver, we suppress the error if a MonitoredSession is provided.
PiperOrigin-RevId: 157748375
This commit is contained in:
parent
7106f9fac3
commit
7ad0d0698a
|
|
@ -22,6 +22,7 @@ import threading
|
|||
import weakref
|
||||
|
||||
from tensorflow.core.protobuf import queue_runner_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
|
@ -401,6 +402,10 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
|
|||
collection: A `GraphKey` specifying the graph collection to
|
||||
get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `sess` is None and there isn't any default session.
|
||||
TypeError: if `sess` is not a `tf.Session` object.
|
||||
|
||||
Returns:
|
||||
A list of threads.
|
||||
"""
|
||||
|
|
@ -410,6 +415,15 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
|
|||
raise ValueError("Cannot start queue runners: No default session is "
|
||||
"registered. Use `with sess.as_default()` or pass an "
|
||||
"explicit session to tf.start_queue_runners(sess=sess)")
|
||||
|
||||
if not isinstance(sess, session.SessionInterface):
|
||||
# Following check is due to backward compatibility. (b/62061352)
|
||||
if sess.__class__.__name__ in [
|
||||
"MonitoredSession", "SingularMonitoredSession"]:
|
||||
return []
|
||||
raise TypeError("sess must be a `tf.Session` object. "
|
||||
"Given class: {}".format(sess.__class__))
|
||||
|
||||
with sess.graph.as_default():
|
||||
threads = []
|
||||
for qr in ops.get_collection(collection):
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from tensorflow.python.ops import data_flow_ops
|
|||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
|
||||
|
||||
|
|
@ -247,6 +248,33 @@ class QueueRunnerTest(test.TestCase):
|
|||
# The variable should be 3.
|
||||
self.assertEqual(3, var.eval())
|
||||
|
||||
def testStartQueueRunnersRaisesIfNotASession(self):
|
||||
zero64 = constant_op.constant(0, dtype=dtypes.int64)
|
||||
var = variables.Variable(zero64)
|
||||
count_up_to = var.count_up_to(3)
|
||||
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
|
||||
init_op = variables.global_variables_initializer()
|
||||
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
|
||||
queue_runner_impl.add_queue_runner(qr)
|
||||
with self.test_session():
|
||||
init_op.run()
|
||||
with self.assertRaisesRegexp(TypeError, "tf.Session"):
|
||||
queue_runner_impl.start_queue_runners("NotASession")
|
||||
|
||||
def testStartQueueRunnersIgnoresMonitoredSession(self):
|
||||
zero64 = constant_op.constant(0, dtype=dtypes.int64)
|
||||
var = variables.Variable(zero64)
|
||||
count_up_to = var.count_up_to(3)
|
||||
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
|
||||
init_op = variables.global_variables_initializer()
|
||||
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
|
||||
queue_runner_impl.add_queue_runner(qr)
|
||||
with self.test_session():
|
||||
init_op.run()
|
||||
threads = queue_runner_impl.start_queue_runners(
|
||||
monitored_session.MonitoredSession())
|
||||
self.assertFalse(threads)
|
||||
|
||||
def testStartQueueRunnersNonDefaultGraph(self):
|
||||
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
|
||||
graph = ops.Graph()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user