Adds save and restore methods to tfe.Network

Save just saves the variables to a checkpoint. Restore either restores immediately or defers the restoration to variable creation time with a custom getter.

PiperOrigin-RevId: 173703075
This commit is contained in:
Allen Lavoie 2017-10-27 12:19:12 -07:00 committed by TensorFlower Gardener
parent 9158f974a3
commit d7cffe9c03
2 changed files with 728 additions and 22 deletions

View File

@ -19,11 +19,17 @@ from __future__ import division
from __future__ import print_function
import collections
import os
import weakref
from tensorflow.python.eager import context
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.layers import base
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
# pylint: disable=protected-access
# Explanation for protected-access disable: Network has lots of same-class and
@ -31,6 +37,151 @@ from tensorflow.python.ops import variable_scope
# functions in base.py which should be reused.
_DeferredRestoration = collections.namedtuple(
"_DeferredRestoration",
[
# The map_func to use (either user-specified or the default).
"map_func",
# Boolean, True if the user specified an explicit map_func, for error
# messages.
"map_func_is_user",
# A mapping from checkpoint names to initial values of not-yet-created
# variables which should be restored. These values come from parsing a
# checkpoint.
"checkpointed_variables_to_restore",
# A mapping from checkpoint name to variable objects of variables which
# have already been restored, for error checking.
"restored_variables",
# The session to restore with (if in graph mode).
"session",
# Names of the Network where the restore was requested, for error
# messages.
"network_name",
"network_scope_name"
])
def _default_naming_conflict_error_message(
mapped_name, first_variable, second_variable,
network_name, network_scope_name):
return (
("The default checkpoint variable name mapping strategy for Network "
"'%s' resulted in a naming conflict. We attempted to strip off the "
"variable prefix for the Network ('%s'), but this resulted in two "
"variables named '%s' (originally '%s' and '%s'). This should only "
"happen when using variable sharing (i.e. the Network contains Networks "
"or Layers which were first added to another Network, and therefore "
"have that Network's variable prefix). One solution is to pass "
"`map_func=lambda n: n` to Network.save and Network.restore to use "
"fully qualified variable names in the checkpoint, although this will "
"require that the variable prefix of the Network being restored into "
"is also '%s'. You may alternatively write an arbitrary mapping.")
% (
network_name, network_scope_name, mapped_name,
first_variable._shared_name,
second_variable._shared_name, network_scope_name
))
def _restore_custom_map_func_error_message(
mapped_name, first_variable, second_variable,
network_name, network_scope_name):
return (
("The map_func passed to Network.restore for the Network '%s' "
"resulted in two variables named '%s' (originally '%s' and '%s'). Since "
"this is also an error on Network.save, this Network was "
"probably not saved with this map_func. Note that map_func "
"always maps from full variable names to checkpoint names; "
"there is no need to specify an inverse mapping.\n\n"
"Try stripping less from the variable names, or renaming parts "
"of the Network. For reference, variables created by sub-Layers "
"of this Network are prefixed with '%s', but if they are "
"re-used after being added to another Network they will have "
"that Network's full variable prefix instead.") % (
network_name, mapped_name,
first_variable._shared_name,
second_variable._shared_name,
network_scope_name))
def _make_custom_getter_for_deferred_restorations():
"""Returns a custom getter which searches `deferred_restorations`.
Returns: A tuple of (_custom_getter, deferred_restorations)
_custom_getter: The getter which should be added to variable_scopes where
variables will be created.
deferred_restorations: A list for _DeferredRestoration objects. Typically
empty when the getter is set, and expanded as deferred restorations are
requested. All new deferred restorations should be appended to the end of
the list, where they will have priority over older deferred restorations.
"""
deferred_restorations = []
def _custom_getter(getter, name, shape=None, dtype=None,
initializer=None,
*args, **kwargs):
"""A custom getter which processes deferred restorations."""
# Iterate over restorations, newest first (newer restorations will take
# precedence over older restorations, just like with immediate restorations
# into existing variables).
delayed_restoration = None
found_value = False
value_to_restore = None
for delayed_restoration in reversed(
deferred_restorations):
checkpoint_name = delayed_restoration.map_func(name)
if (checkpoint_name
in delayed_restoration.checkpointed_variables_to_restore):
found_value = True
value_to_restore = (
delayed_restoration.checkpointed_variables_to_restore[
checkpoint_name])
if found_value:
break
# value_to_restore may be False because this variable is not in any
# checkpoint we are restoring, or None because we have explicitly set it to
# None when it was previously fetched. In either case, we don't need to
# set an initializer.
if found_value and value_to_restore is not None:
initializer = value_to_restore
shape = None
variable = getter(name, shape=shape, dtype=dtype, initializer=initializer,
*args, **kwargs)
if found_value and value_to_restore is not None:
# Mark as already restored from this checkpoint.
delayed_restoration.checkpointed_variables_to_restore[
checkpoint_name] = None
if context.in_graph_mode():
delayed_restoration.session.run(variable.initializer)
if found_value:
# Error checking should run even if we've already restored a value.
if delayed_restoration.restored_variables.setdefault(
checkpoint_name, variable) is not variable:
# Naming conflict. We've tried to initialize two variables with the
# same value from the checkpoint.
if delayed_restoration.map_func_is_user:
raise ValueError(
_restore_custom_map_func_error_message(
mapped_name=checkpoint_name,
first_variable=delayed_restoration.restored_variables[
checkpoint_name],
second_variable=variable,
network_name=delayed_restoration.network_name,
network_scope_name=delayed_restoration.network_scope_name))
else:
raise ValueError(
_default_naming_conflict_error_message(
mapped_name=checkpoint_name,
first_variable=delayed_restoration.restored_variables[
checkpoint_name],
second_variable=variable,
network_name=delayed_restoration.network_name,
network_scope_name=delayed_restoration.network_scope_name))
return variable
return _custom_getter, deferred_restorations
class Network(base.Layer):
"""Represents the composition of a set of Layers.
@ -41,7 +192,6 @@ class Network(base.Layer):
- Convert inputs to __call__ to tensors.
- Prevent variables from being created after the first __call__?
(Think about restoring from a checkpoint).
- Save & restore
"""
def __init__(self, name=None):
@ -60,6 +210,8 @@ class Network(base.Layer):
self._owned_layers = {}
# The scope to use if we end up without a parent.
self._default_parent_variable_scope = variable_scope.get_variable_scope()
self._custom_getter, self._deferred_restorations = (
_make_custom_getter_for_deferred_restorations())
def _init_set_name(self, name):
# Anonymous Networks (name=None) defer setting a final name until they are
@ -87,7 +239,8 @@ class Network(base.Layer):
avoid_names = None
self._name, self._base_name = self._make_unique_name(
name_uid_map=name_uid_map, avoid_names=avoid_names)
if self._first_parent is None or self._first_parent() is None:
if self._first_parent is None or (self._first_parent # False = no parent
and self._first_parent() is None):
# Save a pointer to the parent Network so that we can later check that the
# scope name we get is correct.
if not parent_network:
@ -151,26 +304,32 @@ class Network(base.Layer):
"of Networks in which they were first created). To set "
"options, try `with tf.variable_scope(''):`. If this "
"limitation bothers you, please file a feature request.")
for non_network_constituent in self._non_network_sublayers:
if non_network_constituent._scope is None:
if non_network_constituent._first_parent is None:
constituent_first_parent = None
else:
constituent_first_parent = non_network_constituent._first_parent()
if constituent_first_parent:
constituent_first_parent._set_scope()
parent_scope = constituent_first_parent._scope
else:
parent_scope = (
non_network_constituent._default_parent_variable_scope)
with variable_scope.variable_scope(parent_scope):
# Horrid hack to make Layer variable names which are direct
# sub-layers of Networks conform to the Network variable naming
# conventions.
with variable_scope.variable_scope(
None, use_resource=True,
default_name=non_network_constituent.name) as sub_scope:
non_network_constituent._scope = sub_scope
for non_network_sublayer in self._non_network_sublayers:
self._set_scope_for_nonnetwork_sublayer(non_network_sublayer)
def _set_scope_for_nonnetwork_sublayer(self, sublayer):
if sublayer._scope is None:
if sublayer._first_parent is None:
constituent_first_parent = None
else:
constituent_first_parent = sublayer._first_parent()
if constituent_first_parent:
constituent_first_parent._set_scope()
parent_scope = constituent_first_parent._scope
else:
self._finalize_name(False)
raise ValueError(
("The parent of a Layer added to Network %s was garbage collected "
"before the Layer was built. If this limitation bothers you "
"please, file a feature request.") % (self.name,))
with variable_scope.variable_scope(parent_scope):
# Horrid hack to make Layer variable names which are direct
# sub-layers of Networks conform to the Network variable naming
# conventions.
with variable_scope.variable_scope(
None, use_resource=True,
default_name=sublayer.name) as sub_scope:
sublayer._scope = sub_scope
@base.Layer.name.getter
def name(self):
@ -327,6 +486,270 @@ class Network(base.Layer):
"at https://github.com/tensorflow/tensorflow/issues/new if this is "
"important to you")
def _strip_variable_prefix(self, original_variable_name):
"""The default map_func for saving or restoring variables.
Strips the variable prefix for the Network on which save/restore was called,
and leaves other variable names fully qualified in the checkpoint.
Args:
original_variable_name: The _shared_name of the variable (no :0
suffix) to map.
Returns:
The checkpoint name of the variable.
"""
scope_name_with_slash = self.scope_name + "/"
if original_variable_name.startswith(scope_name_with_slash):
return original_variable_name[len(scope_name_with_slash):]
else:
return original_variable_name
def save(self, save_path, global_step=None, map_func=None):
"""Save variables from the Network to a checkpoint.
Args:
save_path: Either a checkpoint prefix or the name of a directory to save
the checkpoint in (in which case the checkpoint will be named based on
the Network name).
global_step: The global step to use when naming the checkpoint. If None
(default), we will first try to get the default global step. If that
fails because no default global step exists, then the checkpoint is
created without a global step suffix.
map_func: A function mapping fully qualified variable names
(e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
default (if `map_func=None`), the variable prefix for the network being
restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
and all other variable names (shared with other Networks) are left
unchanged.
Returns:
The checkpoint prefix for the saved checkpoint, which may be passed to
`Network.restore`.
Raises:
ValueError: If the Network has not yet been called, or if map_func results
in a name collision.
"""
if not self.built:
raise ValueError(
"Attempt to save the Network before it was first called. This means "
"variables have not yet been created, so there is nothing to save.")
self._set_scope() # scope_name should be available to map_funcs
if global_step is None:
global_step = training_util.get_global_step()
if os.path.isdir(save_path):
# If we were passed a directory, default to naming based on the Network
# name.
save_path = os.path.join(save_path, self.name)
user_map_func = map_func
if map_func is None:
map_func = self._strip_variable_prefix
variable_map = {}
for variable in self.variables:
mapped_name = map_func(variable._shared_name)
if variable_map.setdefault(mapped_name, variable) is not variable:
if user_map_func is None:
# Instead of erroring out, we could just re-try and silently use the
# full variable names in the checkpoint. This could be odd for deeply
# nested sub-Networks (since the full prefix from the nesting would
# get added), so for now we'll let the user deal with this case.
raise ValueError(_default_naming_conflict_error_message(
mapped_name=mapped_name,
first_variable=variable_map[mapped_name],
second_variable=variable,
network_name=self.name,
network_scope_name=self.scope_name))
else:
# The user passed their own problematic map_func.
raise ValueError(
("The map_func passed to Network.save for the Network '%s' "
"resulted in two variables named '%s' ('%s' and '%s'). Try "
"stripping less from the variable names, or renaming parts of "
"the Network. For reference, variables created by sub-Layers of "
"this Network are prefixed with '%s', but if they are re-used "
"after being added to another Network, they will have that "
"Network's full variable prefix instead.") % (
self.name, mapped_name,
variable_map[mapped_name]._shared_name,
variable._shared_name,
self.scope_name))
if context.in_eager_mode():
sess = None
else:
sess = ops.get_default_session()
return saver_lib.Saver(variable_map).save(
sess=sess, save_path=save_path, write_meta_graph=False,
global_step=global_step)
def _restore_existing_variables(self, save_path, map_func, user_map_func):
"""Use a standard Saver to restore existing variables from a checkpoint.
Args:
save_path: The checkpoint prefix or directory to read from.
map_func: The function to use when mapping from variable names to
checkpoint names.
user_map_func: The original map_func passed by the user, for error
checking.
Returns:
A dictionary mapping from checkpoint names to variable objects which have
been restored (for bookkeeping to avoid deferred restorations on these
variables).
Raises:
ValueError: If there is a name collision.
"""
existing_variables_by_checkpoint_name = {}
for variable in self.variables:
checkpoint_name = map_func(variable._shared_name)
if existing_variables_by_checkpoint_name.setdefault(
checkpoint_name, variable) is not variable:
if user_map_func is None:
raise ValueError(_default_naming_conflict_error_message(
mapped_name=checkpoint_name,
first_variable=existing_variables_by_checkpoint_name[
checkpoint_name],
second_variable=variable,
network_name=self.name,
network_scope_name=self.scope_name))
else:
raise ValueError(_restore_custom_map_func_error_message(
mapped_name=checkpoint_name,
first_variable=existing_variables_by_checkpoint_name[
checkpoint_name],
second_variable=variable,
network_name=self.name,
network_scope_name=self.scope_name))
if existing_variables_by_checkpoint_name:
if context.in_eager_mode():
sess = None
else:
sess = ops.get_default_session()
saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore(
sess=sess, save_path=save_path)
return existing_variables_by_checkpoint_name
def _set_restore_on_create(self, save_path, map_func, user_map_func,
existing_variables_by_checkpoint_name):
"""If necessary, request deferred restorations of variables."""
checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
checkpointed_variables_to_restore = {}
for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
if checkpoint_name in existing_variables_by_checkpoint_name:
# This variable was already created and restored.
continue
# Save the variable for later restoration in a custom getter.
checkpointed_variables_to_restore[checkpoint_name] = (
checkpoint_reader.get_tensor(checkpoint_name))
# Only set a deferred restoration if there are checkpoint variables which
# have not been assigned to existing variables. Note that this loses out on
# some opportunity for error checking, but avoids creating
# _DeferredRestoration objects once a Network has been built (so that
# restoring in a loop does not take increasing amounts of memory).
if checkpointed_variables_to_restore:
if context.in_eager_mode():
sess = None
else:
sess = ops.get_default_session()
# We need a name for error messages. If we haven't been added to another
# Network yet, we're top-level.
self._finalize_name(False)
self._set_scope()
# Save a record of this restoration for use in the custom getter.
deferred_restoration = _DeferredRestoration(
map_func=map_func,
map_func_is_user=(user_map_func is not None),
checkpointed_variables_to_restore=checkpointed_variables_to_restore,
restored_variables={},
session=sess,
network_name=self.name,
network_scope_name=self.scope_name)
self._deferred_restorations.append(deferred_restoration)
# Add the deferred registration to non-Network children, and request that
# Networks propagate the request to their children.
self._add_deferred_restoration(deferred_restoration)
def _add_deferred_restoration(self, deferred_restoration):
"""Add a deferred restoration to this Network and all children.
Restorations which are requested later have higher priority, and the highest
priority matching restoration is applied to a variable when it is created.
Args:
deferred_restoration: A _DeferredRestoration object.
"""
# Networks don't create variables at the moment, so this append isn't
# strictly necessary. We could get by with only adding deferred restorations
# to non-Network Layers.
self._set_scope()
# We use set_custom_getter because it avoids recursively calling up the
# variable_scope tree. We've done the tree traversal ourselves and have
# added the request to each Layer which needs it.
self._scope.set_custom_getter(self._custom_getter)
self._deferred_restorations.append(deferred_restoration)
for layer in self.layers:
if isinstance(layer, Network):
# For Networks, request that they propagate this deferred restoration
# to all of their children recursively.
layer._add_deferred_restoration(deferred_restoration)
else:
# For non-Network Layers, make sure they have a deferred restoration
# queue and a custom getter, then add our request to it.
if not hasattr(layer, "_custom_getter"):
assert not hasattr(layer, "_deferred_restorations")
layer._custom_getter, layer._deferred_restorations = (
_make_custom_getter_for_deferred_restorations())
self._set_scope_for_nonnetwork_sublayer(layer)
layer._scope.set_custom_getter(layer._custom_getter)
layer._deferred_restorations.append(deferred_restoration)
def restore(self, save_path, map_func=None):
"""Restore the Network from a checkpoint.
If variables have already been created (typically when some or all of the
`Network` is built), they are assigned values from the checkpoint
immediately, overwriting any existing values (in graph mode the default
session is used for the assignments).
If there are checkpoint entries which do not correspond to any existing
variables in the `Network`, these values are saved for deferred restoration;
their initial values will be the checkpointed values once they are
created. Requests for multiple deferred restorations behave the same way as
immediate restorations, in that later requests will take priority over
earlier requests relevant to the same variable.
If this `Network` shares `Layer`s with another network, those `Layer`s will
also have their variables restored from the checkpoint.
Args:
save_path: The return value of `Network.save`, or a directory to search
for a checkpoint.
map_func: A function mapping fully qualified variable names
(e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
default (if `map_func=None`), the variable prefix for the network being
restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
and all other variable names (shared with other Networks) are left
unchanged. Note that this is the _same_ map_func as `Network.save`, not
an inverse mapping.
"""
self._finalize_name(parent_network=False)
self._set_scope() # scope_name should be available to map_funcs
if os.path.isdir(save_path):
# If we don't have a name yet, set no parent.
save_path = os.path.join(save_path, self.name)
user_map_func = map_func
if map_func is None:
map_func = self._strip_variable_prefix
# Step one is to restore any existing variables from the checkpoint.
existing_variables_by_checkpoint_name = self._restore_existing_variables(
save_path=save_path,
map_func=map_func,
user_map_func=user_map_func)
# Step two is to set a custom getter which restores variables on creation,
# for those variables which have not been added to sub-Layers yet.
self._set_restore_on_create(
save_path=save_path,
map_func=map_func,
user_map_func=user_map_func,
existing_variables_by_checkpoint_name=(
existing_variables_by_checkpoint_name))
# TODO(josh11b): Support other Layer methods needed for graph mode, such as for
# losses and updates

View File

@ -21,12 +21,14 @@ import gc
from tensorflow.contrib.eager.python import network
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import training_util
# pylint: disable=not-callable
@ -42,6 +44,29 @@ class MyNetwork(network.Network):
class NetworkTest(test.TestCase):
def _save_modify_load_network_built(self, net, global_step=None):
checkpoint_directory = self.get_temp_dir()
checkpoint_path = net.save(
save_path=checkpoint_directory, global_step=global_step)
input_value = constant_op.constant([[42.0]])
original_output = self.evaluate(net(input_value))
for var in net.variables:
self.evaluate(var.assign(var + 1.))
self.assertGreater(
self.evaluate(net(input_value)),
original_output)
# Either the returned explicit checkpoint path or the directory should work.
net.restore(save_path=checkpoint_directory)
self.assertAllEqual(
original_output,
self.evaluate(net(input_value)))
for var in net.variables:
self.evaluate(var.assign(var + 2.))
net.restore(save_path=checkpoint_path)
self.assertAllEqual(
original_output,
self.evaluate(net(input_value)))
@test_util.run_in_graph_and_eager_modes()
def testTrainableAttribute(self):
net = network.Network()
@ -60,6 +85,264 @@ class NetworkTest(test.TestCase):
result = net(constant_op.constant([[2.0]]))
self.assertEqual(34.0, self.evaluate(result))
@test_util.run_in_graph_and_eager_modes()
def testNetworkSaveRestoreAlreadyBuilt(self):
net = MyNetwork(name="abcd")
with self.assertRaisesRegexp(
ValueError, "Attempt to save the Network before it was first called"):
net.save(self.get_temp_dir())
net(constant_op.constant([[2.0]]))
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
self._save_modify_load_network_built(net, global_step=None)
self._save_modify_load_network_built(net, global_step=10)
@test_util.run_in_graph_and_eager_modes()
def testSaveRestoreDefaultGlobalStep(self):
net = MyNetwork(name="abcd")
net(constant_op.constant([[2.0]]))
self.evaluate(net.variables[0].assign([[3.]]))
default_global_step = training_util.get_or_create_global_step()
self.evaluate(default_global_step.assign(4242))
save_path = net.save(self.get_temp_dir())
self.assertIn("abcd-4242", save_path)
@test_util.run_in_graph_and_eager_modes()
def testNetworkSaveAndRestoreIntoUnbuilt(self):
save_dir = self.get_temp_dir()
net1 = MyNetwork()
test_input = constant_op.constant([[2.0]])
net1(test_input)
self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
save_path = net1.save(save_dir)
# With a pre-build restore we should have the same value.
net2 = MyNetwork()
net2.restore(save_path)
self.assertAllEqual(self.evaluate(net1(test_input)),
self.evaluate(net2(test_input)))
self.assertIsNot(net1.variables[0], net2.variables[0])
self.assertAllEqual(self.evaluate(net1.variables[0]),
self.evaluate(net2.variables[0]))
@test_util.run_in_graph_and_eager_modes()
def testLoadIntoUnbuiltSharedLayer(self):
class Owner(network.Network):
def __init__(self, name=None):
super(Owner, self).__init__(name=name)
self.first = self.track_layer(core.Dense(
1, name="first_layer", use_bias=False))
def call(self, x):
return self.first(x)
first_owner = Owner()
class User(network.Network):
def __init__(self, use_layer, name=None):
super(User, self).__init__(name=name)
self.first = self.track_layer(use_layer)
self.second = self.track_layer(core.Dense(
1, name="second_layer", use_bias=False))
def call(self, x):
return self.second(self.first(x))
class LikeUserButNotSharing(network.Network):
def __init__(self, name=None):
super(LikeUserButNotSharing, self).__init__(name=name)
self.first = self.track_layer(core.Dense(
1, name="first_layer", use_bias=False))
self.second = self.track_layer(core.Dense(
1, name="second_layer", use_bias=False))
def call(self, x):
return self.second(self.first(x))
checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator")
one = constant_op.constant([[1.0]])
checkpoint_creator(one)
self.assertEqual(2, len(checkpoint_creator.variables))
self.evaluate(checkpoint_creator.variables[0].assign([[5.]]))
self.evaluate(checkpoint_creator.variables[1].assign([[6.]]))
# Re-map the variable names so that with default restore mapping we'll
# attempt to restore into the unbuilt Layer.
name_mapping = {
"checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel",
"checkpoint_creator/second_layer/kernel": "second_layer/kernel",
}
save_path = checkpoint_creator.save(
self.get_temp_dir(),
map_func=lambda full_name: name_mapping[full_name])
load_into = User(use_layer=first_owner.first)
load_into.restore(save_path)
self.assertEqual(0, len(first_owner.variables))
self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
self.evaluate(load_into(one)))
self.assertEqual(1, len(first_owner.variables))
self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0]))
self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1]))
first_owner(one)
self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0]))
# Try again with a garbage collected parent.
first_owner = Owner()
load_into = User(use_layer=first_owner.first)
del first_owner
gc.collect()
def _restore_map_func(original_name):
if original_name.startswith("owner_1"):
return original_name.replace("owner_1", "owner_2")
else:
return "user_2/" + original_name
with self.assertRaisesRegexp(ValueError, "garbage collected"):
load_into.restore(save_path, map_func=_restore_map_func)
@test_util.run_in_graph_and_eager_modes()
def testRestoreIntoSubNetwork(self):
class Parent(network.Network):
def __init__(self, name=None):
super(Parent, self).__init__(name=name)
self.first = self.track_layer(MyNetwork())
self.second = self.track_layer(MyNetwork())
def call(self, x):
return self.first(self.second(x))
one = constant_op.constant([[3.]])
whole_model_saver = Parent()
whole_model_saver(one)
self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir())
save_from = MyNetwork()
save_from(one)
self.evaluate(save_from.variables[0].assign([[5.]]))
checkpoint = save_from.save(self.get_temp_dir())
save_into_parent = Parent()
save_into_parent.restore(whole_model_checkpoint)
save_into_parent.first.restore(checkpoint)
save_into_parent.first.restore(checkpoint) # deferred loading multiple
# times is fine
save_into_parent(one) # deferred loading
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
# Try again with the opposite ordering, and we should get different results
# (deferred restoration should happen the same way non-deferred happens,
# with later restorations overwriting older ones).
save_into_parent = Parent()
save_into_parent.first.restore(checkpoint) # deferred loading multiple
# times is fine
save_into_parent.restore(whole_model_checkpoint)
save_into_parent(one) # deferred loading
# We've overwritten the sub-Network restore.
self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
self.evaluate(save_into_parent.variables[0].assign([[3.]]))
self.evaluate(save_into_parent.variables[1].assign([[4.]]))
save_into_parent.second.restore(checkpoint)
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
with self.assertRaisesRegexp(errors_impl.NotFoundError,
"not found in checkpoint"):
# The checkpoint is incompatible.
save_into_parent.restore(checkpoint)
@test_util.run_in_graph_and_eager_modes()
def testCustomMapCollisionErrors(self):
class Parent(network.Network):
def __init__(self, name=None):
super(Parent, self).__init__(name=name)
self.first = self.track_layer(MyNetwork())
self.second = self.track_layer(MyNetwork())
def call(self, x):
return self.first(self.second(x))
make_checkpoint = Parent()
one = constant_op.constant([[1.]])
make_checkpoint(one)
self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
with self.assertRaisesRegexp(
ValueError,
"The map_func passed to Network.save for the Network 'parent_1' "
"resulted in two variables named 'foo'"):
make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo")
checkpoint = make_checkpoint.first.save(
self.get_temp_dir(), map_func=lambda n: "foo")
loader = Parent()
loader.restore(checkpoint, map_func=lambda n: "foo")
with self.assertRaisesRegexp(
ValueError,
("The map_func passed to Network.restore for the Network"
" 'parent_2' resulted in two variables named 'foo'")):
loader(one)
loader = Parent()
loader(one)
with self.assertRaisesRegexp(
ValueError,
("The map_func passed to Network.restore for the Network"
" 'parent_3' resulted in two variables named 'foo'")):
loader.restore(checkpoint, map_func=lambda n: "foo")
@test_util.run_in_graph_and_eager_modes()
def testDefaultMapCollisionErrors(self):
one = constant_op.constant([[1.]])
first = core.Dense(1, name="dense_1", use_bias=False)
first(one)
class Parent(network.Network):
def __init__(self, name=None):
super(Parent, self).__init__(name=name)
self.first = self.track_layer(first)
self.second = self.track_layer(core.Dense(1, use_bias=False))
def call(self, x):
return self.first(self.second(x))
make_checkpoint = Parent()
one = constant_op.constant([[1.]])
make_checkpoint(one)
self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
with self.assertRaisesRegexp(
ValueError,
("The default checkpoint variable name mapping strategy for Network "
"'parent_1' resulted in a naming conflict.")):
make_checkpoint.save(self.get_temp_dir())
class Compatible(network.Network):
def __init__(self, name=None):
super(Compatible, self).__init__(name=name)
self.first = self.track_layer(core.Dense(1, use_bias=False))
def call(self, x):
return self.first(x)
successful_checkpoint = Compatible()
successful_checkpoint(one)
self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
checkpoint_path = successful_checkpoint.save(self.get_temp_dir())
load_checkpoint = Parent()
load_checkpoint(one)
with self.assertRaisesRegexp(
ValueError,
("The default checkpoint variable name mapping strategy for Network "
"'parent_2' resulted in a naming conflict.")):
load_checkpoint.restore(checkpoint_path)
def testNoReferenceCyclesAfterCall(self):
class ChildNetwork(network.Network):