mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Handle the absence of a fresh eval checkpoint in run_local.
It is ~unexpected condition for an eval checkpoint to not be available after a train call to the estimator. There is a corner case when it is possible, but that's going to be resolved soon. This case is handled for continuous (distributed) evaluation differently. Instead of erroring out, we skip evaluation runs. That behavior is captured in the `test_skip_evaluation_due_to_ckpt` test. PiperOrigin-RevId: 170919925
This commit is contained in:
parent
435b31b9fc
commit
d0c76cd188
|
|
@ -392,6 +392,10 @@ class _TrainingExecutor(object):
|
|||
|
||||
metrics = evaluator.evaluate_and_export()
|
||||
|
||||
if not metrics:
|
||||
# This is unexpected. Training should always end with a new checkpoint.
|
||||
raise RuntimeError('There was no new checkpoint after the training.')
|
||||
|
||||
if _should_stop_local_train(metrics[ops.GraphKeys.GLOBAL_STEP]):
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ _INVALID_NAME_MSG = '`name` must be string'
|
|||
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
||||
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
||||
_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
|
||||
_STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.'
|
||||
_INVALID_EXPORT_STRATEGY_MSG = '`export_strategies` must be an ExportStrategy'
|
||||
_DUPLICATE_STRATEGY_NAMES_MSG = '`export_strategies` must have unique names.'
|
||||
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
|
||||
|
|
@ -1024,6 +1025,27 @@ class TrainingExecutorRunLocalTest(test.TestCase):
|
|||
self.assertEqual(3, mock_est.evaluate.call_count)
|
||||
self.assertEqual(3, mock_est.times_export_fn_was_called)
|
||||
|
||||
def test_handles_no_new_checkpoint_found(self):
|
||||
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
|
||||
mock_est.latest_checkpoint.return_value = (
|
||||
'no_new_checkpoints_after_the_first_train_step')
|
||||
train_spec = training.TrainSpec(
|
||||
input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
|
||||
eval_spec = training.EvalSpec(
|
||||
input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
|
||||
# It was going to be called 3 times.
|
||||
mock_est.evaluate.side_effect = [{
|
||||
_GLOBAL_STEP_KEY: train_spec.max_steps - 100
|
||||
}, {
|
||||
_GLOBAL_STEP_KEY: train_spec.max_steps - 50
|
||||
}, {
|
||||
_GLOBAL_STEP_KEY: train_spec.max_steps
|
||||
}]
|
||||
|
||||
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
|
||||
with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG):
|
||||
executor.run_local()
|
||||
|
||||
def test_train_and_evaluate_args(self):
|
||||
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
|
||||
mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user