mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Automated g4 rollback of changelist 170892257
PiperOrigin-RevId: 170919783
This commit is contained in:
parent
66df43d09c
commit
435b31b9fc
|
|
@ -392,10 +392,6 @@ class _TrainingExecutor(object):
|
|||
|
||||
metrics = evaluator.evaluate_and_export()
|
||||
|
||||
if not metrics:
|
||||
# This is unexpected. Training should always end with a new checkpoint.
|
||||
raise RuntimeError('There was no new checkpoint after the training.')
|
||||
|
||||
if _should_stop_local_train(metrics[ops.GraphKeys.GLOBAL_STEP]):
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ _INVALID_NAME_MSG = '`name` must be string'
|
|||
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
||||
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
||||
_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
|
||||
_STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.'
|
||||
_INVALID_EXPORT_STRATEGY_MSG = '`export_strategies` must be an ExportStrategy'
|
||||
_DUPLICATE_STRATEGY_NAMES_MSG = '`export_strategies` must have unique names.'
|
||||
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
|
||||
|
|
@ -1025,27 +1024,6 @@ class TrainingExecutorRunLocalTest(test.TestCase):
|
|||
self.assertEqual(3, mock_est.evaluate.call_count)
|
||||
self.assertEqual(3, mock_est.times_export_fn_was_called)
|
||||
|
||||
def test_handles_no_new_checkpoint_found(self):
|
||||
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
|
||||
mock_est.latest_checkpoint.return_value = (
|
||||
'no_new_checkpoints_after_the_first_train_step')
|
||||
train_spec = training.TrainSpec(
|
||||
input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
|
||||
eval_spec = training.EvalSpec(
|
||||
input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
|
||||
# It was going to be called 3 times.
|
||||
mock_est.evaluate.side_effect = [{
|
||||
_GLOBAL_STEP_KEY: train_spec.max_steps - 100
|
||||
}, {
|
||||
_GLOBAL_STEP_KEY: train_spec.max_steps - 50
|
||||
}, {
|
||||
_GLOBAL_STEP_KEY: train_spec.max_steps
|
||||
}]
|
||||
|
||||
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
|
||||
with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG):
|
||||
executor.run_local()
|
||||
|
||||
def test_train_and_evaluate_args(self):
|
||||
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
|
||||
mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user