mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Automated g4 rollback of changelist 170892257
PiperOrigin-RevId: 170919783
This commit is contained in:
parent
66df43d09c
commit
435b31b9fc
|
|
@ -392,10 +392,6 @@ class _TrainingExecutor(object):
|
||||||
|
|
||||||
metrics = evaluator.evaluate_and_export()
|
metrics = evaluator.evaluate_and_export()
|
||||||
|
|
||||||
if not metrics:
|
|
||||||
# This is unexpected. Training should always end with a new checkpoint.
|
|
||||||
raise RuntimeError('There was no new checkpoint after the training.')
|
|
||||||
|
|
||||||
if _should_stop_local_train(metrics[ops.GraphKeys.GLOBAL_STEP]):
|
if _should_stop_local_train(metrics[ops.GraphKeys.GLOBAL_STEP]):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,6 @@ _INVALID_NAME_MSG = '`name` must be string'
|
||||||
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
|
||||||
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
|
||||||
_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
|
_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
|
||||||
_STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.'
|
|
||||||
_INVALID_EXPORT_STRATEGY_MSG = '`export_strategies` must be an ExportStrategy'
|
_INVALID_EXPORT_STRATEGY_MSG = '`export_strategies` must be an ExportStrategy'
|
||||||
_DUPLICATE_STRATEGY_NAMES_MSG = '`export_strategies` must have unique names.'
|
_DUPLICATE_STRATEGY_NAMES_MSG = '`export_strategies` must have unique names.'
|
||||||
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
|
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
|
||||||
|
|
@ -1025,27 +1024,6 @@ class TrainingExecutorRunLocalTest(test.TestCase):
|
||||||
self.assertEqual(3, mock_est.evaluate.call_count)
|
self.assertEqual(3, mock_est.evaluate.call_count)
|
||||||
self.assertEqual(3, mock_est.times_export_fn_was_called)
|
self.assertEqual(3, mock_est.times_export_fn_was_called)
|
||||||
|
|
||||||
def test_handles_no_new_checkpoint_found(self):
|
|
||||||
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
|
|
||||||
mock_est.latest_checkpoint.return_value = (
|
|
||||||
'no_new_checkpoints_after_the_first_train_step')
|
|
||||||
train_spec = training.TrainSpec(
|
|
||||||
input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
|
|
||||||
eval_spec = training.EvalSpec(
|
|
||||||
input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
|
|
||||||
# It was going to be called 3 times.
|
|
||||||
mock_est.evaluate.side_effect = [{
|
|
||||||
_GLOBAL_STEP_KEY: train_spec.max_steps - 100
|
|
||||||
}, {
|
|
||||||
_GLOBAL_STEP_KEY: train_spec.max_steps - 50
|
|
||||||
}, {
|
|
||||||
_GLOBAL_STEP_KEY: train_spec.max_steps
|
|
||||||
}]
|
|
||||||
|
|
||||||
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
|
|
||||||
with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG):
|
|
||||||
executor.run_local()
|
|
||||||
|
|
||||||
def test_train_and_evaluate_args(self):
|
def test_train_and_evaluate_args(self):
|
||||||
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
|
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
|
||||||
mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
|
mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user