mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #31858 from tensorflow/revert-29458-cherrypicks_JXE4D
Revert "Deprecate `ModelCheckpoint.__init__`'s `load_weights_on_restart` argument and provide a warning message if used."
This commit is contained in:
commit
00fad90125
|
|
@ -835,6 +835,14 @@ class ModelCheckpoint(Callback):
|
|||
monitored metric may potentially be less reliable (it could reflect as
|
||||
little as 1 batch, since the metrics get reset every epoch). Defaults to
|
||||
`'epoch'`
|
||||
load_weights_on_restart: Whether the training should restore the model. If
|
||||
True, the model will attempt to load the checkpoint file from `filepath`
|
||||
at the start of `model.fit()`. This saves the need of manually calling
|
||||
`model.load_weights()` before `model.fit(). In multi-worker distributed
|
||||
training, this provides fault-tolerance and loads the model
|
||||
automatically upon recovery of workers. The callback gives up loading if
|
||||
the filepath does not exist, and raises ValueError if format does not
|
||||
match. Defaults to False.
|
||||
**kwargs: Additional arguments for backwards compatibility. Possible key
|
||||
is `period`.
|
||||
"""
|
||||
|
|
@ -847,6 +855,7 @@ class ModelCheckpoint(Callback):
|
|||
save_weights_only=False,
|
||||
mode='auto',
|
||||
save_freq='epoch',
|
||||
load_weights_on_restart=False,
|
||||
**kwargs):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
self.monitor = monitor
|
||||
|
|
@ -855,20 +864,10 @@ class ModelCheckpoint(Callback):
|
|||
self.save_best_only = save_best_only
|
||||
self.save_weights_only = save_weights_only
|
||||
self.save_freq = save_freq
|
||||
self.load_weights_on_restart = load_weights_on_restart
|
||||
self.epochs_since_last_save = 0
|
||||
self._samples_seen_since_last_saving = 0
|
||||
|
||||
# Deprecated field `load_weights_on_restart` is for loading the checkpoint
|
||||
# file from `filepath` at the start of `model.fit()`
|
||||
# TODO(rchao): Remove the arg during next breaking release.
|
||||
if 'load_weights_on_restart' in kwargs:
|
||||
self.load_weights_on_restart = kwargs['load_weights_on_restart']
|
||||
logging.warning('`load_weights_on_restart` argument is deprecated. '
|
||||
'Please use `model.load_weights()` for loading weights '
|
||||
'before the start of `model.fit()`.')
|
||||
else:
|
||||
self.load_weights_on_restart = False
|
||||
|
||||
# Deprecated field `period` is for the number of epochs between which
|
||||
# the model is saved.
|
||||
if 'period' in kwargs:
|
||||
|
|
@ -913,41 +912,27 @@ class ModelCheckpoint(Callback):
|
|||
self.save_weights_only = True
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
# TODO(rchao): Replace dc_context reference with
|
||||
if K.in_multi_worker_mode():
|
||||
# pylint: disable=protected-access
|
||||
# MultiWorkerTrainingState is used to manage the training state needed
|
||||
# for preemption-recovery of a worker in multi-worker training.
|
||||
self.model._training_state = (
|
||||
multi_worker_training_state.MultiWorkerTrainingState(
|
||||
self.model, self.filepath))
|
||||
self._training_state = self.model._training_state
|
||||
if self._training_state.restore():
|
||||
# If the training state needs to be and is successfully restored,
|
||||
# it is recovering from a previous failure (or preemption). In such
|
||||
# case, do not load the weights from user specified file path.
|
||||
return
|
||||
# TODO(rchao): Replace dc_context reference with
|
||||
# distributed_training_utils.should_current_worker_load_model() once
|
||||
# distributed_training_utils.py no longer depends on callbacks.py.
|
||||
if K.in_multi_worker_mode(
|
||||
) and not dc_context.get_current_worker_context().experimental_should_init:
|
||||
# For multi-worker training, it should not restore a model in certain
|
||||
# worker setting (e.g. non-chief worker in ParameterServerStrategy).
|
||||
return
|
||||
|
||||
# If this is not multi worker training, restoring is not needed, or
|
||||
# restoring failed, check if it should load weights on restart.
|
||||
if self.load_weights_on_restart:
|
||||
# In multi worker training, it only should if `experimental_should_init`
|
||||
# is True.
|
||||
# TODO(rchao): Reference `experimental_should_init` api from a util file.
|
||||
if not K.in_multi_worker_mode() or dc_context.get_current_worker_context(
|
||||
).experimental_should_init:
|
||||
filepath_to_load = (
|
||||
self._get_most_recently_modified_file_matching_pattern(
|
||||
self.filepath))
|
||||
if filepath_to_load is not None and os.path.exists(filepath_to_load):
|
||||
try:
|
||||
# `filepath` may contain placeholders such as `{epoch:02d}`, and
|
||||
# thus it attempts to load the most recently modified file with file
|
||||
# name matching the pattern.
|
||||
self.model.load_weights(filepath_to_load)
|
||||
except (IOError, ValueError) as e:
|
||||
raise ValueError('Error loading file from {}. Reason: {}'.format(
|
||||
filepath_to_load, e))
|
||||
filepath_to_load = self._get_most_recently_modified_file_matching_pattern(
|
||||
self.filepath)
|
||||
if (self.load_weights_on_restart and filepath_to_load is not None and
|
||||
os.path.exists(filepath_to_load)):
|
||||
try:
|
||||
# `filepath` may contain placeholders such as `{epoch:02d}`, and thus
|
||||
# it attempts to load the most recently modified file with file name
|
||||
# matching the pattern.
|
||||
self.model.load_weights(filepath_to_load)
|
||||
except (IOError, ValueError) as e:
|
||||
raise ValueError('Error loading file from {}. Reason: {}'.format(
|
||||
filepath_to_load, e))
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
logs = logs or {}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ tf_class {
|
|||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_batch_begin"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ tf_class {
|
|||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'load_weights_on_restart\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_batch_begin"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user