Add more recovery functionality to MonitoredSession.run_step_fn.

Current implemention wouldn't recover from one of `_PREEMPTION_ERRORS` during a fetch through the raw session that is made available to the step_fn.

The changelist presents a way to map the desired functionality to the hiearchy of _MonitoredSession > (possibly!) _RecoverableSession > _CoordinatedSession > _HookedSession.

PiperOrigin-RevId: 174053865
This commit is contained in:
Igor Saprykin 2017-10-31 10:20:54 -07:00 committed by TensorFlower Gardener
parent b2ff3ad966
commit 5f1a66ccb4
2 changed files with 387 additions and 12 deletions

View File

@ -496,7 +496,6 @@ class _MonitoredSession(object):
self._sess = _RecoverableSession(self._coordinated_creator)
else:
self._sess = self._coordinated_creator.create_session()
self._stop_requested_in_step_fn = False
@property
def graph(self):
@ -576,11 +575,12 @@ class _MonitoredSession(object):
' `self` and `step_context` arguments if it\'s an instance'
' method. Got {} instead.'.format(step_fn_arguments))
try:
return step_fn(_MonitoredSession.StepContext(self._tf_sess(), self.run))
except StopIteration:
self._stop_requested_in_step_fn = True
raise
# `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
# Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
# `_CoordinatedSession.run` downstream in either case. This allows
# `_PREEMPTION_ERRORS` to propage from within `step_fn` to
# `_RecoverableSession.run_step_fn`.
return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
class StepContext(object):
"""Control flow instrument for the `step_fn` from `run_step_fn()`.
@ -620,8 +620,7 @@ class _MonitoredSession(object):
raise StopIteration('step_fn has requested the iterations to stop.')
def should_stop(self):
return (self._sess is None or self._sess.should_stop() or
self._stop_requested_in_step_fn)
return self._sess is None or self._sess.should_stop()
def close(self):
self._close_internal()
@ -924,6 +923,13 @@ class _WrappedSession(object):
def run(self, *args, **kwargs):
return self._sess.run(*args, **kwargs)
def run_step_fn(self, step_fn, raw_session, run_with_hooks):
# `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`.
# It is `None` when called from `_CoordinatedSession`. In that case
# `self.run` is `_CoordinatedSession.run`.
run_with_hooks = run_with_hooks or self.run
return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks))
class _RecoverableSession(_WrappedSession):
"""A wrapped session that recreates a session upon certain kinds of errors.
@ -996,6 +1002,22 @@ class _RecoverableSession(_WrappedSession):
self.close()
self._sess = None
def run_step_fn(self, step_fn, raw_session, run_with_hooks):
while True:
try:
if not self._sess:
self._sess = self._create_session()
run_with_hooks = self._sess.run
return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks)
except _PREEMPTION_ERRORS as e:
logging.info('An error was raised. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
'created. Error: %s', e)
self.close()
self._sess = None
class _CoordinatedSession(_WrappedSession):
"""A wrapped session that works with a `tf.Coordinator`.

View File

@ -798,6 +798,214 @@ class RecoverableSessionTest(test.TestCase):
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
[StopCoordinatorWithException(calls_before_stopping=2)])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# Even though the coordinator was asked to stop, the underlying session is
# recreated and is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
exception_to_raise=errors_impl.UnknownError(
None, None, 'Some fatal exception inside the coordinator.'))
session = monitored_session.MonitoredSession(session_creator, [hook])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# The coordinator was asked to stop due to non-redeemable error. Training
# should stop and the session should not be recreated.
self.assertTrue(session.should_stop())
self.assertEqual(1, session_creator.number_of_sessions_created)
with self.assertRaises(errors_impl.UnknownError):
session.close()
def test_recovery_from_session_getting_stuck_when_run_hooks(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
[FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# Training will not fail, since it's the call number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# Training will fail during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# Even though the coordinator stopped which and training failed, the
# underlying session is recreated and training is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
def create_raw_session_with_failing_coordinator(self, session_creator, hook):
"""Return MonitoredSession that triggers coordinator failures."""
session = monitored_session.MonitoredSession(session_creator, [hook])
# We would like to test a situation where during fetches through the
# raw session, the coordinator fails with an exception. To do that, we
# are going to use (raw_session + StopCoordinatorWithException) hook
# combination that is stored in
# `MonitoredSession._RecoverableSession._CoordinatedSession._sess`
# at this point:
session._tf_sess = lambda: session._sess._sess._sess
# `run()` on such a session is equivalent to `run()` on the raw session
# with separate coordinator threads independently stopping with an
# exception.
return session
def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
StopCoordinatorWithException(calls_before_stopping=2))
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.session.run(fetches=v, feed_dict={c: value})
return step_fn
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# Even though the coordinator was asked to stop, the underlying session is
# recreated and is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
StopCoordinatorWithException(
calls_before_stopping=2,
exception_to_raise=errors_impl.UnknownError(
None, None, 'Some fatal exception inside the coordinator.')))
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# The coordinator was asked to stop due to non-redeemable error. Training
# should stop and the session should not be recreated.
self.assertTrue(session.should_stop())
self.assertEqual(1, session_creator.number_of_sessions_created)
with self.assertRaises(errors_impl.UnknownError):
session.close()
def test_recovery_from_session_getting_stuck_with_raw_session(self):
with self.test_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
FailTrainingAfterCoordinatorStopped(calls_before_stopping=2))
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# Training will not fail, since it's the call number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# Training will fail during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# Even though the coordinator stopped which and training failed, the
# underlying session is recreated and training is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
class FakeSession(monitored_session._WrappedSession):
@ -1475,6 +1683,7 @@ class MonitoredSessionTest(test.TestCase):
def test_step_request_stop_without_a_with_block(self):
with ops.Graph().as_default():
was_stop_iteration_raised = False
def step_fn(step_context):
step_context.request_stop()
@ -1483,8 +1692,10 @@ class MonitoredSessionTest(test.TestCase):
try:
self.assertEqual(None, session.run_step_fn(step_fn))
except StopIteration:
pass
self.assertTrue(session.should_stop())
was_stop_iteration_raised = True
self.assertTrue(was_stop_iteration_raised)
self.assertFalse(session.should_stop())
def test_step_request_stop_in_a_loop(self):
with ops.Graph().as_default():
@ -1526,8 +1737,7 @@ class MonitoredSessionTest(test.TestCase):
class Model(object):
def step_fn(self, step_context):
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
return value
return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
with monitored_session.MonitoredSession() as session:
model = Model()
@ -1592,6 +1802,38 @@ class MonitoredSessionTest(test.TestCase):
with monitored_session.MonitoredSession(hooks=[Hook(self)]) as session:
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
def test_step_fn_has_the_same_hooks_behavior_without_recovery(self):
with ops.Graph().as_default():
var = resource_variable_ops.ResourceVariable(0.0)
stage_0 = state_ops.assign_add(var, 0.3)
stage_1_0 = state_ops.assign_add(var, 0.7)
with ops.control_dependencies([stage_1_0]):
stage_1_1 = state_ops.assign_add(var, 0.5)
stage_2 = state_ops.assign_add(var, 1.1)
class Hook(session_run_hook.SessionRunHook):
def __init__(self, testing):
self._testing = testing
def before_run(self, run_context):
return session_run_hook.SessionRunArgs(fetches=stage_1_0)
def after_run(self, run_context, run_values):
self._testing.assertNear(0.3 + 0.5 + 0.7,
run_context.session.run(var), 0.1)
self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1,
run_context.session.run(stage_2), 0.1)
def step_fn(step_context):
self.assertNear(0.3, step_context.session.run(stage_0), 0.1)
return step_context.run_with_hooks(fetches=stage_1_1)
with monitored_session.SingularMonitoredSession(
hooks=[Hook(self)]) as session:
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
def test_step_fn_with_hooks_and_request_stop(self):
with ops.Graph().as_default():
trace_the_hook = {'before_run': False, 'after_run': False}
@ -1615,6 +1857,117 @@ class MonitoredSessionTest(test.TestCase):
self.assertFalse(trace_the_hook['before_run'])
self.assertFalse(trace_the_hook['after_run'])
def test_recovers_from_an_exception_in_step_fn(self):
trace_the_exception = {'run_already': False}
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
def step_fn(step_context):
if not trace_the_exception['run_already']:
trace_the_exception['run_already'] = True
raise errors_impl.AbortedError(None, None, 'Abort')
return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
with monitored_session.MonitoredSession() as session:
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
self.assertTrue(trace_the_exception['run_already'])
def test_recovers_from_an_exception_in_step_fn_after_hooks(self):
trace_the_exception = {'run_already': False, 'side_effect_counter': 0}
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
graph_state = variables.Variable(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
def step_fn(step_context):
trace_the_exception['side_effect_counter'] += 1
step_context.session.run(graph_side_effect)
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
if not trace_the_exception['run_already']:
trace_the_exception['run_already'] = True
raise errors_impl.AbortedError(None, None, 'Abort')
return value
with self.test_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session)) as session:
session.run(variables.global_variables_initializer())
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
self.assertTrue(trace_the_exception['run_already'])
# Make sure the rest of the body of the step_fn is re-executed upon
# AbortedError:
self.assertEqual(2, trace_the_exception['side_effect_counter'])
self.assertNear(0.62, session.run(graph_state), 0.1)
def test_step_fn_doesnt_recover_when_it_wasnt_asked_to(self):
trace_the_exception = {'run_already': False}
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
def step_fn(step_context):
if not trace_the_exception['run_already']:
trace_the_exception['run_already'] = True
raise errors_impl.AbortedError(None, None, 'Abort')
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
return value
with monitored_session.SingularMonitoredSession() as session:
with self.assertRaisesRegexp(errors_impl.AbortedError, 'Abort'):
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
self.fail()
self.assertTrue(trace_the_exception['run_already'])
def test_step_fn_exception_from_before_run(self):
trace_the_exception = {'run_already': False, 'side_effect_counter': 0}
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
vv = constant_op.constant(3.2)
graph_state = variables.Variable(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
class Hook(session_run_hook.SessionRunHook):
def __init__(self, testing):
self._testing = testing
def before_run(self, run_context):
if not trace_the_exception['run_already']:
trace_the_exception['run_already'] = True
raise errors_impl.AbortedError(None, None, 'Abort')
return session_run_hook.SessionRunArgs(fetches=vv)
def after_run(self, run_context, run_values):
self._testing.assertNear(3.2, run_values.results, 0.1)
def step_fn(step_context):
trace_the_exception['side_effect_counter'] += 1
step_context.session.run(graph_side_effect)
return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})
with self.test_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session),
hooks=[Hook(self)]) as session:
test_session.run(variables.global_variables_initializer())
self.assertNear(1.3, session.run_step_fn(step_fn), 0.1)
self.assertEqual(2, trace_the_exception['side_effect_counter'])
self.assertNear(0.62, session.run(graph_state), 0.1)
class SingularMonitoredSessionTest(test.TestCase):
"""Tests SingularMonitoredSession."""