mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add more recovery functionality to MonitoredSession.run_step_fn.
Current implemention wouldn't recover from one of `_PREEMPTION_ERRORS` during a fetch through the raw session that is made available to the step_fn. The changelist presents a way to map the desired functionality to the hiearchy of _MonitoredSession > (possibly!) _RecoverableSession > _CoordinatedSession > _HookedSession. PiperOrigin-RevId: 174053865
This commit is contained in:
parent
b2ff3ad966
commit
5f1a66ccb4
|
|
@ -496,7 +496,6 @@ class _MonitoredSession(object):
|
||||||
self._sess = _RecoverableSession(self._coordinated_creator)
|
self._sess = _RecoverableSession(self._coordinated_creator)
|
||||||
else:
|
else:
|
||||||
self._sess = self._coordinated_creator.create_session()
|
self._sess = self._coordinated_creator.create_session()
|
||||||
self._stop_requested_in_step_fn = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def graph(self):
|
def graph(self):
|
||||||
|
|
@ -576,11 +575,12 @@ class _MonitoredSession(object):
|
||||||
' `self` and `step_context` arguments if it\'s an instance'
|
' `self` and `step_context` arguments if it\'s an instance'
|
||||||
' method. Got {} instead.'.format(step_fn_arguments))
|
' method. Got {} instead.'.format(step_fn_arguments))
|
||||||
|
|
||||||
try:
|
# `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
|
||||||
return step_fn(_MonitoredSession.StepContext(self._tf_sess(), self.run))
|
# Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
|
||||||
except StopIteration:
|
# `_CoordinatedSession.run` downstream in either case. This allows
|
||||||
self._stop_requested_in_step_fn = True
|
# `_PREEMPTION_ERRORS` to propage from within `step_fn` to
|
||||||
raise
|
# `_RecoverableSession.run_step_fn`.
|
||||||
|
return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
|
||||||
|
|
||||||
class StepContext(object):
|
class StepContext(object):
|
||||||
"""Control flow instrument for the `step_fn` from `run_step_fn()`.
|
"""Control flow instrument for the `step_fn` from `run_step_fn()`.
|
||||||
|
|
@ -620,8 +620,7 @@ class _MonitoredSession(object):
|
||||||
raise StopIteration('step_fn has requested the iterations to stop.')
|
raise StopIteration('step_fn has requested the iterations to stop.')
|
||||||
|
|
||||||
def should_stop(self):
|
def should_stop(self):
|
||||||
return (self._sess is None or self._sess.should_stop() or
|
return self._sess is None or self._sess.should_stop()
|
||||||
self._stop_requested_in_step_fn)
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._close_internal()
|
self._close_internal()
|
||||||
|
|
@ -924,6 +923,13 @@ class _WrappedSession(object):
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
return self._sess.run(*args, **kwargs)
|
return self._sess.run(*args, **kwargs)
|
||||||
|
|
||||||
|
def run_step_fn(self, step_fn, raw_session, run_with_hooks):
|
||||||
|
# `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`.
|
||||||
|
# It is `None` when called from `_CoordinatedSession`. In that case
|
||||||
|
# `self.run` is `_CoordinatedSession.run`.
|
||||||
|
run_with_hooks = run_with_hooks or self.run
|
||||||
|
return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks))
|
||||||
|
|
||||||
|
|
||||||
class _RecoverableSession(_WrappedSession):
|
class _RecoverableSession(_WrappedSession):
|
||||||
"""A wrapped session that recreates a session upon certain kinds of errors.
|
"""A wrapped session that recreates a session upon certain kinds of errors.
|
||||||
|
|
@ -996,6 +1002,22 @@ class _RecoverableSession(_WrappedSession):
|
||||||
self.close()
|
self.close()
|
||||||
self._sess = None
|
self._sess = None
|
||||||
|
|
||||||
|
def run_step_fn(self, step_fn, raw_session, run_with_hooks):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if not self._sess:
|
||||||
|
self._sess = self._create_session()
|
||||||
|
|
||||||
|
run_with_hooks = self._sess.run
|
||||||
|
return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks)
|
||||||
|
except _PREEMPTION_ERRORS as e:
|
||||||
|
logging.info('An error was raised. This may be due to a preemption in '
|
||||||
|
'a connected worker or parameter server. The current '
|
||||||
|
'session will be closed and a new session will be '
|
||||||
|
'created. Error: %s', e)
|
||||||
|
self.close()
|
||||||
|
self._sess = None
|
||||||
|
|
||||||
|
|
||||||
class _CoordinatedSession(_WrappedSession):
|
class _CoordinatedSession(_WrappedSession):
|
||||||
"""A wrapped session that works with a `tf.Coordinator`.
|
"""A wrapped session that works with a `tf.Coordinator`.
|
||||||
|
|
|
||||||
|
|
@ -798,6 +798,214 @@ class RecoverableSessionTest(test.TestCase):
|
||||||
self.assertFalse(session.should_stop())
|
self.assertFalse(session.should_stop())
|
||||||
self.assertEqual(2, session_creator.number_of_sessions_created)
|
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||||
|
|
||||||
|
def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
session = monitored_session.MonitoredSession(
|
||||||
|
session_creator,
|
||||||
|
[StopCoordinatorWithException(calls_before_stopping=2)])
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
def feed_step_fn(value):
|
||||||
|
def step_fn(step_context):
|
||||||
|
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||||
|
return step_fn
|
||||||
|
|
||||||
|
# The coordinator will not abort during this call, since it's the call
|
||||||
|
# number 0.
|
||||||
|
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# The coordinator will abort during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||||
|
# Even though the coordinator was asked to stop, the underlying session is
|
||||||
|
# recreated and is to be continued.
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||||
|
|
||||||
|
def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
hook = StopCoordinatorWithException(
|
||||||
|
calls_before_stopping=2,
|
||||||
|
exception_to_raise=errors_impl.UnknownError(
|
||||||
|
None, None, 'Some fatal exception inside the coordinator.'))
|
||||||
|
session = monitored_session.MonitoredSession(session_creator, [hook])
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
def feed_step_fn(value):
|
||||||
|
def step_fn(step_context):
|
||||||
|
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||||
|
return step_fn
|
||||||
|
|
||||||
|
# The coordinator will not abort during this call, since it's the call
|
||||||
|
# number 0.
|
||||||
|
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# The coordinator will abort during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||||
|
# The coordinator was asked to stop due to non-redeemable error. Training
|
||||||
|
# should stop and the session should not be recreated.
|
||||||
|
self.assertTrue(session.should_stop())
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
with self.assertRaises(errors_impl.UnknownError):
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def test_recovery_from_session_getting_stuck_when_run_hooks(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
session = monitored_session.MonitoredSession(
|
||||||
|
session_creator,
|
||||||
|
[FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
def feed_step_fn(value):
|
||||||
|
def step_fn(step_context):
|
||||||
|
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||||
|
return step_fn
|
||||||
|
|
||||||
|
# Training will not fail, since it's the call number 0.
|
||||||
|
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# Training will fail during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||||
|
# Even though the coordinator stopped which and training failed, the
|
||||||
|
# underlying session is recreated and training is to be continued.
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||||
|
|
||||||
|
def create_raw_session_with_failing_coordinator(self, session_creator, hook):
|
||||||
|
"""Return MonitoredSession that triggers coordinator failures."""
|
||||||
|
session = monitored_session.MonitoredSession(session_creator, [hook])
|
||||||
|
# We would like to test a situation where during fetches through the
|
||||||
|
# raw session, the coordinator fails with an exception. To do that, we
|
||||||
|
# are going to use (raw_session + StopCoordinatorWithException) hook
|
||||||
|
# combination that is stored in
|
||||||
|
# `MonitoredSession._RecoverableSession._CoordinatedSession._sess`
|
||||||
|
# at this point:
|
||||||
|
session._tf_sess = lambda: session._sess._sess._sess
|
||||||
|
# `run()` on such a session is equivalent to `run()` on the raw session
|
||||||
|
# with separate coordinator threads independently stopping with an
|
||||||
|
# exception.
|
||||||
|
return session
|
||||||
|
|
||||||
|
def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
session = self.create_raw_session_with_failing_coordinator(
|
||||||
|
session_creator,
|
||||||
|
StopCoordinatorWithException(calls_before_stopping=2))
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
def feed_step_fn(value):
|
||||||
|
|
||||||
|
def step_fn(step_context):
|
||||||
|
return step_context.session.run(fetches=v, feed_dict={c: value})
|
||||||
|
|
||||||
|
return step_fn
|
||||||
|
|
||||||
|
# The coordinator will not abort during this call, since it's the call
|
||||||
|
# number 0.
|
||||||
|
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# The coordinator will abort during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||||
|
# Even though the coordinator was asked to stop, the underlying session is
|
||||||
|
# recreated and is to be continued.
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||||
|
|
||||||
|
def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
session = self.create_raw_session_with_failing_coordinator(
|
||||||
|
session_creator,
|
||||||
|
StopCoordinatorWithException(
|
||||||
|
calls_before_stopping=2,
|
||||||
|
exception_to_raise=errors_impl.UnknownError(
|
||||||
|
None, None, 'Some fatal exception inside the coordinator.')))
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
def feed_step_fn(value):
|
||||||
|
|
||||||
|
def step_fn(step_context):
|
||||||
|
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||||
|
|
||||||
|
return step_fn
|
||||||
|
|
||||||
|
# The coordinator will not abort during this call, since it's the call
|
||||||
|
# number 0.
|
||||||
|
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# The coordinator will abort during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||||
|
# The coordinator was asked to stop due to non-redeemable error. Training
|
||||||
|
# should stop and the session should not be recreated.
|
||||||
|
self.assertTrue(session.should_stop())
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
with self.assertRaises(errors_impl.UnknownError):
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def test_recovery_from_session_getting_stuck_with_raw_session(self):
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
session_creator = CountingSessionCreator(test_session)
|
||||||
|
session = self.create_raw_session_with_failing_coordinator(
|
||||||
|
session_creator,
|
||||||
|
FailTrainingAfterCoordinatorStopped(calls_before_stopping=2))
|
||||||
|
|
||||||
|
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
|
c = constant_op.constant(0)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
def feed_step_fn(value):
|
||||||
|
|
||||||
|
def step_fn(step_context):
|
||||||
|
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||||
|
|
||||||
|
return step_fn
|
||||||
|
|
||||||
|
# Training will not fail, since it's the call number 0.
|
||||||
|
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
# Training will fail during the next call, since it's the call
|
||||||
|
# number 1.
|
||||||
|
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||||
|
# Even though the coordinator stopped which and training failed, the
|
||||||
|
# underlying session is recreated and training is to be continued.
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||||
|
|
||||||
|
|
||||||
class FakeSession(monitored_session._WrappedSession):
|
class FakeSession(monitored_session._WrappedSession):
|
||||||
|
|
||||||
|
|
@ -1475,6 +1683,7 @@ class MonitoredSessionTest(test.TestCase):
|
||||||
|
|
||||||
def test_step_request_stop_without_a_with_block(self):
|
def test_step_request_stop_without_a_with_block(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
|
was_stop_iteration_raised = False
|
||||||
|
|
||||||
def step_fn(step_context):
|
def step_fn(step_context):
|
||||||
step_context.request_stop()
|
step_context.request_stop()
|
||||||
|
|
@ -1483,8 +1692,10 @@ class MonitoredSessionTest(test.TestCase):
|
||||||
try:
|
try:
|
||||||
self.assertEqual(None, session.run_step_fn(step_fn))
|
self.assertEqual(None, session.run_step_fn(step_fn))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
was_stop_iteration_raised = True
|
||||||
self.assertTrue(session.should_stop())
|
|
||||||
|
self.assertTrue(was_stop_iteration_raised)
|
||||||
|
self.assertFalse(session.should_stop())
|
||||||
|
|
||||||
def test_step_request_stop_in_a_loop(self):
|
def test_step_request_stop_in_a_loop(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
|
|
@ -1526,8 +1737,7 @@ class MonitoredSessionTest(test.TestCase):
|
||||||
class Model(object):
|
class Model(object):
|
||||||
|
|
||||||
def step_fn(self, step_context):
|
def step_fn(self, step_context):
|
||||||
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||||
return value
|
|
||||||
|
|
||||||
with monitored_session.MonitoredSession() as session:
|
with monitored_session.MonitoredSession() as session:
|
||||||
model = Model()
|
model = Model()
|
||||||
|
|
@ -1592,6 +1802,38 @@ class MonitoredSessionTest(test.TestCase):
|
||||||
with monitored_session.MonitoredSession(hooks=[Hook(self)]) as session:
|
with monitored_session.MonitoredSession(hooks=[Hook(self)]) as session:
|
||||||
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
|
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
|
||||||
|
|
||||||
|
def test_step_fn_has_the_same_hooks_behavior_without_recovery(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
var = resource_variable_ops.ResourceVariable(0.0)
|
||||||
|
|
||||||
|
stage_0 = state_ops.assign_add(var, 0.3)
|
||||||
|
stage_1_0 = state_ops.assign_add(var, 0.7)
|
||||||
|
with ops.control_dependencies([stage_1_0]):
|
||||||
|
stage_1_1 = state_ops.assign_add(var, 0.5)
|
||||||
|
stage_2 = state_ops.assign_add(var, 1.1)
|
||||||
|
|
||||||
|
class Hook(session_run_hook.SessionRunHook):
|
||||||
|
|
||||||
|
def __init__(self, testing):
|
||||||
|
self._testing = testing
|
||||||
|
|
||||||
|
def before_run(self, run_context):
|
||||||
|
return session_run_hook.SessionRunArgs(fetches=stage_1_0)
|
||||||
|
|
||||||
|
def after_run(self, run_context, run_values):
|
||||||
|
self._testing.assertNear(0.3 + 0.5 + 0.7,
|
||||||
|
run_context.session.run(var), 0.1)
|
||||||
|
self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1,
|
||||||
|
run_context.session.run(stage_2), 0.1)
|
||||||
|
|
||||||
|
def step_fn(step_context):
|
||||||
|
self.assertNear(0.3, step_context.session.run(stage_0), 0.1)
|
||||||
|
return step_context.run_with_hooks(fetches=stage_1_1)
|
||||||
|
|
||||||
|
with monitored_session.SingularMonitoredSession(
|
||||||
|
hooks=[Hook(self)]) as session:
|
||||||
|
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
|
||||||
|
|
||||||
def test_step_fn_with_hooks_and_request_stop(self):
|
def test_step_fn_with_hooks_and_request_stop(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
trace_the_hook = {'before_run': False, 'after_run': False}
|
trace_the_hook = {'before_run': False, 'after_run': False}
|
||||||
|
|
@ -1615,6 +1857,117 @@ class MonitoredSessionTest(test.TestCase):
|
||||||
self.assertFalse(trace_the_hook['before_run'])
|
self.assertFalse(trace_the_hook['before_run'])
|
||||||
self.assertFalse(trace_the_hook['after_run'])
|
self.assertFalse(trace_the_hook['after_run'])
|
||||||
|
|
||||||
|
def test_recovers_from_an_exception_in_step_fn(self):
|
||||||
|
trace_the_exception = {'run_already': False}
|
||||||
|
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
c = array_ops.placeholder(dtypes.float32)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
def step_fn(step_context):
|
||||||
|
if not trace_the_exception['run_already']:
|
||||||
|
trace_the_exception['run_already'] = True
|
||||||
|
raise errors_impl.AbortedError(None, None, 'Abort')
|
||||||
|
|
||||||
|
return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||||
|
|
||||||
|
with monitored_session.MonitoredSession() as session:
|
||||||
|
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
|
||||||
|
self.assertTrue(trace_the_exception['run_already'])
|
||||||
|
|
||||||
|
def test_recovers_from_an_exception_in_step_fn_after_hooks(self):
|
||||||
|
trace_the_exception = {'run_already': False, 'side_effect_counter': 0}
|
||||||
|
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
c = array_ops.placeholder(dtypes.float32)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
graph_state = variables.Variable(0.0)
|
||||||
|
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
|
||||||
|
|
||||||
|
def step_fn(step_context):
|
||||||
|
trace_the_exception['side_effect_counter'] += 1
|
||||||
|
step_context.session.run(graph_side_effect)
|
||||||
|
|
||||||
|
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||||
|
|
||||||
|
if not trace_the_exception['run_already']:
|
||||||
|
trace_the_exception['run_already'] = True
|
||||||
|
raise errors_impl.AbortedError(None, None, 'Abort')
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
with monitored_session.MonitoredSession(
|
||||||
|
CountingSessionCreator(test_session)) as session:
|
||||||
|
session.run(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
|
||||||
|
self.assertTrue(trace_the_exception['run_already'])
|
||||||
|
# Make sure the rest of the body of the step_fn is re-executed upon
|
||||||
|
# AbortedError:
|
||||||
|
self.assertEqual(2, trace_the_exception['side_effect_counter'])
|
||||||
|
self.assertNear(0.62, session.run(graph_state), 0.1)
|
||||||
|
|
||||||
|
def test_step_fn_doesnt_recover_when_it_wasnt_asked_to(self):
|
||||||
|
trace_the_exception = {'run_already': False}
|
||||||
|
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
c = array_ops.placeholder(dtypes.float32)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
|
||||||
|
def step_fn(step_context):
|
||||||
|
if not trace_the_exception['run_already']:
|
||||||
|
trace_the_exception['run_already'] = True
|
||||||
|
raise errors_impl.AbortedError(None, None, 'Abort')
|
||||||
|
|
||||||
|
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||||
|
return value
|
||||||
|
|
||||||
|
with monitored_session.SingularMonitoredSession() as session:
|
||||||
|
with self.assertRaisesRegexp(errors_impl.AbortedError, 'Abort'):
|
||||||
|
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
|
||||||
|
self.fail()
|
||||||
|
|
||||||
|
self.assertTrue(trace_the_exception['run_already'])
|
||||||
|
|
||||||
|
def test_step_fn_exception_from_before_run(self):
|
||||||
|
trace_the_exception = {'run_already': False, 'side_effect_counter': 0}
|
||||||
|
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
c = array_ops.placeholder(dtypes.float32)
|
||||||
|
v = array_ops.identity(c)
|
||||||
|
vv = constant_op.constant(3.2)
|
||||||
|
graph_state = variables.Variable(0.0)
|
||||||
|
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
|
||||||
|
|
||||||
|
class Hook(session_run_hook.SessionRunHook):
|
||||||
|
|
||||||
|
def __init__(self, testing):
|
||||||
|
self._testing = testing
|
||||||
|
|
||||||
|
def before_run(self, run_context):
|
||||||
|
if not trace_the_exception['run_already']:
|
||||||
|
trace_the_exception['run_already'] = True
|
||||||
|
raise errors_impl.AbortedError(None, None, 'Abort')
|
||||||
|
return session_run_hook.SessionRunArgs(fetches=vv)
|
||||||
|
|
||||||
|
def after_run(self, run_context, run_values):
|
||||||
|
self._testing.assertNear(3.2, run_values.results, 0.1)
|
||||||
|
|
||||||
|
def step_fn(step_context):
|
||||||
|
trace_the_exception['side_effect_counter'] += 1
|
||||||
|
step_context.session.run(graph_side_effect)
|
||||||
|
return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})
|
||||||
|
|
||||||
|
with self.test_session() as test_session:
|
||||||
|
with monitored_session.MonitoredSession(
|
||||||
|
CountingSessionCreator(test_session),
|
||||||
|
hooks=[Hook(self)]) as session:
|
||||||
|
test_session.run(variables.global_variables_initializer())
|
||||||
|
self.assertNear(1.3, session.run_step_fn(step_fn), 0.1)
|
||||||
|
self.assertEqual(2, trace_the_exception['side_effect_counter'])
|
||||||
|
self.assertNear(0.62, session.run(graph_state), 0.1)
|
||||||
|
|
||||||
|
|
||||||
class SingularMonitoredSessionTest(test.TestCase):
|
class SingularMonitoredSessionTest(test.TestCase):
|
||||||
"""Tests SingularMonitoredSession."""
|
"""Tests SingularMonitoredSession."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user