Adding a slot / accumulator warmstart initializer that overrides the provided partitioner at call time with one passed at construction time. This is intended to be used for slot Variables (such as accumulators) associated with Optimizers, since these Variables are created in a fashion that relies on replicating the exact shape of the associated primary variables (see slot_creator).

PiperOrigin-RevId: 157453498
This commit is contained in:
A. Unique TensorFlower 2017-05-30 05:31:26 -07:00 committed by TensorFlower Gardener
parent 73d10599fe
commit 8c2a079ec8
3 changed files with 183 additions and 3 deletions

View File

@ -77,6 +77,7 @@ See the @{$python/contrib.framework} guide.
@@load_and_remap_matrix_initializer @@load_and_remap_matrix_initializer
@@load_embedding_initializer @@load_embedding_initializer
@@load_linear_multiclass_bias_initializer @@load_linear_multiclass_bias_initializer
@@load_variable_slot_initializer
""" """
from __future__ import absolute_import from __future__ import absolute_import

View File

@ -488,3 +488,91 @@ def load_linear_multiclass_bias_initializer(ckpt_path,
num_row_oov_buckets=num_class_oov_buckets, num_row_oov_buckets=num_class_oov_buckets,
num_col_oov_buckets=0, num_col_oov_buckets=0,
initializer=initializer) initializer=initializer)
def load_variable_slot_initializer(ckpt_path,
old_tensor_name,
primary_partition_info,
new_row_vocab_size,
new_col_vocab_size,
old_row_vocab_file=None,
new_row_vocab_file=None,
old_col_vocab_file=None,
new_col_vocab_file=None,
num_row_oov_buckets=0,
num_col_oov_buckets=0,
initializer=None):
"""Loads pre-trained multi-class slots for linear models from checkpoint.
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
multi-class slots (such as optimizer accumulators) and remapping them
according to the provided vocab files. See docs for
`load_and_remap_matrix_initializer()` for more details. Takes in a
`variable_scope._PartitionInfo` representing the slot's primary `Variable`'s
partitioning. This is necessary since accumulator `Variable` creation ignores
primary scoping and partitioning information.
Args:
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
from which the old matrix `Tensor` will be loaded.
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
primary_partition_info: A `variable_scope._PartitionInfo` containing this
slot's primary `Variable`'s partitioning information. This is used to
calculate the offset and override the partition_info passed to the call to
_initialize.
new_row_vocab_size: `int` specifying the number of entries in
`new_row_vocab_file`. If no row remapping is needed (no row vocab
provided), this should be equal to the number of rows to load from the old
matrix (which can theoretically be smaller than the number of rows in the
old matrix).
new_col_vocab_size: `int` specifying the number of entries in
`new_col_vocab_file`. If no column remapping is needed (no column vocab
provided), this should be equal to the number of columns in the old
matrix.
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old row vocabulary file. Can be None, which represents no
remapping on the row axis.
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new row vocabulary file. Can be None, which represents no remapping
on the row axis.
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
path to the old column vocabulary file. Can be None, which represents no
remapping on the column axis.
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
to the new column vocabulary file. Can be None, which represents no
remapping on the column axis.
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
to append. Must be >= 0.
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
columns to append. Must be >= 0.
initializer: Initializer function to initialize missing values. Accepts a
1-D tensor as the arg to specify the shape of the returned tensor. If
`None`, defaults to using `zeros_initializer()`.
Returns:
A variable initializer function that should be used to initialize a
(potentially partitioned) `Variable` whose complete shape is
`[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
num_col_oov_buckets]`.
Raises:
TypeError: If `initializer` is specified but not callable.
"""
initializer_fn = load_and_remap_matrix_initializer(
ckpt_path=ckpt_path,
old_tensor_name=old_tensor_name,
new_row_vocab_size=new_row_vocab_size,
new_col_vocab_size=new_col_vocab_size,
old_row_vocab_file=old_row_vocab_file,
new_row_vocab_file=new_row_vocab_file,
old_col_vocab_file=old_col_vocab_file,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=num_row_oov_buckets,
num_col_oov_buckets=num_col_oov_buckets,
initializer=initializer)
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
del partition_info # Unused by this override.
return initializer_fn(shape, dtype, partition_info=primary_partition_info)
return _initializer

View File

@ -90,7 +90,7 @@ class GenerateVocabRemappingTest(test.TestCase):
class LoadAndRemapMatrixTest(test.TestCase): class LoadAndRemapMatrixTest(test.TestCase):
"""Tests for the load_and_remap_weight_matrix() op.""" """Tests for the load_and_remap_matrix() op."""
def setUp(self): def setUp(self):
ops.reset_default_graph() ops.reset_default_graph()
@ -276,7 +276,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
def test_load_and_remap_matrix(self): def test_load_and_remap_matrix(self):
"""Tests the end-to-end loading / remapping of weights.""" """Tests the end-to-end loading / remapping of weights."""
# load_and_remap_matrix() is the generalized wrapper that takes in row and # _load_and_remap_matrix() is the generalized wrapper that takes in row and
# column vocabulary files, calls the relevant remappings, and returns the # column vocabulary files, calls the relevant remappings, and returns the
# weight matrix. Take this example to be linear multi-class by providing # weight matrix. Take this example to be linear multi-class by providing
# both row and column vocabularies. # both row and column vocabularies.
@ -458,7 +458,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
remapped_matrix.as_tensor().eval()) remapped_matrix.as_tensor().eval())
def test_load_embedding_initializer(self): def test_load_embedding_initializer(self):
"""Tests for the load_embedding initializer wrapper.""" """Tests for the load_embedding_initializer wrapper."""
embedding_loading_initializer = ( embedding_loading_initializer = (
contrib_framework.load_embedding_initializer( contrib_framework.load_embedding_initializer(
new_vocab_file=self.new_feature_vocab_file, new_vocab_file=self.new_feature_vocab_file,
@ -553,5 +553,96 @@ class LoadMulticlassBiasTest(test.TestCase):
remapped_bias_vector.as_tensor().eval()) remapped_bias_vector.as_tensor().eval())
class LoadVariableSlotTest(test.TestCase):
"""Tests for the load_variable_slot_initializer functionality."""
def setUp(self):
ops.reset_default_graph()
dim = 1
num = 3
with ops.name_scope('some_scope'):
# Basically from 0 to dim*num-1.
flat_data = math_ops.linspace(0.0, dim * num - 1, dim * num)
accum = variables.Variable(
array_ops.reshape(flat_data, (num, dim)), name='accum')
save = saver.Saver([accum])
with self.test_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'accum_checkpoint')
save.save(sess, self.bundle_file)
self.new_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
self.old_class_vocab_file = os.path.join(
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
self.init_val = 42
def _init_val_initializer(shape, dtype=None, partition_info=None):
del dtype, partition_info # Unused by this unit-testing initializer.
return array_ops.tile(
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
self.initializer = _init_val_initializer
def test_load_variable_slot_initializer(self):
"""Tests for the slot initializer wrapper."""
# We have an initializer for each of two partitioned variables, which will
# be [3, 1] and [2, 1]. The partitioning information is passed here in
# initializer construction, as opposed to through a variable scope during
# variable creation.
variable_slot_initializer_part_0 = (
contrib_framework.load_variable_slot_initializer(
new_row_vocab_file=self.new_class_vocab_file,
old_row_vocab_file=self.old_class_vocab_file,
new_row_vocab_size=4,
new_col_vocab_size=1,
primary_partition_info=variable_scope._PartitionInfo(
full_shape=[5, 1], var_offset=[0, 0]),
old_tensor_name='some_scope/accum',
ckpt_path=[self.bundle_file],
num_row_oov_buckets=1,
initializer=self.initializer))
variable_slot_initializer_part_1 = (
contrib_framework.load_variable_slot_initializer(
new_row_vocab_file=self.new_class_vocab_file,
old_row_vocab_file=self.old_class_vocab_file,
new_row_vocab_size=4,
new_col_vocab_size=1,
primary_partition_info=variable_scope._PartitionInfo(
full_shape=[5, 1], var_offset=[3, 0]),
old_tensor_name='some_scope/accum',
ckpt_path=[self.bundle_file],
num_row_oov_buckets=1,
initializer=self.initializer))
expected_remapped_accum_vector_part_0 = np.reshape([2, 0, self.init_val],
[3, 1])
expected_remapped_accum_vector_part_1 = np.reshape([1, self.init_val],
[2, 1])
# Since there is no variable scope here, partition_info will be None, so
# if variable_slot_initializer_part_0 and variable_slot_initializer_part_1
# were instead instances of load_and_remap_matrix_initializer, the part_0
# obtained vector would still be [2, 0, self.init_val], but the part_1
# obtained vector would be [2, 0], since the partition_info would default to
# assuming a single partition.
remapped_accum_vector_part_0 = variable_scope.get_variable(
name='accum/obtained_accum_vector_part_0',
shape=[3, 1],
initializer=variable_slot_initializer_part_0)
remapped_accum_vector_part_1 = variable_scope.get_variable(
name='accum/obtained_accum_vector_part_1',
shape=[2, 1],
initializer=variable_slot_initializer_part_1)
with self.test_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_accum_vector_part_0,
remapped_accum_vector_part_0.eval())
self.assertAllClose(expected_remapped_accum_vector_part_1,
remapped_accum_vector_part_1.eval())
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()