mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add 'log_progress' argument for tf.estimator.Estimator's evaluate function (#13695)
* Add argument for tf.estimator.Estimator's evaluate function * add log_progress argument to ._convert_eval_steps_to_hooks for TPU estimator * log only every 10th step if more than 100 iterations in _StopAfterNEvalsHook * ensure last step is logged and aim for 10 outputs total
This commit is contained in:
parent
07a91dac54
commit
fa9d8aab41
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import math
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
|
@ -91,6 +92,9 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
|
|||
self._num_evals = num_evals
|
||||
self._evals_completed = None
|
||||
self._log_progress = log_progress
|
||||
# Reduce logging frequency if there are 20 or more evaluations.
|
||||
self._log_frequency = (1 if (num_evals is None or num_evals < 20)
|
||||
else math.floor(num_evals / 10.))
|
||||
|
||||
def _set_evals_completed_tensor(self, updated_eval_step):
|
||||
self._evals_completed = updated_eval_step
|
||||
|
|
@ -106,7 +110,9 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
|
|||
if self._num_evals is None:
|
||||
logging.info('Evaluation [%d]', evals_completed)
|
||||
else:
|
||||
logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
|
||||
if ((evals_completed % self._log_frequency) == 0 or
|
||||
(self._num_evals == evals_completed)):
|
||||
logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
|
||||
if self._num_evals is not None and evals_completed >= self._num_evals:
|
||||
run_context.request_stop()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user