mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add more recovery functionality to MonitoredSession.run_step_fn.
Current implemention wouldn't recover from one of `_PREEMPTION_ERRORS` during a fetch through the raw session that is made available to the step_fn. The changelist presents a way to map the desired functionality to the hiearchy of _MonitoredSession > (possibly!) _RecoverableSession > _CoordinatedSession > _HookedSession. PiperOrigin-RevId: 174053865
This commit is contained in:
parent
b2ff3ad966
commit
5f1a66ccb4
|
|
@ -496,7 +496,6 @@ class _MonitoredSession(object):
|
|||
self._sess = _RecoverableSession(self._coordinated_creator)
|
||||
else:
|
||||
self._sess = self._coordinated_creator.create_session()
|
||||
self._stop_requested_in_step_fn = False
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
|
|
@ -576,11 +575,12 @@ class _MonitoredSession(object):
|
|||
' `self` and `step_context` arguments if it\'s an instance'
|
||||
' method. Got {} instead.'.format(step_fn_arguments))
|
||||
|
||||
try:
|
||||
return step_fn(_MonitoredSession.StepContext(self._tf_sess(), self.run))
|
||||
except StopIteration:
|
||||
self._stop_requested_in_step_fn = True
|
||||
raise
|
||||
# `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
|
||||
# Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
|
||||
# `_CoordinatedSession.run` downstream in either case. This allows
|
||||
# `_PREEMPTION_ERRORS` to propage from within `step_fn` to
|
||||
# `_RecoverableSession.run_step_fn`.
|
||||
return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
|
||||
|
||||
class StepContext(object):
|
||||
"""Control flow instrument for the `step_fn` from `run_step_fn()`.
|
||||
|
|
@ -620,8 +620,7 @@ class _MonitoredSession(object):
|
|||
raise StopIteration('step_fn has requested the iterations to stop.')
|
||||
|
||||
def should_stop(self):
|
||||
return (self._sess is None or self._sess.should_stop() or
|
||||
self._stop_requested_in_step_fn)
|
||||
return self._sess is None or self._sess.should_stop()
|
||||
|
||||
def close(self):
|
||||
self._close_internal()
|
||||
|
|
@ -924,6 +923,13 @@ class _WrappedSession(object):
|
|||
def run(self, *args, **kwargs):
|
||||
return self._sess.run(*args, **kwargs)
|
||||
|
||||
def run_step_fn(self, step_fn, raw_session, run_with_hooks):
|
||||
# `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`.
|
||||
# It is `None` when called from `_CoordinatedSession`. In that case
|
||||
# `self.run` is `_CoordinatedSession.run`.
|
||||
run_with_hooks = run_with_hooks or self.run
|
||||
return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks))
|
||||
|
||||
|
||||
class _RecoverableSession(_WrappedSession):
|
||||
"""A wrapped session that recreates a session upon certain kinds of errors.
|
||||
|
|
@ -996,6 +1002,22 @@ class _RecoverableSession(_WrappedSession):
|
|||
self.close()
|
||||
self._sess = None
|
||||
|
||||
def run_step_fn(self, step_fn, raw_session, run_with_hooks):
|
||||
while True:
|
||||
try:
|
||||
if not self._sess:
|
||||
self._sess = self._create_session()
|
||||
|
||||
run_with_hooks = self._sess.run
|
||||
return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks)
|
||||
except _PREEMPTION_ERRORS as e:
|
||||
logging.info('An error was raised. This may be due to a preemption in '
|
||||
'a connected worker or parameter server. The current '
|
||||
'session will be closed and a new session will be '
|
||||
'created. Error: %s', e)
|
||||
self.close()
|
||||
self._sess = None
|
||||
|
||||
|
||||
class _CoordinatedSession(_WrappedSession):
|
||||
"""A wrapped session that works with a `tf.Coordinator`.
|
||||
|
|
|
|||
|
|
@ -798,6 +798,214 @@ class RecoverableSessionTest(test.TestCase):
|
|||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||
|
||||
def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
|
||||
with self.test_session() as test_session:
|
||||
session_creator = CountingSessionCreator(test_session)
|
||||
session = monitored_session.MonitoredSession(
|
||||
session_creator,
|
||||
[StopCoordinatorWithException(calls_before_stopping=2)])
|
||||
|
||||
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
c = constant_op.constant(0)
|
||||
v = array_ops.identity(c)
|
||||
|
||||
def feed_step_fn(value):
|
||||
def step_fn(step_context):
|
||||
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||
return step_fn
|
||||
|
||||
# The coordinator will not abort during this call, since it's the call
|
||||
# number 0.
|
||||
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||
self.assertFalse(session.should_stop())
|
||||
# The coordinator will abort during the next call, since it's the call
|
||||
# number 1.
|
||||
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||
# Even though the coordinator was asked to stop, the underlying session is
|
||||
# recreated and is to be continued.
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||
|
||||
def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
|
||||
with self.test_session() as test_session:
|
||||
session_creator = CountingSessionCreator(test_session)
|
||||
hook = StopCoordinatorWithException(
|
||||
calls_before_stopping=2,
|
||||
exception_to_raise=errors_impl.UnknownError(
|
||||
None, None, 'Some fatal exception inside the coordinator.'))
|
||||
session = monitored_session.MonitoredSession(session_creator, [hook])
|
||||
|
||||
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
c = constant_op.constant(0)
|
||||
v = array_ops.identity(c)
|
||||
|
||||
def feed_step_fn(value):
|
||||
def step_fn(step_context):
|
||||
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||
return step_fn
|
||||
|
||||
# The coordinator will not abort during this call, since it's the call
|
||||
# number 0.
|
||||
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||
self.assertFalse(session.should_stop())
|
||||
# The coordinator will abort during the next call, since it's the call
|
||||
# number 1.
|
||||
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||
# The coordinator was asked to stop due to non-redeemable error. Training
|
||||
# should stop and the session should not be recreated.
|
||||
self.assertTrue(session.should_stop())
|
||||
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||
with self.assertRaises(errors_impl.UnknownError):
|
||||
session.close()
|
||||
|
||||
def test_recovery_from_session_getting_stuck_when_run_hooks(self):
|
||||
with self.test_session() as test_session:
|
||||
session_creator = CountingSessionCreator(test_session)
|
||||
session = monitored_session.MonitoredSession(
|
||||
session_creator,
|
||||
[FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])
|
||||
|
||||
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
c = constant_op.constant(0)
|
||||
v = array_ops.identity(c)
|
||||
|
||||
def feed_step_fn(value):
|
||||
def step_fn(step_context):
|
||||
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||
return step_fn
|
||||
|
||||
# Training will not fail, since it's the call number 0.
|
||||
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||
self.assertFalse(session.should_stop())
|
||||
# Training will fail during the next call, since it's the call
|
||||
# number 1.
|
||||
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||
# Even though the coordinator stopped which and training failed, the
|
||||
# underlying session is recreated and training is to be continued.
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||
|
||||
def create_raw_session_with_failing_coordinator(self, session_creator, hook):
|
||||
"""Return MonitoredSession that triggers coordinator failures."""
|
||||
session = monitored_session.MonitoredSession(session_creator, [hook])
|
||||
# We would like to test a situation where during fetches through the
|
||||
# raw session, the coordinator fails with an exception. To do that, we
|
||||
# are going to use (raw_session + StopCoordinatorWithException) hook
|
||||
# combination that is stored in
|
||||
# `MonitoredSession._RecoverableSession._CoordinatedSession._sess`
|
||||
# at this point:
|
||||
session._tf_sess = lambda: session._sess._sess._sess
|
||||
# `run()` on such a session is equivalent to `run()` on the raw session
|
||||
# with separate coordinator threads independently stopping with an
|
||||
# exception.
|
||||
return session
|
||||
|
||||
def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
|
||||
with self.test_session() as test_session:
|
||||
session_creator = CountingSessionCreator(test_session)
|
||||
session = self.create_raw_session_with_failing_coordinator(
|
||||
session_creator,
|
||||
StopCoordinatorWithException(calls_before_stopping=2))
|
||||
|
||||
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
c = constant_op.constant(0)
|
||||
v = array_ops.identity(c)
|
||||
|
||||
def feed_step_fn(value):
|
||||
|
||||
def step_fn(step_context):
|
||||
return step_context.session.run(fetches=v, feed_dict={c: value})
|
||||
|
||||
return step_fn
|
||||
|
||||
# The coordinator will not abort during this call, since it's the call
|
||||
# number 0.
|
||||
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||
self.assertFalse(session.should_stop())
|
||||
# The coordinator will abort during the next call, since it's the call
|
||||
# number 1.
|
||||
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||
# Even though the coordinator was asked to stop, the underlying session is
|
||||
# recreated and is to be continued.
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||
|
||||
def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
|
||||
with self.test_session() as test_session:
|
||||
session_creator = CountingSessionCreator(test_session)
|
||||
session = self.create_raw_session_with_failing_coordinator(
|
||||
session_creator,
|
||||
StopCoordinatorWithException(
|
||||
calls_before_stopping=2,
|
||||
exception_to_raise=errors_impl.UnknownError(
|
||||
None, None, 'Some fatal exception inside the coordinator.')))
|
||||
|
||||
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
c = constant_op.constant(0)
|
||||
v = array_ops.identity(c)
|
||||
|
||||
def feed_step_fn(value):
|
||||
|
||||
def step_fn(step_context):
|
||||
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||
|
||||
return step_fn
|
||||
|
||||
# The coordinator will not abort during this call, since it's the call
|
||||
# number 0.
|
||||
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||
self.assertFalse(session.should_stop())
|
||||
# The coordinator will abort during the next call, since it's the call
|
||||
# number 1.
|
||||
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||
# The coordinator was asked to stop due to non-redeemable error. Training
|
||||
# should stop and the session should not be recreated.
|
||||
self.assertTrue(session.should_stop())
|
||||
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||
with self.assertRaises(errors_impl.UnknownError):
|
||||
session.close()
|
||||
|
||||
def test_recovery_from_session_getting_stuck_with_raw_session(self):
|
||||
with self.test_session() as test_session:
|
||||
session_creator = CountingSessionCreator(test_session)
|
||||
session = self.create_raw_session_with_failing_coordinator(
|
||||
session_creator,
|
||||
FailTrainingAfterCoordinatorStopped(calls_before_stopping=2))
|
||||
|
||||
self.assertEqual(1, session_creator.number_of_sessions_created)
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
c = constant_op.constant(0)
|
||||
v = array_ops.identity(c)
|
||||
|
||||
def feed_step_fn(value):
|
||||
|
||||
def step_fn(step_context):
|
||||
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
|
||||
|
||||
return step_fn
|
||||
|
||||
# Training will not fail, since it's the call number 0.
|
||||
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
|
||||
self.assertFalse(session.should_stop())
|
||||
# Training will fail during the next call, since it's the call
|
||||
# number 1.
|
||||
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
|
||||
# Even though the coordinator stopped which and training failed, the
|
||||
# underlying session is recreated and training is to be continued.
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(2, session_creator.number_of_sessions_created)
|
||||
|
||||
|
||||
class FakeSession(monitored_session._WrappedSession):
|
||||
|
||||
|
|
@ -1475,6 +1683,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||
|
||||
def test_step_request_stop_without_a_with_block(self):
|
||||
with ops.Graph().as_default():
|
||||
was_stop_iteration_raised = False
|
||||
|
||||
def step_fn(step_context):
|
||||
step_context.request_stop()
|
||||
|
|
@ -1483,8 +1692,10 @@ class MonitoredSessionTest(test.TestCase):
|
|||
try:
|
||||
self.assertEqual(None, session.run_step_fn(step_fn))
|
||||
except StopIteration:
|
||||
pass
|
||||
self.assertTrue(session.should_stop())
|
||||
was_stop_iteration_raised = True
|
||||
|
||||
self.assertTrue(was_stop_iteration_raised)
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
def test_step_request_stop_in_a_loop(self):
|
||||
with ops.Graph().as_default():
|
||||
|
|
@ -1526,8 +1737,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||
class Model(object):
|
||||
|
||||
def step_fn(self, step_context):
|
||||
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||
return value
|
||||
return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||
|
||||
with monitored_session.MonitoredSession() as session:
|
||||
model = Model()
|
||||
|
|
@ -1592,6 +1802,38 @@ class MonitoredSessionTest(test.TestCase):
|
|||
with monitored_session.MonitoredSession(hooks=[Hook(self)]) as session:
|
||||
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
|
||||
|
||||
def test_step_fn_has_the_same_hooks_behavior_without_recovery(self):
|
||||
with ops.Graph().as_default():
|
||||
var = resource_variable_ops.ResourceVariable(0.0)
|
||||
|
||||
stage_0 = state_ops.assign_add(var, 0.3)
|
||||
stage_1_0 = state_ops.assign_add(var, 0.7)
|
||||
with ops.control_dependencies([stage_1_0]):
|
||||
stage_1_1 = state_ops.assign_add(var, 0.5)
|
||||
stage_2 = state_ops.assign_add(var, 1.1)
|
||||
|
||||
class Hook(session_run_hook.SessionRunHook):
|
||||
|
||||
def __init__(self, testing):
|
||||
self._testing = testing
|
||||
|
||||
def before_run(self, run_context):
|
||||
return session_run_hook.SessionRunArgs(fetches=stage_1_0)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
self._testing.assertNear(0.3 + 0.5 + 0.7,
|
||||
run_context.session.run(var), 0.1)
|
||||
self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1,
|
||||
run_context.session.run(stage_2), 0.1)
|
||||
|
||||
def step_fn(step_context):
|
||||
self.assertNear(0.3, step_context.session.run(stage_0), 0.1)
|
||||
return step_context.run_with_hooks(fetches=stage_1_1)
|
||||
|
||||
with monitored_session.SingularMonitoredSession(
|
||||
hooks=[Hook(self)]) as session:
|
||||
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
|
||||
|
||||
def test_step_fn_with_hooks_and_request_stop(self):
|
||||
with ops.Graph().as_default():
|
||||
trace_the_hook = {'before_run': False, 'after_run': False}
|
||||
|
|
@ -1615,6 +1857,117 @@ class MonitoredSessionTest(test.TestCase):
|
|||
self.assertFalse(trace_the_hook['before_run'])
|
||||
self.assertFalse(trace_the_hook['after_run'])
|
||||
|
||||
def test_recovers_from_an_exception_in_step_fn(self):
|
||||
trace_the_exception = {'run_already': False}
|
||||
|
||||
with ops.Graph().as_default():
|
||||
c = array_ops.placeholder(dtypes.float32)
|
||||
v = array_ops.identity(c)
|
||||
|
||||
def step_fn(step_context):
|
||||
if not trace_the_exception['run_already']:
|
||||
trace_the_exception['run_already'] = True
|
||||
raise errors_impl.AbortedError(None, None, 'Abort')
|
||||
|
||||
return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||
|
||||
with monitored_session.MonitoredSession() as session:
|
||||
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
|
||||
self.assertTrue(trace_the_exception['run_already'])
|
||||
|
||||
def test_recovers_from_an_exception_in_step_fn_after_hooks(self):
|
||||
trace_the_exception = {'run_already': False, 'side_effect_counter': 0}
|
||||
|
||||
with ops.Graph().as_default():
|
||||
c = array_ops.placeholder(dtypes.float32)
|
||||
v = array_ops.identity(c)
|
||||
graph_state = variables.Variable(0.0)
|
||||
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
|
||||
|
||||
def step_fn(step_context):
|
||||
trace_the_exception['side_effect_counter'] += 1
|
||||
step_context.session.run(graph_side_effect)
|
||||
|
||||
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||
|
||||
if not trace_the_exception['run_already']:
|
||||
trace_the_exception['run_already'] = True
|
||||
raise errors_impl.AbortedError(None, None, 'Abort')
|
||||
|
||||
return value
|
||||
|
||||
with self.test_session() as test_session:
|
||||
with monitored_session.MonitoredSession(
|
||||
CountingSessionCreator(test_session)) as session:
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
|
||||
self.assertTrue(trace_the_exception['run_already'])
|
||||
# Make sure the rest of the body of the step_fn is re-executed upon
|
||||
# AbortedError:
|
||||
self.assertEqual(2, trace_the_exception['side_effect_counter'])
|
||||
self.assertNear(0.62, session.run(graph_state), 0.1)
|
||||
|
||||
def test_step_fn_doesnt_recover_when_it_wasnt_asked_to(self):
|
||||
trace_the_exception = {'run_already': False}
|
||||
|
||||
with ops.Graph().as_default():
|
||||
c = array_ops.placeholder(dtypes.float32)
|
||||
v = array_ops.identity(c)
|
||||
|
||||
def step_fn(step_context):
|
||||
if not trace_the_exception['run_already']:
|
||||
trace_the_exception['run_already'] = True
|
||||
raise errors_impl.AbortedError(None, None, 'Abort')
|
||||
|
||||
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
|
||||
return value
|
||||
|
||||
with monitored_session.SingularMonitoredSession() as session:
|
||||
with self.assertRaisesRegexp(errors_impl.AbortedError, 'Abort'):
|
||||
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
|
||||
self.fail()
|
||||
|
||||
self.assertTrue(trace_the_exception['run_already'])
|
||||
|
||||
def test_step_fn_exception_from_before_run(self):
|
||||
trace_the_exception = {'run_already': False, 'side_effect_counter': 0}
|
||||
|
||||
with ops.Graph().as_default():
|
||||
c = array_ops.placeholder(dtypes.float32)
|
||||
v = array_ops.identity(c)
|
||||
vv = constant_op.constant(3.2)
|
||||
graph_state = variables.Variable(0.0)
|
||||
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
|
||||
|
||||
class Hook(session_run_hook.SessionRunHook):
|
||||
|
||||
def __init__(self, testing):
|
||||
self._testing = testing
|
||||
|
||||
def before_run(self, run_context):
|
||||
if not trace_the_exception['run_already']:
|
||||
trace_the_exception['run_already'] = True
|
||||
raise errors_impl.AbortedError(None, None, 'Abort')
|
||||
return session_run_hook.SessionRunArgs(fetches=vv)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
self._testing.assertNear(3.2, run_values.results, 0.1)
|
||||
|
||||
def step_fn(step_context):
|
||||
trace_the_exception['side_effect_counter'] += 1
|
||||
step_context.session.run(graph_side_effect)
|
||||
return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})
|
||||
|
||||
with self.test_session() as test_session:
|
||||
with monitored_session.MonitoredSession(
|
||||
CountingSessionCreator(test_session),
|
||||
hooks=[Hook(self)]) as session:
|
||||
test_session.run(variables.global_variables_initializer())
|
||||
self.assertNear(1.3, session.run_step_fn(step_fn), 0.1)
|
||||
self.assertEqual(2, trace_the_exception['side_effect_counter'])
|
||||
self.assertNear(0.62, session.run(graph_state), 0.1)
|
||||
|
||||
|
||||
class SingularMonitoredSessionTest(test.TestCase):
|
||||
"""Tests SingularMonitoredSession."""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user