mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Adding a slot / accumulator warmstart initializer that overrides the provided partitioner at call time with one passed at construction time. This is intended to be used for slot Variables (such as accumulators) associated with Optimizers, since these Variables are created in a fashion that relies on replicating the exact shape of the associated primary variables (see slot_creator).
PiperOrigin-RevId: 157453498
This commit is contained in:
parent
73d10599fe
commit
8c2a079ec8
|
|
@ -77,6 +77,7 @@ See the @{$python/contrib.framework} guide.
|
|||
@@load_and_remap_matrix_initializer
|
||||
@@load_embedding_initializer
|
||||
@@load_linear_multiclass_bias_initializer
|
||||
@@load_variable_slot_initializer
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
|
|
|||
|
|
@ -488,3 +488,91 @@ def load_linear_multiclass_bias_initializer(ckpt_path,
|
|||
num_row_oov_buckets=num_class_oov_buckets,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=initializer)
|
||||
|
||||
|
||||
def load_variable_slot_initializer(ckpt_path,
|
||||
old_tensor_name,
|
||||
primary_partition_info,
|
||||
new_row_vocab_size,
|
||||
new_col_vocab_size,
|
||||
old_row_vocab_file=None,
|
||||
new_row_vocab_file=None,
|
||||
old_col_vocab_file=None,
|
||||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=0,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=None):
|
||||
"""Loads pre-trained multi-class slots for linear models from checkpoint.
|
||||
|
||||
Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
|
||||
multi-class slots (such as optimizer accumulators) and remapping them
|
||||
according to the provided vocab files. See docs for
|
||||
`load_and_remap_matrix_initializer()` for more details. Takes in a
|
||||
`variable_scope._PartitionInfo` representing the slot's primary `Variable`'s
|
||||
partitioning. This is necessary since accumulator `Variable` creation ignores
|
||||
primary scoping and partitioning information.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
|
||||
from which the old matrix `Tensor` will be loaded.
|
||||
old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
|
||||
primary_partition_info: A `variable_scope._PartitionInfo` containing this
|
||||
slot's primary `Variable`'s partitioning information. This is used to
|
||||
calculate the offset and override the partition_info passed to the call to
|
||||
_initialize.
|
||||
new_row_vocab_size: `int` specifying the number of entries in
|
||||
`new_row_vocab_file`. If no row remapping is needed (no row vocab
|
||||
provided), this should be equal to the number of rows to load from the old
|
||||
matrix (which can theoretically be smaller than the number of rows in the
|
||||
old matrix).
|
||||
new_col_vocab_size: `int` specifying the number of entries in
|
||||
`new_col_vocab_file`. If no column remapping is needed (no column vocab
|
||||
provided), this should be equal to the number of columns in the old
|
||||
matrix.
|
||||
old_row_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old row vocabulary file. Can be None, which represents no
|
||||
remapping on the row axis.
|
||||
new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new row vocabulary file. Can be None, which represents no remapping
|
||||
on the row axis.
|
||||
old_col_vocab_file: A scalar `Tensor` of type `string` containing the
|
||||
path to the old column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
|
||||
to the new column vocabulary file. Can be None, which represents no
|
||||
remapping on the column axis.
|
||||
num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
|
||||
to append. Must be >= 0.
|
||||
num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
|
||||
columns to append. Must be >= 0.
|
||||
initializer: Initializer function to initialize missing values. Accepts a
|
||||
1-D tensor as the arg to specify the shape of the returned tensor. If
|
||||
`None`, defaults to using `zeros_initializer()`.
|
||||
|
||||
Returns:
|
||||
A variable initializer function that should be used to initialize a
|
||||
(potentially partitioned) `Variable` whose complete shape is
|
||||
`[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
|
||||
num_col_oov_buckets]`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `initializer` is specified but not callable.
|
||||
"""
|
||||
initializer_fn = load_and_remap_matrix_initializer(
|
||||
ckpt_path=ckpt_path,
|
||||
old_tensor_name=old_tensor_name,
|
||||
new_row_vocab_size=new_row_vocab_size,
|
||||
new_col_vocab_size=new_col_vocab_size,
|
||||
old_row_vocab_file=old_row_vocab_file,
|
||||
new_row_vocab_file=new_row_vocab_file,
|
||||
old_col_vocab_file=old_col_vocab_file,
|
||||
new_col_vocab_file=new_col_vocab_file,
|
||||
num_row_oov_buckets=num_row_oov_buckets,
|
||||
num_col_oov_buckets=num_col_oov_buckets,
|
||||
initializer=initializer)
|
||||
|
||||
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
|
||||
del partition_info # Unused by this override.
|
||||
return initializer_fn(shape, dtype, partition_info=primary_partition_info)
|
||||
|
||||
return _initializer
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class GenerateVocabRemappingTest(test.TestCase):
|
|||
|
||||
|
||||
class LoadAndRemapMatrixTest(test.TestCase):
|
||||
"""Tests for the load_and_remap_weight_matrix() op."""
|
||||
"""Tests for the load_and_remap_matrix() op."""
|
||||
|
||||
def setUp(self):
|
||||
ops.reset_default_graph()
|
||||
|
|
@ -276,7 +276,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
|
|||
|
||||
def test_load_and_remap_matrix(self):
|
||||
"""Tests the end-to-end loading / remapping of weights."""
|
||||
# load_and_remap_matrix() is the generalized wrapper that takes in row and
|
||||
# _load_and_remap_matrix() is the generalized wrapper that takes in row and
|
||||
# column vocabulary files, calls the relevant remappings, and returns the
|
||||
# weight matrix. Take this example to be linear multi-class by providing
|
||||
# both row and column vocabularies.
|
||||
|
|
@ -458,7 +458,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
|
|||
remapped_matrix.as_tensor().eval())
|
||||
|
||||
def test_load_embedding_initializer(self):
|
||||
"""Tests for the load_embedding initializer wrapper."""
|
||||
"""Tests for the load_embedding_initializer wrapper."""
|
||||
embedding_loading_initializer = (
|
||||
contrib_framework.load_embedding_initializer(
|
||||
new_vocab_file=self.new_feature_vocab_file,
|
||||
|
|
@ -553,5 +553,96 @@ class LoadMulticlassBiasTest(test.TestCase):
|
|||
remapped_bias_vector.as_tensor().eval())
|
||||
|
||||
|
||||
class LoadVariableSlotTest(test.TestCase):
|
||||
"""Tests for the load_variable_slot_initializer functionality."""
|
||||
|
||||
def setUp(self):
|
||||
ops.reset_default_graph()
|
||||
dim = 1
|
||||
num = 3
|
||||
with ops.name_scope('some_scope'):
|
||||
# Basically from 0 to dim*num-1.
|
||||
flat_data = math_ops.linspace(0.0, dim * num - 1, dim * num)
|
||||
accum = variables.Variable(
|
||||
array_ops.reshape(flat_data, (num, dim)), name='accum')
|
||||
save = saver.Saver([accum])
|
||||
with self.test_session() as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
self.bundle_file = os.path.join(test.get_temp_dir(), 'accum_checkpoint')
|
||||
save.save(sess, self.bundle_file)
|
||||
|
||||
self.new_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt')
|
||||
self.old_class_vocab_file = os.path.join(
|
||||
test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt')
|
||||
self.init_val = 42
|
||||
|
||||
def _init_val_initializer(shape, dtype=None, partition_info=None):
|
||||
del dtype, partition_info # Unused by this unit-testing initializer.
|
||||
return array_ops.tile(
|
||||
constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
|
||||
|
||||
self.initializer = _init_val_initializer
|
||||
|
||||
def test_load_variable_slot_initializer(self):
|
||||
"""Tests for the slot initializer wrapper."""
|
||||
# We have an initializer for each of two partitioned variables, which will
|
||||
# be [3, 1] and [2, 1]. The partitioning information is passed here in
|
||||
# initializer construction, as opposed to through a variable scope during
|
||||
# variable creation.
|
||||
variable_slot_initializer_part_0 = (
|
||||
contrib_framework.load_variable_slot_initializer(
|
||||
new_row_vocab_file=self.new_class_vocab_file,
|
||||
old_row_vocab_file=self.old_class_vocab_file,
|
||||
new_row_vocab_size=4,
|
||||
new_col_vocab_size=1,
|
||||
primary_partition_info=variable_scope._PartitionInfo(
|
||||
full_shape=[5, 1], var_offset=[0, 0]),
|
||||
old_tensor_name='some_scope/accum',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_row_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
variable_slot_initializer_part_1 = (
|
||||
contrib_framework.load_variable_slot_initializer(
|
||||
new_row_vocab_file=self.new_class_vocab_file,
|
||||
old_row_vocab_file=self.old_class_vocab_file,
|
||||
new_row_vocab_size=4,
|
||||
new_col_vocab_size=1,
|
||||
primary_partition_info=variable_scope._PartitionInfo(
|
||||
full_shape=[5, 1], var_offset=[3, 0]),
|
||||
old_tensor_name='some_scope/accum',
|
||||
ckpt_path=[self.bundle_file],
|
||||
num_row_oov_buckets=1,
|
||||
initializer=self.initializer))
|
||||
|
||||
expected_remapped_accum_vector_part_0 = np.reshape([2, 0, self.init_val],
|
||||
[3, 1])
|
||||
|
||||
expected_remapped_accum_vector_part_1 = np.reshape([1, self.init_val],
|
||||
[2, 1])
|
||||
|
||||
# Since there is no variable scope here, partition_info will be None, so
|
||||
# if variable_slot_initializer_part_0 and variable_slot_initializer_part_1
|
||||
# were instead instances of load_and_remap_matrix_initializer, the part_0
|
||||
# obtained vector would still be [2, 0, self.init_val], but the part_1
|
||||
# obtained vector would be [2, 0], since the partition_info would default to
|
||||
# assuming a single partition.
|
||||
remapped_accum_vector_part_0 = variable_scope.get_variable(
|
||||
name='accum/obtained_accum_vector_part_0',
|
||||
shape=[3, 1],
|
||||
initializer=variable_slot_initializer_part_0)
|
||||
remapped_accum_vector_part_1 = variable_scope.get_variable(
|
||||
name='accum/obtained_accum_vector_part_1',
|
||||
shape=[2, 1],
|
||||
initializer=variable_slot_initializer_part_1)
|
||||
|
||||
with self.test_session():
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertAllClose(expected_remapped_accum_vector_part_0,
|
||||
remapped_accum_vector_part_0.eval())
|
||||
self.assertAllClose(expected_remapped_accum_vector_part_1,
|
||||
remapped_accum_vector_part_1.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user