Fix Keras Callbacks logs sync

This commit is contained in:
Lukas Geiger 2021-03-19 00:53:39 +01:00
parent fb37439d64
commit cbbd1cd765
2 changed files with 97 additions and 80 deletions

View File

@ -234,6 +234,8 @@ class CallbackList:
# Performance optimization: determines if batch hooks need to be called.
# pylint: disable=protected-access
self._supports_tf_logs = all(
getattr(cb, '_supports_tf_logs', False) for cb in self.callbacks)
self._should_call_train_batch_hooks = any(
cb._implements_train_batch_hooks() for cb in self.callbacks)
self._should_call_test_batch_hooks = any(
@ -272,6 +274,14 @@ class CallbackList:
self._history = History()
self.callbacks.append(self._history)
def _process_logs(self, logs):
"""Turns tensors into numpy arrays or Python scalars if necessary."""
if logs is None:
return {}
if not self._supports_tf_logs:
return tf_utils.sync_to_numpy_or_python_type(logs)
return logs
def append(self, callback):
self.callbacks.append(callback)
@ -347,19 +357,13 @@ class CallbackList:
def _call_batch_hook_helper(self, hook_name, batch, logs):
"""Helper function for `on_*_batch_*` methods."""
logs = logs or {}
numpy_logs = None
if self._check_timing:
start_time = time.time()
logs = self._process_logs(logs)
for callback in self.callbacks:
hook = getattr(callback, hook_name)
if getattr(callback, '_supports_tf_logs', False):
hook(batch, logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
hook(batch, numpy_logs)
hook(batch, logs)
if self._check_timing:
if hook_name not in self._hook_times:
@ -402,15 +406,9 @@ class CallbackList:
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
numpy_logs = None
logs = self._process_logs(logs)
for callback in self.callbacks:
if getattr(callback, '_supports_tf_logs', False):
callback.on_epoch_begin(epoch, logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
callback.on_epoch_begin(epoch, numpy_logs)
callback.on_epoch_begin(epoch, logs)
def on_epoch_end(self, epoch, logs=None):
"""Calls the `on_epoch_end` methods of its callbacks.
@ -423,15 +421,9 @@ class CallbackList:
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
logs = logs or {}
numpy_logs = None
logs = self._process_logs(logs)
for callback in self.callbacks:
if getattr(callback, '_supports_tf_logs', False):
callback.on_epoch_end(epoch, logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
callback.on_epoch_end(epoch, numpy_logs)
callback.on_epoch_end(epoch, logs)
def on_train_batch_begin(self, batch, logs=None):
"""Calls the `on_train_batch_begin` methods of its callbacks.
@ -506,15 +498,9 @@ class CallbackList:
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
numpy_logs = None
logs = self._process_logs(logs)
for callback in self.callbacks:
if getattr(callback, '_supports_tf_logs', False):
callback.on_train_begin(logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
callback.on_train_begin(numpy_logs)
callback.on_train_begin(logs)
def on_train_end(self, logs=None):
"""Calls the `on_train_end` methods of its callbacks.
@ -523,15 +509,9 @@ class CallbackList:
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
numpy_logs = None
logs = self._process_logs(logs)
for callback in self.callbacks:
if getattr(callback, '_supports_tf_logs', False):
callback.on_train_end(logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
callback.on_train_end(numpy_logs)
callback.on_train_end(logs)
def on_test_begin(self, logs=None):
"""Calls the `on_test_begin` methods of its callbacks.
@ -540,15 +520,9 @@ class CallbackList:
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
numpy_logs = None
logs = self._process_logs(logs)
for callback in self.callbacks:
if getattr(callback, '_supports_tf_logs', False):
callback.on_test_begin(logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
callback.on_test_begin(numpy_logs)
callback.on_test_begin(logs)
def on_test_end(self, logs=None):
"""Calls the `on_test_end` methods of its callbacks.
@ -557,15 +531,9 @@ class CallbackList:
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
numpy_logs = None
logs = self._process_logs(logs)
for callback in self.callbacks:
if getattr(callback, '_supports_tf_logs', False):
callback.on_test_end(logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
callback.on_test_end(numpy_logs)
callback.on_test_end(logs)
def on_predict_begin(self, logs=None):
"""Calls the 'on_predict_begin` methods of its callbacks.
@ -574,15 +542,9 @@ class CallbackList:
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
numpy_logs = None
logs = self._process_logs(logs)
for callback in self.callbacks:
if getattr(callback, '_supports_tf_logs', False):
callback.on_predict_begin(logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
callback.on_predict_begin(numpy_logs)
callback.on_predict_begin(logs)
def on_predict_end(self, logs=None):
"""Calls the `on_predict_end` methods of its callbacks.
@ -591,15 +553,9 @@ class CallbackList:
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
numpy_logs = None
logs = self._process_logs(logs)
for callback in self.callbacks:
if getattr(callback, '_supports_tf_logs', False):
callback.on_predict_end(logs)
else:
if numpy_logs is None: # Only convert once.
numpy_logs = tf_utils.sync_to_numpy_or_python_type(logs)
callback.on_predict_end(numpy_logs)
callback.on_predict_end(logs)
def __iter__(self):
return iter(self.callbacks)

View File

@ -76,6 +76,13 @@ INPUT_DIM = 3
NUM_HIDDEN = 5
BATCH_SIZE = 5
CALLBACK_HOOKS = [
'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end',
'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin',
'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end',
'on_test_begin', 'on_test_end', 'on_train_batch_begin',
'on_train_batch_end', 'on_train_begin', 'on_train_end'
]
class Counter(keras.callbacks.Callback):
"""Counts the number of times each callback method was run.
@ -87,14 +94,7 @@ class Counter(keras.callbacks.Callback):
def __init__(self):
self.method_counts = collections.defaultdict(int)
methods_to_count = [
'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end',
'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin',
'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end',
'on_test_begin', 'on_test_end', 'on_train_batch_begin',
'on_train_batch_end', 'on_train_begin', 'on_train_end'
]
for method_name in methods_to_count:
for method_name in CALLBACK_HOOKS:
setattr(self, method_name,
self.wrap_with_counts(method_name, getattr(self, method_name)))
@ -107,6 +107,17 @@ class Counter(keras.callbacks.Callback):
return _call_and_count
class CallAllHooks(keras.callbacks.Callback):
"""A callback that calls self._run for all hooks"""
def __init__(self):
for method_name in CALLBACK_HOOKS:
setattr(self, method_name, self._run)
def _run(self, *args, logs=None):
raise NotImplementedError
def _get_numpy():
return np.ones((10, 10)), np.ones((10, 1))
@ -1720,6 +1731,56 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_logs_conversion(self):
assert_dict_equal = self.assertDictEqual
class MutateNumpyLogs(CallAllHooks):
def _run(self, *args, logs=None):
logs = logs or args[-1]
logs["numpy"] = 1
class MutateTensorFlowLogs(CallAllHooks):
def __init__(self):
super(MutateTensorFlowLogs, self).__init__()
self._supports_tf_logs = True
def _run(self, *args, logs=None):
logs = logs or args[-1]
logs["tf"] = 2
class AssertNumpyLogs(CallAllHooks):
def _run(self, *args, logs=None):
logs = logs or args[-1]
assert_dict_equal(logs, {"all": 0, "numpy": 1, "tf": 2})
class AssertTensorFlowLogs(AssertNumpyLogs):
def __init__(self):
super(AssertTensorFlowLogs, self).__init__()
self._supports_tf_logs = True
cb_list = keras.callbacks.CallbackList([
MutateNumpyLogs(),
MutateTensorFlowLogs(),
AssertNumpyLogs(),
AssertTensorFlowLogs()])
assert len(cb_list.callbacks) == 4
cb_list.on_epoch_begin(0, logs={"all": 0})
cb_list.on_epoch_end(0, logs={"all": 0})
cb_list.on_predict_batch_begin(0, logs={"all": 0})
cb_list.on_predict_batch_end(0, logs={"all": 0})
cb_list.on_predict_begin(logs={"all": 0})
cb_list.on_predict_end(logs={"all": 0})
cb_list.on_test_batch_begin(0, logs={"all": 0})
cb_list.on_test_batch_end(0, logs={"all": 0})
cb_list.on_test_begin(logs={"all": 0})
cb_list.on_test_end(logs={"all": 0})
cb_list.on_train_batch_begin(0, logs={"all": 0})
cb_list.on_train_batch_end(0, logs={"all": 0})
cb_list.on_train_begin(logs={"all": 0})
cb_list.on_train_end(logs={"all": 0})
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_implements_batch_hooks_override(self):