mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Fix Keras Callbacks logs sync
This commit is contained in:
parent
fb37439d64
commit
cbbd1cd765
|
|
@ -234,6 +234,8 @@ class CallbackList:
|
|||
|
||||
# Performance optimization: determines if batch hooks need to be called.
|
||||
# pylint: disable=protected-access
|
||||
self._supports_tf_logs = all(
|
||||
getattr(cb, '_supports_tf_logs', False) for cb in self.callbacks)
|
||||
self._should_call_train_batch_hooks = any(
|
||||
cb._implements_train_batch_hooks() for cb in self.callbacks)
|
||||
self._should_call_test_batch_hooks = any(
|
||||
|
|
@ -272,6 +274,14 @@ class CallbackList:
|
|||
self._history = History()
|
||||
self.callbacks.append(self._history)
|
||||
|
||||
def _process_logs(self, logs):
|
||||
"""Turns tensors into numpy arrays or Python scalars if necessary."""
|
||||
if logs is None:
|
||||
return {}
|
||||
if not self._supports_tf_logs:
|
||||
return tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
return logs
|
||||
|
||||
def append(self, callback):
|
||||
self.callbacks.append(callback)
|
||||
|
||||
|
|
@ -347,19 +357,13 @@ class CallbackList:
|
|||
|
||||
def _call_batch_hook_helper(self, hook_name, batch, logs):
|
||||
"""Helper function for `on_*_batch_*` methods."""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
if self._check_timing:
|
||||
start_time = time.time()
|
||||
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
hook = getattr(callback, hook_name)
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
hook(batch, logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
hook(batch, numpy_logs)
|
||||
hook(batch, logs)
|
||||
|
||||
if self._check_timing:
|
||||
if hook_name not in self._hook_times:
|
||||
|
|
@ -402,15 +406,9 @@ class CallbackList:
|
|||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
callback.on_epoch_begin(epoch, logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
callback.on_epoch_begin(epoch, numpy_logs)
|
||||
callback.on_epoch_begin(epoch, logs)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
"""Calls the `on_epoch_end` methods of its callbacks.
|
||||
|
|
@ -423,15 +421,9 @@ class CallbackList:
|
|||
validation epoch if validation is performed. Validation result keys
|
||||
are prefixed with `val_`.
|
||||
"""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
callback.on_epoch_end(epoch, logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
callback.on_epoch_end(epoch, numpy_logs)
|
||||
callback.on_epoch_end(epoch, logs)
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
"""Calls the `on_train_batch_begin` methods of its callbacks.
|
||||
|
|
@ -506,15 +498,9 @@ class CallbackList:
|
|||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
callback.on_train_begin(logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
callback.on_train_begin(numpy_logs)
|
||||
callback.on_train_begin(logs)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
"""Calls the `on_train_end` methods of its callbacks.
|
||||
|
|
@ -523,15 +509,9 @@ class CallbackList:
|
|||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
callback.on_train_end(logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
callback.on_train_end(numpy_logs)
|
||||
callback.on_train_end(logs)
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
"""Calls the `on_test_begin` methods of its callbacks.
|
||||
|
|
@ -540,15 +520,9 @@ class CallbackList:
|
|||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
callback.on_test_begin(logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
callback.on_test_begin(numpy_logs)
|
||||
callback.on_test_begin(logs)
|
||||
|
||||
def on_test_end(self, logs=None):
|
||||
"""Calls the `on_test_end` methods of its callbacks.
|
||||
|
|
@ -557,15 +531,9 @@ class CallbackList:
|
|||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
callback.on_test_end(logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
callback.on_test_end(numpy_logs)
|
||||
callback.on_test_end(logs)
|
||||
|
||||
def on_predict_begin(self, logs=None):
|
||||
"""Calls the 'on_predict_begin` methods of its callbacks.
|
||||
|
|
@ -574,15 +542,9 @@ class CallbackList:
|
|||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
callback.on_predict_begin(logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
callback.on_predict_begin(numpy_logs)
|
||||
callback.on_predict_begin(logs)
|
||||
|
||||
def on_predict_end(self, logs=None):
|
||||
"""Calls the `on_predict_end` methods of its callbacks.
|
||||
|
|
@ -591,15 +553,9 @@ class CallbackList:
|
|||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
numpy_logs = None
|
||||
logs = self._process_logs(logs)
|
||||
for callback in self.callbacks:
|
||||
if getattr(callback, '_supports_tf_logs', False):
|
||||
callback.on_predict_end(logs)
|
||||
else:
|
||||
if numpy_logs is None: # Only convert once.
|
||||
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
||||
callback.on_predict_end(numpy_logs)
|
||||
callback.on_predict_end(logs)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.callbacks)
|
||||
|
|
|
|||
|
|
@ -76,6 +76,13 @@ INPUT_DIM = 3
|
|||
NUM_HIDDEN = 5
|
||||
BATCH_SIZE = 5
|
||||
|
||||
CALLBACK_HOOKS = [
|
||||
'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end',
|
||||
'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin',
|
||||
'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end',
|
||||
'on_test_begin', 'on_test_end', 'on_train_batch_begin',
|
||||
'on_train_batch_end', 'on_train_begin', 'on_train_end'
|
||||
]
|
||||
|
||||
class Counter(keras.callbacks.Callback):
|
||||
"""Counts the number of times each callback method was run.
|
||||
|
|
@ -87,14 +94,7 @@ class Counter(keras.callbacks.Callback):
|
|||
|
||||
def __init__(self):
|
||||
self.method_counts = collections.defaultdict(int)
|
||||
methods_to_count = [
|
||||
'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end',
|
||||
'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin',
|
||||
'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end',
|
||||
'on_test_begin', 'on_test_end', 'on_train_batch_begin',
|
||||
'on_train_batch_end', 'on_train_begin', 'on_train_end'
|
||||
]
|
||||
for method_name in methods_to_count:
|
||||
for method_name in CALLBACK_HOOKS:
|
||||
setattr(self, method_name,
|
||||
self.wrap_with_counts(method_name, getattr(self, method_name)))
|
||||
|
||||
|
|
@ -107,6 +107,17 @@ class Counter(keras.callbacks.Callback):
|
|||
return _call_and_count
|
||||
|
||||
|
||||
class CallAllHooks(keras.callbacks.Callback):
|
||||
"""A callback that calls self._run for all hooks"""
|
||||
|
||||
def __init__(self):
|
||||
for method_name in CALLBACK_HOOKS:
|
||||
setattr(self, method_name, self._run)
|
||||
|
||||
def _run(self, *args, logs=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _get_numpy():
|
||||
return np.ones((10, 10)), np.ones((10, 1))
|
||||
|
||||
|
|
@ -1720,6 +1731,56 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
|||
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_logs_conversion(self):
|
||||
assert_dict_equal = self.assertDictEqual
|
||||
|
||||
class MutateNumpyLogs(CallAllHooks):
|
||||
def _run(self, *args, logs=None):
|
||||
logs = logs or args[-1]
|
||||
logs["numpy"] = 1
|
||||
|
||||
class MutateTensorFlowLogs(CallAllHooks):
|
||||
def __init__(self):
|
||||
super(MutateTensorFlowLogs, self).__init__()
|
||||
self._supports_tf_logs = True
|
||||
|
||||
def _run(self, *args, logs=None):
|
||||
logs = logs or args[-1]
|
||||
logs["tf"] = 2
|
||||
|
||||
class AssertNumpyLogs(CallAllHooks):
|
||||
def _run(self, *args, logs=None):
|
||||
logs = logs or args[-1]
|
||||
assert_dict_equal(logs, {"all": 0, "numpy": 1, "tf": 2})
|
||||
|
||||
class AssertTensorFlowLogs(AssertNumpyLogs):
|
||||
def __init__(self):
|
||||
super(AssertTensorFlowLogs, self).__init__()
|
||||
self._supports_tf_logs = True
|
||||
|
||||
cb_list = keras.callbacks.CallbackList([
|
||||
MutateNumpyLogs(),
|
||||
MutateTensorFlowLogs(),
|
||||
AssertNumpyLogs(),
|
||||
AssertTensorFlowLogs()])
|
||||
|
||||
assert len(cb_list.callbacks) == 4
|
||||
cb_list.on_epoch_begin(0, logs={"all": 0})
|
||||
cb_list.on_epoch_end(0, logs={"all": 0})
|
||||
cb_list.on_predict_batch_begin(0, logs={"all": 0})
|
||||
cb_list.on_predict_batch_end(0, logs={"all": 0})
|
||||
cb_list.on_predict_begin(logs={"all": 0})
|
||||
cb_list.on_predict_end(logs={"all": 0})
|
||||
cb_list.on_test_batch_begin(0, logs={"all": 0})
|
||||
cb_list.on_test_batch_end(0, logs={"all": 0})
|
||||
cb_list.on_test_begin(logs={"all": 0})
|
||||
cb_list.on_test_end(logs={"all": 0})
|
||||
cb_list.on_train_batch_begin(0, logs={"all": 0})
|
||||
cb_list.on_train_batch_end(0, logs={"all": 0})
|
||||
cb_list.on_train_begin(logs={"all": 0})
|
||||
cb_list.on_train_end(logs={"all": 0})
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_implements_batch_hooks_override(self):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user