mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Adds save and restore methods to tfe.Network
Save just saves the variables to a checkpoint. Restore either restores immediately or defers the restoration to variable creation time with a custom getter. PiperOrigin-RevId: 173703075
This commit is contained in:
parent
9158f974a3
commit
d7cffe9c03
|
|
@ -19,11 +19,17 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.estimator import util as estimator_util
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import base
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# Explanation for protected-access disable: Network has lots of same-class and
|
||||
|
|
@ -31,6 +37,151 @@ from tensorflow.python.ops import variable_scope
|
|||
# functions in base.py which should be reused.
|
||||
|
||||
|
||||
_DeferredRestoration = collections.namedtuple(
|
||||
|
||||
"_DeferredRestoration",
|
||||
[
|
||||
# The map_func to use (either user-specified or the default).
|
||||
"map_func",
|
||||
# Boolean, True if the user specified an explicit map_func, for error
|
||||
# messages.
|
||||
"map_func_is_user",
|
||||
# A mapping from checkpoint names to initial values of not-yet-created
|
||||
# variables which should be restored. These values come from parsing a
|
||||
# checkpoint.
|
||||
"checkpointed_variables_to_restore",
|
||||
# A mapping from checkpoint name to variable objects of variables which
|
||||
# have already been restored, for error checking.
|
||||
"restored_variables",
|
||||
# The session to restore with (if in graph mode).
|
||||
"session",
|
||||
# Names of the Network where the restore was requested, for error
|
||||
# messages.
|
||||
"network_name",
|
||||
"network_scope_name"
|
||||
])
|
||||
|
||||
|
||||
def _default_naming_conflict_error_message(
|
||||
mapped_name, first_variable, second_variable,
|
||||
network_name, network_scope_name):
|
||||
return (
|
||||
("The default checkpoint variable name mapping strategy for Network "
|
||||
"'%s' resulted in a naming conflict. We attempted to strip off the "
|
||||
"variable prefix for the Network ('%s'), but this resulted in two "
|
||||
"variables named '%s' (originally '%s' and '%s'). This should only "
|
||||
"happen when using variable sharing (i.e. the Network contains Networks "
|
||||
"or Layers which were first added to another Network, and therefore "
|
||||
"have that Network's variable prefix). One solution is to pass "
|
||||
"`map_func=lambda n: n` to Network.save and Network.restore to use "
|
||||
"fully qualified variable names in the checkpoint, although this will "
|
||||
"require that the variable prefix of the Network being restored into "
|
||||
"is also '%s'. You may alternatively write an arbitrary mapping.")
|
||||
% (
|
||||
network_name, network_scope_name, mapped_name,
|
||||
first_variable._shared_name,
|
||||
second_variable._shared_name, network_scope_name
|
||||
))
|
||||
|
||||
|
||||
def _restore_custom_map_func_error_message(
|
||||
mapped_name, first_variable, second_variable,
|
||||
network_name, network_scope_name):
|
||||
return (
|
||||
("The map_func passed to Network.restore for the Network '%s' "
|
||||
"resulted in two variables named '%s' (originally '%s' and '%s'). Since "
|
||||
"this is also an error on Network.save, this Network was "
|
||||
"probably not saved with this map_func. Note that map_func "
|
||||
"always maps from full variable names to checkpoint names; "
|
||||
"there is no need to specify an inverse mapping.\n\n"
|
||||
"Try stripping less from the variable names, or renaming parts "
|
||||
"of the Network. For reference, variables created by sub-Layers "
|
||||
"of this Network are prefixed with '%s', but if they are "
|
||||
"re-used after being added to another Network they will have "
|
||||
"that Network's full variable prefix instead.") % (
|
||||
network_name, mapped_name,
|
||||
first_variable._shared_name,
|
||||
second_variable._shared_name,
|
||||
network_scope_name))
|
||||
|
||||
|
||||
def _make_custom_getter_for_deferred_restorations():
|
||||
"""Returns a custom getter which searches `deferred_restorations`.
|
||||
|
||||
Returns: A tuple of (_custom_getter, deferred_restorations)
|
||||
_custom_getter: The getter which should be added to variable_scopes where
|
||||
variables will be created.
|
||||
deferred_restorations: A list for _DeferredRestoration objects. Typically
|
||||
empty when the getter is set, and expanded as deferred restorations are
|
||||
requested. All new deferred restorations should be appended to the end of
|
||||
the list, where they will have priority over older deferred restorations.
|
||||
"""
|
||||
deferred_restorations = []
|
||||
|
||||
def _custom_getter(getter, name, shape=None, dtype=None,
|
||||
initializer=None,
|
||||
*args, **kwargs):
|
||||
"""A custom getter which processes deferred restorations."""
|
||||
# Iterate over restorations, newest first (newer restorations will take
|
||||
# precedence over older restorations, just like with immediate restorations
|
||||
# into existing variables).
|
||||
delayed_restoration = None
|
||||
found_value = False
|
||||
value_to_restore = None
|
||||
for delayed_restoration in reversed(
|
||||
deferred_restorations):
|
||||
checkpoint_name = delayed_restoration.map_func(name)
|
||||
if (checkpoint_name
|
||||
in delayed_restoration.checkpointed_variables_to_restore):
|
||||
found_value = True
|
||||
value_to_restore = (
|
||||
delayed_restoration.checkpointed_variables_to_restore[
|
||||
checkpoint_name])
|
||||
if found_value:
|
||||
break
|
||||
# value_to_restore may be False because this variable is not in any
|
||||
# checkpoint we are restoring, or None because we have explicitly set it to
|
||||
# None when it was previously fetched. In either case, we don't need to
|
||||
# set an initializer.
|
||||
if found_value and value_to_restore is not None:
|
||||
initializer = value_to_restore
|
||||
shape = None
|
||||
variable = getter(name, shape=shape, dtype=dtype, initializer=initializer,
|
||||
*args, **kwargs)
|
||||
if found_value and value_to_restore is not None:
|
||||
# Mark as already restored from this checkpoint.
|
||||
delayed_restoration.checkpointed_variables_to_restore[
|
||||
checkpoint_name] = None
|
||||
if context.in_graph_mode():
|
||||
delayed_restoration.session.run(variable.initializer)
|
||||
if found_value:
|
||||
# Error checking should run even if we've already restored a value.
|
||||
if delayed_restoration.restored_variables.setdefault(
|
||||
checkpoint_name, variable) is not variable:
|
||||
# Naming conflict. We've tried to initialize two variables with the
|
||||
# same value from the checkpoint.
|
||||
if delayed_restoration.map_func_is_user:
|
||||
raise ValueError(
|
||||
_restore_custom_map_func_error_message(
|
||||
mapped_name=checkpoint_name,
|
||||
first_variable=delayed_restoration.restored_variables[
|
||||
checkpoint_name],
|
||||
second_variable=variable,
|
||||
network_name=delayed_restoration.network_name,
|
||||
network_scope_name=delayed_restoration.network_scope_name))
|
||||
else:
|
||||
raise ValueError(
|
||||
_default_naming_conflict_error_message(
|
||||
mapped_name=checkpoint_name,
|
||||
first_variable=delayed_restoration.restored_variables[
|
||||
checkpoint_name],
|
||||
second_variable=variable,
|
||||
network_name=delayed_restoration.network_name,
|
||||
network_scope_name=delayed_restoration.network_scope_name))
|
||||
return variable
|
||||
return _custom_getter, deferred_restorations
|
||||
|
||||
|
||||
class Network(base.Layer):
|
||||
"""Represents the composition of a set of Layers.
|
||||
|
||||
|
|
@ -41,7 +192,6 @@ class Network(base.Layer):
|
|||
- Convert inputs to __call__ to tensors.
|
||||
- Prevent variables from being created after the first __call__?
|
||||
(Think about restoring from a checkpoint).
|
||||
- Save & restore
|
||||
"""
|
||||
|
||||
def __init__(self, name=None):
|
||||
|
|
@ -60,6 +210,8 @@ class Network(base.Layer):
|
|||
self._owned_layers = {}
|
||||
# The scope to use if we end up without a parent.
|
||||
self._default_parent_variable_scope = variable_scope.get_variable_scope()
|
||||
self._custom_getter, self._deferred_restorations = (
|
||||
_make_custom_getter_for_deferred_restorations())
|
||||
|
||||
def _init_set_name(self, name):
|
||||
# Anonymous Networks (name=None) defer setting a final name until they are
|
||||
|
|
@ -87,7 +239,8 @@ class Network(base.Layer):
|
|||
avoid_names = None
|
||||
self._name, self._base_name = self._make_unique_name(
|
||||
name_uid_map=name_uid_map, avoid_names=avoid_names)
|
||||
if self._first_parent is None or self._first_parent() is None:
|
||||
if self._first_parent is None or (self._first_parent # False = no parent
|
||||
and self._first_parent() is None):
|
||||
# Save a pointer to the parent Network so that we can later check that the
|
||||
# scope name we get is correct.
|
||||
if not parent_network:
|
||||
|
|
@ -151,26 +304,32 @@ class Network(base.Layer):
|
|||
"of Networks in which they were first created). To set "
|
||||
"options, try `with tf.variable_scope(''):`. If this "
|
||||
"limitation bothers you, please file a feature request.")
|
||||
for non_network_constituent in self._non_network_sublayers:
|
||||
if non_network_constituent._scope is None:
|
||||
if non_network_constituent._first_parent is None:
|
||||
constituent_first_parent = None
|
||||
else:
|
||||
constituent_first_parent = non_network_constituent._first_parent()
|
||||
if constituent_first_parent:
|
||||
constituent_first_parent._set_scope()
|
||||
parent_scope = constituent_first_parent._scope
|
||||
else:
|
||||
parent_scope = (
|
||||
non_network_constituent._default_parent_variable_scope)
|
||||
with variable_scope.variable_scope(parent_scope):
|
||||
# Horrid hack to make Layer variable names which are direct
|
||||
# sub-layers of Networks conform to the Network variable naming
|
||||
# conventions.
|
||||
with variable_scope.variable_scope(
|
||||
None, use_resource=True,
|
||||
default_name=non_network_constituent.name) as sub_scope:
|
||||
non_network_constituent._scope = sub_scope
|
||||
for non_network_sublayer in self._non_network_sublayers:
|
||||
self._set_scope_for_nonnetwork_sublayer(non_network_sublayer)
|
||||
|
||||
def _set_scope_for_nonnetwork_sublayer(self, sublayer):
|
||||
if sublayer._scope is None:
|
||||
if sublayer._first_parent is None:
|
||||
constituent_first_parent = None
|
||||
else:
|
||||
constituent_first_parent = sublayer._first_parent()
|
||||
if constituent_first_parent:
|
||||
constituent_first_parent._set_scope()
|
||||
parent_scope = constituent_first_parent._scope
|
||||
else:
|
||||
self._finalize_name(False)
|
||||
raise ValueError(
|
||||
("The parent of a Layer added to Network %s was garbage collected "
|
||||
"before the Layer was built. If this limitation bothers you "
|
||||
"please, file a feature request.") % (self.name,))
|
||||
with variable_scope.variable_scope(parent_scope):
|
||||
# Horrid hack to make Layer variable names which are direct
|
||||
# sub-layers of Networks conform to the Network variable naming
|
||||
# conventions.
|
||||
with variable_scope.variable_scope(
|
||||
None, use_resource=True,
|
||||
default_name=sublayer.name) as sub_scope:
|
||||
sublayer._scope = sub_scope
|
||||
|
||||
@base.Layer.name.getter
|
||||
def name(self):
|
||||
|
|
@ -327,6 +486,270 @@ class Network(base.Layer):
|
|||
"at https://github.com/tensorflow/tensorflow/issues/new if this is "
|
||||
"important to you")
|
||||
|
||||
def _strip_variable_prefix(self, original_variable_name):
|
||||
"""The default map_func for saving or restoring variables.
|
||||
|
||||
Strips the variable prefix for the Network on which save/restore was called,
|
||||
and leaves other variable names fully qualified in the checkpoint.
|
||||
|
||||
Args:
|
||||
original_variable_name: The _shared_name of the variable (no :0
|
||||
suffix) to map.
|
||||
Returns:
|
||||
The checkpoint name of the variable.
|
||||
"""
|
||||
scope_name_with_slash = self.scope_name + "/"
|
||||
if original_variable_name.startswith(scope_name_with_slash):
|
||||
return original_variable_name[len(scope_name_with_slash):]
|
||||
else:
|
||||
return original_variable_name
|
||||
|
||||
def save(self, save_path, global_step=None, map_func=None):
|
||||
"""Save variables from the Network to a checkpoint.
|
||||
|
||||
Args:
|
||||
save_path: Either a checkpoint prefix or the name of a directory to save
|
||||
the checkpoint in (in which case the checkpoint will be named based on
|
||||
the Network name).
|
||||
global_step: The global step to use when naming the checkpoint. If None
|
||||
(default), we will first try to get the default global step. If that
|
||||
fails because no default global step exists, then the checkpoint is
|
||||
created without a global step suffix.
|
||||
map_func: A function mapping fully qualified variable names
|
||||
(e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
|
||||
default (if `map_func=None`), the variable prefix for the network being
|
||||
restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
|
||||
and all other variable names (shared with other Networks) are left
|
||||
unchanged.
|
||||
Returns:
|
||||
The checkpoint prefix for the saved checkpoint, which may be passed to
|
||||
`Network.restore`.
|
||||
Raises:
|
||||
ValueError: If the Network has not yet been called, or if map_func results
|
||||
in a name collision.
|
||||
"""
|
||||
if not self.built:
|
||||
raise ValueError(
|
||||
"Attempt to save the Network before it was first called. This means "
|
||||
"variables have not yet been created, so there is nothing to save.")
|
||||
self._set_scope() # scope_name should be available to map_funcs
|
||||
if global_step is None:
|
||||
global_step = training_util.get_global_step()
|
||||
if os.path.isdir(save_path):
|
||||
# If we were passed a directory, default to naming based on the Network
|
||||
# name.
|
||||
save_path = os.path.join(save_path, self.name)
|
||||
user_map_func = map_func
|
||||
if map_func is None:
|
||||
map_func = self._strip_variable_prefix
|
||||
variable_map = {}
|
||||
for variable in self.variables:
|
||||
mapped_name = map_func(variable._shared_name)
|
||||
if variable_map.setdefault(mapped_name, variable) is not variable:
|
||||
if user_map_func is None:
|
||||
# Instead of erroring out, we could just re-try and silently use the
|
||||
# full variable names in the checkpoint. This could be odd for deeply
|
||||
# nested sub-Networks (since the full prefix from the nesting would
|
||||
# get added), so for now we'll let the user deal with this case.
|
||||
raise ValueError(_default_naming_conflict_error_message(
|
||||
mapped_name=mapped_name,
|
||||
first_variable=variable_map[mapped_name],
|
||||
second_variable=variable,
|
||||
network_name=self.name,
|
||||
network_scope_name=self.scope_name))
|
||||
else:
|
||||
# The user passed their own problematic map_func.
|
||||
raise ValueError(
|
||||
("The map_func passed to Network.save for the Network '%s' "
|
||||
"resulted in two variables named '%s' ('%s' and '%s'). Try "
|
||||
"stripping less from the variable names, or renaming parts of "
|
||||
"the Network. For reference, variables created by sub-Layers of "
|
||||
"this Network are prefixed with '%s', but if they are re-used "
|
||||
"after being added to another Network, they will have that "
|
||||
"Network's full variable prefix instead.") % (
|
||||
self.name, mapped_name,
|
||||
variable_map[mapped_name]._shared_name,
|
||||
variable._shared_name,
|
||||
self.scope_name))
|
||||
if context.in_eager_mode():
|
||||
sess = None
|
||||
else:
|
||||
sess = ops.get_default_session()
|
||||
return saver_lib.Saver(variable_map).save(
|
||||
sess=sess, save_path=save_path, write_meta_graph=False,
|
||||
global_step=global_step)
|
||||
|
||||
def _restore_existing_variables(self, save_path, map_func, user_map_func):
|
||||
"""Use a standard Saver to restore existing variables from a checkpoint.
|
||||
|
||||
Args:
|
||||
save_path: The checkpoint prefix or directory to read from.
|
||||
map_func: The function to use when mapping from variable names to
|
||||
checkpoint names.
|
||||
user_map_func: The original map_func passed by the user, for error
|
||||
checking.
|
||||
Returns:
|
||||
A dictionary mapping from checkpoint names to variable objects which have
|
||||
been restored (for bookkeeping to avoid deferred restorations on these
|
||||
variables).
|
||||
Raises:
|
||||
ValueError: If there is a name collision.
|
||||
"""
|
||||
existing_variables_by_checkpoint_name = {}
|
||||
for variable in self.variables:
|
||||
checkpoint_name = map_func(variable._shared_name)
|
||||
if existing_variables_by_checkpoint_name.setdefault(
|
||||
checkpoint_name, variable) is not variable:
|
||||
if user_map_func is None:
|
||||
raise ValueError(_default_naming_conflict_error_message(
|
||||
mapped_name=checkpoint_name,
|
||||
first_variable=existing_variables_by_checkpoint_name[
|
||||
checkpoint_name],
|
||||
second_variable=variable,
|
||||
network_name=self.name,
|
||||
network_scope_name=self.scope_name))
|
||||
else:
|
||||
raise ValueError(_restore_custom_map_func_error_message(
|
||||
mapped_name=checkpoint_name,
|
||||
first_variable=existing_variables_by_checkpoint_name[
|
||||
checkpoint_name],
|
||||
second_variable=variable,
|
||||
network_name=self.name,
|
||||
network_scope_name=self.scope_name))
|
||||
if existing_variables_by_checkpoint_name:
|
||||
if context.in_eager_mode():
|
||||
sess = None
|
||||
else:
|
||||
sess = ops.get_default_session()
|
||||
saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore(
|
||||
sess=sess, save_path=save_path)
|
||||
return existing_variables_by_checkpoint_name
|
||||
|
||||
def _set_restore_on_create(self, save_path, map_func, user_map_func,
|
||||
existing_variables_by_checkpoint_name):
|
||||
"""If necessary, request deferred restorations of variables."""
|
||||
checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
|
||||
checkpointed_variables_to_restore = {}
|
||||
for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
|
||||
if checkpoint_name in existing_variables_by_checkpoint_name:
|
||||
# This variable was already created and restored.
|
||||
continue
|
||||
# Save the variable for later restoration in a custom getter.
|
||||
checkpointed_variables_to_restore[checkpoint_name] = (
|
||||
checkpoint_reader.get_tensor(checkpoint_name))
|
||||
# Only set a deferred restoration if there are checkpoint variables which
|
||||
# have not been assigned to existing variables. Note that this loses out on
|
||||
# some opportunity for error checking, but avoids creating
|
||||
# _DeferredRestoration objects once a Network has been built (so that
|
||||
# restoring in a loop does not take increasing amounts of memory).
|
||||
if checkpointed_variables_to_restore:
|
||||
if context.in_eager_mode():
|
||||
sess = None
|
||||
else:
|
||||
sess = ops.get_default_session()
|
||||
# We need a name for error messages. If we haven't been added to another
|
||||
# Network yet, we're top-level.
|
||||
self._finalize_name(False)
|
||||
self._set_scope()
|
||||
# Save a record of this restoration for use in the custom getter.
|
||||
deferred_restoration = _DeferredRestoration(
|
||||
map_func=map_func,
|
||||
map_func_is_user=(user_map_func is not None),
|
||||
checkpointed_variables_to_restore=checkpointed_variables_to_restore,
|
||||
restored_variables={},
|
||||
session=sess,
|
||||
network_name=self.name,
|
||||
network_scope_name=self.scope_name)
|
||||
self._deferred_restorations.append(deferred_restoration)
|
||||
# Add the deferred registration to non-Network children, and request that
|
||||
# Networks propagate the request to their children.
|
||||
self._add_deferred_restoration(deferred_restoration)
|
||||
|
||||
def _add_deferred_restoration(self, deferred_restoration):
|
||||
"""Add a deferred restoration to this Network and all children.
|
||||
|
||||
Restorations which are requested later have higher priority, and the highest
|
||||
priority matching restoration is applied to a variable when it is created.
|
||||
|
||||
Args:
|
||||
deferred_restoration: A _DeferredRestoration object.
|
||||
"""
|
||||
# Networks don't create variables at the moment, so this append isn't
|
||||
# strictly necessary. We could get by with only adding deferred restorations
|
||||
# to non-Network Layers.
|
||||
self._set_scope()
|
||||
# We use set_custom_getter because it avoids recursively calling up the
|
||||
# variable_scope tree. We've done the tree traversal ourselves and have
|
||||
# added the request to each Layer which needs it.
|
||||
self._scope.set_custom_getter(self._custom_getter)
|
||||
self._deferred_restorations.append(deferred_restoration)
|
||||
for layer in self.layers:
|
||||
if isinstance(layer, Network):
|
||||
# For Networks, request that they propagate this deferred restoration
|
||||
# to all of their children recursively.
|
||||
layer._add_deferred_restoration(deferred_restoration)
|
||||
else:
|
||||
# For non-Network Layers, make sure they have a deferred restoration
|
||||
# queue and a custom getter, then add our request to it.
|
||||
if not hasattr(layer, "_custom_getter"):
|
||||
assert not hasattr(layer, "_deferred_restorations")
|
||||
layer._custom_getter, layer._deferred_restorations = (
|
||||
_make_custom_getter_for_deferred_restorations())
|
||||
self._set_scope_for_nonnetwork_sublayer(layer)
|
||||
layer._scope.set_custom_getter(layer._custom_getter)
|
||||
layer._deferred_restorations.append(deferred_restoration)
|
||||
|
||||
def restore(self, save_path, map_func=None):
|
||||
"""Restore the Network from a checkpoint.
|
||||
|
||||
If variables have already been created (typically when some or all of the
|
||||
`Network` is built), they are assigned values from the checkpoint
|
||||
immediately, overwriting any existing values (in graph mode the default
|
||||
session is used for the assignments).
|
||||
|
||||
If there are checkpoint entries which do not correspond to any existing
|
||||
variables in the `Network`, these values are saved for deferred restoration;
|
||||
their initial values will be the checkpointed values once they are
|
||||
created. Requests for multiple deferred restorations behave the same way as
|
||||
immediate restorations, in that later requests will take priority over
|
||||
earlier requests relevant to the same variable.
|
||||
|
||||
If this `Network` shares `Layer`s with another network, those `Layer`s will
|
||||
also have their variables restored from the checkpoint.
|
||||
|
||||
Args:
|
||||
save_path: The return value of `Network.save`, or a directory to search
|
||||
for a checkpoint.
|
||||
map_func: A function mapping fully qualified variable names
|
||||
(e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
|
||||
default (if `map_func=None`), the variable prefix for the network being
|
||||
restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
|
||||
and all other variable names (shared with other Networks) are left
|
||||
unchanged. Note that this is the _same_ map_func as `Network.save`, not
|
||||
an inverse mapping.
|
||||
"""
|
||||
self._finalize_name(parent_network=False)
|
||||
self._set_scope() # scope_name should be available to map_funcs
|
||||
if os.path.isdir(save_path):
|
||||
# If we don't have a name yet, set no parent.
|
||||
save_path = os.path.join(save_path, self.name)
|
||||
user_map_func = map_func
|
||||
if map_func is None:
|
||||
map_func = self._strip_variable_prefix
|
||||
# Step one is to restore any existing variables from the checkpoint.
|
||||
existing_variables_by_checkpoint_name = self._restore_existing_variables(
|
||||
save_path=save_path,
|
||||
map_func=map_func,
|
||||
user_map_func=user_map_func)
|
||||
# Step two is to set a custom getter which restores variables on creation,
|
||||
# for those variables which have not been added to sub-Layers yet.
|
||||
self._set_restore_on_create(
|
||||
save_path=save_path,
|
||||
map_func=map_func,
|
||||
user_map_func=user_map_func,
|
||||
existing_variables_by_checkpoint_name=(
|
||||
existing_variables_by_checkpoint_name))
|
||||
|
||||
# TODO(josh11b): Support other Layer methods needed for graph mode, such as for
|
||||
# losses and updates
|
||||
|
||||
|
|
|
|||
|
|
@ -21,12 +21,14 @@ import gc
|
|||
from tensorflow.contrib.eager.python import network
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
# pylint: disable=not-callable
|
||||
|
|
@ -42,6 +44,29 @@ class MyNetwork(network.Network):
|
|||
|
||||
class NetworkTest(test.TestCase):
|
||||
|
||||
def _save_modify_load_network_built(self, net, global_step=None):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_path = net.save(
|
||||
save_path=checkpoint_directory, global_step=global_step)
|
||||
input_value = constant_op.constant([[42.0]])
|
||||
original_output = self.evaluate(net(input_value))
|
||||
for var in net.variables:
|
||||
self.evaluate(var.assign(var + 1.))
|
||||
self.assertGreater(
|
||||
self.evaluate(net(input_value)),
|
||||
original_output)
|
||||
# Either the returned explicit checkpoint path or the directory should work.
|
||||
net.restore(save_path=checkpoint_directory)
|
||||
self.assertAllEqual(
|
||||
original_output,
|
||||
self.evaluate(net(input_value)))
|
||||
for var in net.variables:
|
||||
self.evaluate(var.assign(var + 2.))
|
||||
net.restore(save_path=checkpoint_path)
|
||||
self.assertAllEqual(
|
||||
original_output,
|
||||
self.evaluate(net(input_value)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testTrainableAttribute(self):
|
||||
net = network.Network()
|
||||
|
|
@ -60,6 +85,264 @@ class NetworkTest(test.TestCase):
|
|||
result = net(constant_op.constant([[2.0]]))
|
||||
self.assertEqual(34.0, self.evaluate(result))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNetworkSaveRestoreAlreadyBuilt(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Attempt to save the Network before it was first called"):
|
||||
net.save(self.get_temp_dir())
|
||||
net(constant_op.constant([[2.0]]))
|
||||
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
|
||||
self._save_modify_load_network_built(net, global_step=None)
|
||||
self._save_modify_load_network_built(net, global_step=10)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSaveRestoreDefaultGlobalStep(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
net(constant_op.constant([[2.0]]))
|
||||
self.evaluate(net.variables[0].assign([[3.]]))
|
||||
default_global_step = training_util.get_or_create_global_step()
|
||||
self.evaluate(default_global_step.assign(4242))
|
||||
save_path = net.save(self.get_temp_dir())
|
||||
self.assertIn("abcd-4242", save_path)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNetworkSaveAndRestoreIntoUnbuilt(self):
|
||||
save_dir = self.get_temp_dir()
|
||||
net1 = MyNetwork()
|
||||
test_input = constant_op.constant([[2.0]])
|
||||
net1(test_input)
|
||||
self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
|
||||
save_path = net1.save(save_dir)
|
||||
# With a pre-build restore we should have the same value.
|
||||
net2 = MyNetwork()
|
||||
net2.restore(save_path)
|
||||
self.assertAllEqual(self.evaluate(net1(test_input)),
|
||||
self.evaluate(net2(test_input)))
|
||||
self.assertIsNot(net1.variables[0], net2.variables[0])
|
||||
self.assertAllEqual(self.evaluate(net1.variables[0]),
|
||||
self.evaluate(net2.variables[0]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLoadIntoUnbuiltSharedLayer(self):
|
||||
|
||||
class Owner(network.Network):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(Owner, self).__init__(name=name)
|
||||
self.first = self.track_layer(core.Dense(
|
||||
1, name="first_layer", use_bias=False))
|
||||
|
||||
def call(self, x):
|
||||
return self.first(x)
|
||||
|
||||
first_owner = Owner()
|
||||
|
||||
class User(network.Network):
|
||||
|
||||
def __init__(self, use_layer, name=None):
|
||||
super(User, self).__init__(name=name)
|
||||
self.first = self.track_layer(use_layer)
|
||||
self.second = self.track_layer(core.Dense(
|
||||
1, name="second_layer", use_bias=False))
|
||||
|
||||
def call(self, x):
|
||||
return self.second(self.first(x))
|
||||
|
||||
class LikeUserButNotSharing(network.Network):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(LikeUserButNotSharing, self).__init__(name=name)
|
||||
self.first = self.track_layer(core.Dense(
|
||||
1, name="first_layer", use_bias=False))
|
||||
self.second = self.track_layer(core.Dense(
|
||||
1, name="second_layer", use_bias=False))
|
||||
|
||||
def call(self, x):
|
||||
return self.second(self.first(x))
|
||||
|
||||
checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator")
|
||||
one = constant_op.constant([[1.0]])
|
||||
checkpoint_creator(one)
|
||||
self.assertEqual(2, len(checkpoint_creator.variables))
|
||||
self.evaluate(checkpoint_creator.variables[0].assign([[5.]]))
|
||||
self.evaluate(checkpoint_creator.variables[1].assign([[6.]]))
|
||||
# Re-map the variable names so that with default restore mapping we'll
|
||||
# attempt to restore into the unbuilt Layer.
|
||||
name_mapping = {
|
||||
"checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel",
|
||||
"checkpoint_creator/second_layer/kernel": "second_layer/kernel",
|
||||
}
|
||||
save_path = checkpoint_creator.save(
|
||||
self.get_temp_dir(),
|
||||
map_func=lambda full_name: name_mapping[full_name])
|
||||
load_into = User(use_layer=first_owner.first)
|
||||
load_into.restore(save_path)
|
||||
self.assertEqual(0, len(first_owner.variables))
|
||||
self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
|
||||
self.evaluate(load_into(one)))
|
||||
self.assertEqual(1, len(first_owner.variables))
|
||||
self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0]))
|
||||
self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1]))
|
||||
first_owner(one)
|
||||
self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0]))
|
||||
|
||||
# Try again with a garbage collected parent.
|
||||
first_owner = Owner()
|
||||
load_into = User(use_layer=first_owner.first)
|
||||
del first_owner
|
||||
gc.collect()
|
||||
def _restore_map_func(original_name):
|
||||
if original_name.startswith("owner_1"):
|
||||
return original_name.replace("owner_1", "owner_2")
|
||||
else:
|
||||
return "user_2/" + original_name
|
||||
with self.assertRaisesRegexp(ValueError, "garbage collected"):
|
||||
load_into.restore(save_path, map_func=_restore_map_func)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testRestoreIntoSubNetwork(self):
|
||||
|
||||
class Parent(network.Network):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(Parent, self).__init__(name=name)
|
||||
self.first = self.track_layer(MyNetwork())
|
||||
self.second = self.track_layer(MyNetwork())
|
||||
|
||||
def call(self, x):
|
||||
return self.first(self.second(x))
|
||||
|
||||
one = constant_op.constant([[3.]])
|
||||
whole_model_saver = Parent()
|
||||
whole_model_saver(one)
|
||||
self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
|
||||
self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
|
||||
whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir())
|
||||
|
||||
save_from = MyNetwork()
|
||||
save_from(one)
|
||||
self.evaluate(save_from.variables[0].assign([[5.]]))
|
||||
checkpoint = save_from.save(self.get_temp_dir())
|
||||
save_into_parent = Parent()
|
||||
save_into_parent.restore(whole_model_checkpoint)
|
||||
save_into_parent.first.restore(checkpoint)
|
||||
save_into_parent.first.restore(checkpoint) # deferred loading multiple
|
||||
# times is fine
|
||||
save_into_parent(one) # deferred loading
|
||||
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
|
||||
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
|
||||
|
||||
# Try again with the opposite ordering, and we should get different results
|
||||
# (deferred restoration should happen the same way non-deferred happens,
|
||||
# with later restorations overwriting older ones).
|
||||
save_into_parent = Parent()
|
||||
save_into_parent.first.restore(checkpoint) # deferred loading multiple
|
||||
# times is fine
|
||||
save_into_parent.restore(whole_model_checkpoint)
|
||||
save_into_parent(one) # deferred loading
|
||||
# We've overwritten the sub-Network restore.
|
||||
self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
|
||||
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
|
||||
|
||||
self.evaluate(save_into_parent.variables[0].assign([[3.]]))
|
||||
self.evaluate(save_into_parent.variables[1].assign([[4.]]))
|
||||
save_into_parent.second.restore(checkpoint)
|
||||
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
|
||||
with self.assertRaisesRegexp(errors_impl.NotFoundError,
|
||||
"not found in checkpoint"):
|
||||
# The checkpoint is incompatible.
|
||||
save_into_parent.restore(checkpoint)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testCustomMapCollisionErrors(self):
|
||||
|
||||
class Parent(network.Network):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(Parent, self).__init__(name=name)
|
||||
self.first = self.track_layer(MyNetwork())
|
||||
self.second = self.track_layer(MyNetwork())
|
||||
|
||||
def call(self, x):
|
||||
return self.first(self.second(x))
|
||||
|
||||
make_checkpoint = Parent()
|
||||
one = constant_op.constant([[1.]])
|
||||
make_checkpoint(one)
|
||||
self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
|
||||
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"The map_func passed to Network.save for the Network 'parent_1' "
|
||||
"resulted in two variables named 'foo'"):
|
||||
make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo")
|
||||
checkpoint = make_checkpoint.first.save(
|
||||
self.get_temp_dir(), map_func=lambda n: "foo")
|
||||
loader = Parent()
|
||||
loader.restore(checkpoint, map_func=lambda n: "foo")
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
("The map_func passed to Network.restore for the Network"
|
||||
" 'parent_2' resulted in two variables named 'foo'")):
|
||||
loader(one)
|
||||
loader = Parent()
|
||||
loader(one)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
("The map_func passed to Network.restore for the Network"
|
||||
" 'parent_3' resulted in two variables named 'foo'")):
|
||||
loader.restore(checkpoint, map_func=lambda n: "foo")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDefaultMapCollisionErrors(self):
|
||||
|
||||
one = constant_op.constant([[1.]])
|
||||
first = core.Dense(1, name="dense_1", use_bias=False)
|
||||
first(one)
|
||||
|
||||
class Parent(network.Network):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(Parent, self).__init__(name=name)
|
||||
self.first = self.track_layer(first)
|
||||
self.second = self.track_layer(core.Dense(1, use_bias=False))
|
||||
|
||||
def call(self, x):
|
||||
return self.first(self.second(x))
|
||||
|
||||
make_checkpoint = Parent()
|
||||
one = constant_op.constant([[1.]])
|
||||
make_checkpoint(one)
|
||||
self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
|
||||
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
("The default checkpoint variable name mapping strategy for Network "
|
||||
"'parent_1' resulted in a naming conflict.")):
|
||||
make_checkpoint.save(self.get_temp_dir())
|
||||
|
||||
class Compatible(network.Network):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(Compatible, self).__init__(name=name)
|
||||
self.first = self.track_layer(core.Dense(1, use_bias=False))
|
||||
|
||||
def call(self, x):
|
||||
return self.first(x)
|
||||
|
||||
successful_checkpoint = Compatible()
|
||||
successful_checkpoint(one)
|
||||
self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
|
||||
checkpoint_path = successful_checkpoint.save(self.get_temp_dir())
|
||||
load_checkpoint = Parent()
|
||||
load_checkpoint(one)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
("The default checkpoint variable name mapping strategy for Network "
|
||||
"'parent_2' resulted in a naming conflict.")):
|
||||
load_checkpoint.restore(checkpoint_path)
|
||||
|
||||
def testNoReferenceCyclesAfterCall(self):
|
||||
|
||||
class ChildNetwork(network.Network):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user