Merge pull request #31858 from tensorflow/revert-29458-cherrypicks_JXE4D

Revert "Deprecate `ModelCheckpoint.__init__`'s `load_weights_on_restart` argument and provide a warning message if used."
This commit is contained in:
Mihai Maruseac 2019-08-22 09:25:44 -07:00 committed by GitHub
commit 00fad90125
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 47 deletions

View File

@ -835,6 +835,14 @@ class ModelCheckpoint(Callback):
monitored metric may potentially be less reliable (it could reflect as
little as 1 batch, since the metrics get reset every epoch). Defaults to
`'epoch'`
load_weights_on_restart: Whether the training should restore the model. If
True, the model will attempt to load the checkpoint file from `filepath`
at the start of `model.fit()`. This saves the need of manually calling
`model.load_weights()` before `model.fit(). In multi-worker distributed
training, this provides fault-tolerance and loads the model
automatically upon recovery of workers. The callback gives up loading if
the filepath does not exist, and raises ValueError if format does not
match. Defaults to False.
**kwargs: Additional arguments for backwards compatibility. Possible key
is `period`.
"""
@ -847,6 +855,7 @@ class ModelCheckpoint(Callback):
save_weights_only=False,
mode='auto',
save_freq='epoch',
load_weights_on_restart=False,
**kwargs):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
@ -855,20 +864,10 @@ class ModelCheckpoint(Callback):
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.save_freq = save_freq
self.load_weights_on_restart = load_weights_on_restart
self.epochs_since_last_save = 0
self._samples_seen_since_last_saving = 0
# Deprecated field `load_weights_on_restart` is for loading the checkpoint
# file from `filepath` at the start of `model.fit()`
# TODO(rchao): Remove the arg during next breaking release.
if 'load_weights_on_restart' in kwargs:
self.load_weights_on_restart = kwargs['load_weights_on_restart']
logging.warning('`load_weights_on_restart` argument is deprecated. '
'Please use `model.load_weights()` for loading weights '
'before the start of `model.fit()`.')
else:
self.load_weights_on_restart = False
# Deprecated field `period` is for the number of epochs between which
# the model is saved.
if 'period' in kwargs:
@ -913,41 +912,27 @@ class ModelCheckpoint(Callback):
self.save_weights_only = True
def on_train_begin(self, logs=None):
# TODO(rchao): Replace dc_context reference with
if K.in_multi_worker_mode():
# pylint: disable=protected-access
# MultiWorkerTrainingState is used to manage the training state needed
# for preemption-recovery of a worker in multi-worker training.
self.model._training_state = (
multi_worker_training_state.MultiWorkerTrainingState(
self.model, self.filepath))
self._training_state = self.model._training_state
if self._training_state.restore():
# If the training state needs to be and is successfully restored,
# it is recovering from a previous failure (or preemption). In such
# case, do not load the weights from user specified file path.
return
# TODO(rchao): Replace dc_context reference with
# distributed_training_utils.should_current_worker_load_model() once
# distributed_training_utils.py no longer depends on callbacks.py.
if K.in_multi_worker_mode(
) and not dc_context.get_current_worker_context().experimental_should_init:
# For multi-worker training, it should not restore a model in certain
# worker setting (e.g. non-chief worker in ParameterServerStrategy).
return
# If this is not multi worker training, restoring is not needed, or
# restoring failed, check if it should load weights on restart.
if self.load_weights_on_restart:
# In multi worker training, it only should if `experimental_should_init`
# is True.
# TODO(rchao): Reference `experimental_should_init` api from a util file.
if not K.in_multi_worker_mode() or dc_context.get_current_worker_context(
).experimental_should_init:
filepath_to_load = (
self._get_most_recently_modified_file_matching_pattern(
self.filepath))
if filepath_to_load is not None and os.path.exists(filepath_to_load):
try:
# `filepath` may contain placeholders such as `{epoch:02d}`, and
# thus it attempts to load the most recently modified file with file
# name matching the pattern.
self.model.load_weights(filepath_to_load)
except (IOError, ValueError) as e:
raise ValueError('Error loading file from {}. Reason: {}'.format(
filepath_to_load, e))
filepath_to_load = self._get_most_recently_modified_file_matching_pattern(
self.filepath)
if (self.load_weights_on_restart and filepath_to_load is not None and
os.path.exists(filepath_to_load)):
try:
# `filepath` may contain placeholders such as `{epoch:02d}`, and thus
# it attempts to load the most recently modified file with file name
# matching the pattern.
self.model.load_weights(filepath_to_load)
except (IOError, ValueError) as e:
raise ValueError('Error loading file from {}. Reason: {}'.format(
filepath_to_load, e))
def on_train_end(self, logs=None):
logs = logs or {}

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], "
}
member_method {
name: "on_batch_begin"

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], "
}
member_method {
name: "on_batch_begin"