- Remove slice hack to properly initialize missing entries in weight matrices

- Add real support for EmbeddingColumns / input_layer()
- Fix warmstarting for non-PartitionedVariables

PiperOrigin-RevId: 174083777
This commit is contained in:
A. Unique TensorFlower 2017-10-31 13:36:05 -07:00 committed by TensorFlower Gardener
parent 0cddb9bcaf
commit f1916f8f6c
3 changed files with 281 additions and 36 deletions

View File

@ -23,7 +23,6 @@ import six
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
@ -125,7 +124,7 @@ def _infer_var_name(var):
Name of the `var`
"""
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var)
if len(name_to_var_dict.keys()) > 1:
if len(name_to_var_dict) > 1:
raise TypeError("`var` passed as arg violates the constraints.")
return list(name_to_var_dict.keys())[0]
@ -138,26 +137,69 @@ def _warmstart_var(var, prev_ckpt, prev_tensor_name=None):
Can be either of the following:
(i) `Variable`
(ii) `ResourceVariable`
(iii) list of `Variable`: The list must contain slices of the same larger
variable.
(iv) `PartitionedVariable`
(iii) `PartitionedVariable`
(iv) list of `Variable` and/or `PartitionedVariable`: The list may
contain one or more variables that has been sharded. For example:
[Variable('a/part_0'), Variable('b/part_0'), Variable('a/part_1'),
PartitionedVariable([Variable('c/part_0'), Variable('c/part_1')])]
where we have three whole Variables represented ('a', 'b', and 'c').
prev_ckpt: A string specifying the directory with checkpoint file(s) or path
to checkpoint. The given checkpoint must have tensor with name
`prev_tensor_name` (if not None) or tensor with name same as given `var`.
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
None, we lookup tensor with same name as given `var`.
Raises:
ValueError: If prev_tensor_name is not None, but the given var represents
more than one Variable.
TypeError: If var is not one of the allowed types.
"""
if _is_variable(var):
current_var_name = _infer_var_name([var])
elif isinstance(var, list) and all(_is_variable(v) for v in var):
current_var_name = _infer_var_name(var)
elif isinstance(var, variables.PartitionedVariable):
current_var_name = _infer_var_name([var])
var = var._get_variable_list() # pylint: disable=protected-access
elif (isinstance(var, list) and all(
_is_variable(v) or isinstance(v, variables.PartitionedVariable)
for v in var)):
# Convert length-1 lists of vars to single tf.Variables. This ensures that
# checkpoint_utils.init_from_checkpoint() doesn't incorrectly assume
# slice info is present.
if len(var) == 1:
current_var_name = _infer_var_name(var)
var = var[0]
else:
# If we have multiple elements in var, we cannot assume they all
# represent the same Variable.
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(
var, convert_variable_to_tensor=False)
if prev_tensor_name:
# Providing a prev_tensor_name is only viable if var representes a
# single Variable.
if len(name_to_var_dict) > 1:
raise ValueError("var represented more than one Variable, but "
"prev_tensor_name was provided.")
checkpoint_utils.init_from_checkpoint(prev_ckpt, {
prev_tensor_name: var
})
else:
# OpListToDict gives us roughly what we need, but
# the values in the dict may be PartitionedVariables (which
# init_from_checkpoint does not expect) that we need to convert to
# lists.
name_to_var_dict_fixed = {}
for name, var in six.iteritems(name_to_var_dict):
if isinstance(var, variables.PartitionedVariable):
name_to_var_dict_fixed[name] = var._get_variable_list() # pylint: disable=protected-access
else:
name_to_var_dict_fixed[name] = var
checkpoint_utils.init_from_checkpoint(prev_ckpt, name_to_var_dict_fixed)
return
else:
raise TypeError(
"var MUST be one of the following: a Variable, list of Variable or "
"PartitionedVariable, but is {}".format(type(var)))
"var MUST be one of the following: a Variable, PartitionedVariable, or "
"list of Variable's and/or PartitionedVariable's, but is {}".format(
type(var)))
if not prev_tensor_name:
# Assume tensor name remains the same.
prev_tensor_name = current_var_name
@ -173,7 +215,8 @@ def _warmstart_var_with_vocab(var,
prev_ckpt,
prev_vocab_path,
current_oov_buckets=0,
prev_tensor_name=None):
prev_tensor_name=None,
initializer=None):
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
Use this method when the `var` is backed by vocabulary. This method stitches
@ -200,6 +243,8 @@ def _warmstart_var_with_vocab(var,
buckets used for given `var`.
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
None, we lookup tensor with same name as given `var`.
initializer: Variable initializer to be used for missing entries. If None,
missing entries will be zero-initialized.
Raises:
ValueError: If required args are not provided.
@ -232,18 +277,6 @@ def _warmstart_var_with_vocab(var,
full_shape=slice_info.full_shape,
var_offset=slice_info.var_offset)
# TODO(vihanjain): This is brittle. Can we instead infer actual initializer
# used originally for the variable or use a fixed initializer?
def _missing_ids_init(shape, dtype=None):
# pylint: disable=cell-var-from-loop
if dtype and dtype.base_dtype != v.dtype.base_dtype:
raise ValueError("Trying to initialize missing ids with a different "
"dtype `{}` than variable's dtype `{}`".format(
dtype, v.dtype))
return array_ops.slice(v.initial_value, [0, 0], shape)
# pylint: enable=cell-var-from-loop
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
# remapping too.
init = checkpoint_ops._load_and_remap_matrix_initializer(
@ -257,7 +290,7 @@ def _warmstart_var_with_vocab(var,
new_col_vocab_file=None,
num_row_oov_buckets=current_oov_buckets,
num_col_oov_buckets=0,
initializer=_missing_ids_init)
initializer=initializer)
new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info))
v._initializer_op = state_ops.assign(v, new_init_val)
@ -305,6 +338,11 @@ def _warmstart_input_layer(cols_to_vars, warmstart_settings):
```
The above example effectively warm-starts full linear model.
Raises:
ValueError: If a column in cols_to_vars has an entry in
warmstart_settings.cols_to_prev_vocab, but is not an instance of
_VocabularyFileCategoricalColumn or _EmbeddingColumn.
"""
for col, var in six.iteritems(cols_to_vars):
if not isinstance(col, feature_column._FeatureColumn): # pylint: disable=protected-access
@ -316,21 +354,43 @@ def _warmstart_input_layer(cols_to_vars, warmstart_settings):
continue
prev_tensor_name = warmstart_settings.col_to_prev_tensor.get(col)
if isinstance(col, feature_column._VocabularyFileCategoricalColumn): # pylint: disable=protected-access
# pylint: disable=protected-access
is_sparse_vocab_column = isinstance(
col, feature_column._VocabularyFileCategoricalColumn)
is_embedding_vocab_column = (
isinstance(col, feature_column._EmbeddingColumn) and
isinstance(col.categorical_column,
feature_column._VocabularyFileCategoricalColumn))
if is_sparse_vocab_column or is_embedding_vocab_column:
# pylint: enable=protected-access
initializer = None
if is_embedding_vocab_column:
initializer = col.initializer
vocabulary_file = col.categorical_column.vocabulary_file
vocabulary_size = col.categorical_column.vocabulary_size
num_oov_buckets = col.categorical_column.num_oov_buckets
else:
vocabulary_file = col.vocabulary_file
vocabulary_size = col.vocabulary_size
num_oov_buckets = col.num_oov_buckets
prev_vocab_path = warmstart_settings.col_to_prev_vocab.get(
col, col.vocabulary_file)
col, vocabulary_file)
logging.info("Warm-starting column: {}; prev_vocab: {}; prev_tensor: {}".
format(col.name, prev_vocab_path, (
prev_tensor_name or "Unchanged")))
_warmstart_var_with_vocab(
var,
current_vocab_path=col.vocabulary_file,
current_vocab_size=col.vocabulary_size,
current_vocab_path=vocabulary_file,
current_vocab_size=vocabulary_size,
prev_ckpt=warmstart_settings.ckpt_to_initialize_from,
prev_vocab_path=prev_vocab_path,
current_oov_buckets=col.num_oov_buckets,
prev_tensor_name=prev_tensor_name)
current_oov_buckets=num_oov_buckets,
prev_tensor_name=prev_tensor_name,
initializer=initializer)
else:
if col in warmstart_settings.col_to_prev_vocab:
raise ValueError("Vocabulary provided for column %s which is not a "
"_VocabularyFileCategoricalColumn or _EmbeddingColumn")
logging.info("Warm-starting column: {}; prev_tensor: {}".format(
col.name, prev_tensor_name or "Unchanged"))
_warmstart_var(var, warmstart_settings.ckpt_to_initialize_from,

View File

@ -72,6 +72,36 @@ class WarmStartingUtilTest(test.TestCase):
var = var._get_variable_list()
return var, sess.run(var)
def _create_prev_run_multiple_vars(self,
var_names,
initializers,
shapes=None,
partitioners=None):
if not shapes:
shapes = [None] * len(var_names)
if not partitioners:
partitioners = [None] * len(var_names)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
var_list = []
for var_name, shape, initializer, partitioner in zip(
var_names, shapes, initializers, partitioners):
var_list.append(
variable_scope.get_variable(
var_name,
shape=shape,
initializer=initializer,
partitioner=partitioner))
self._write_checkpoint(sess)
run_vars = []
for var, partitioner in zip(var_list, partitioners):
if partitioner:
self.assertTrue(isinstance(var, variables.PartitionedVariable))
run_vars.append(sess.run(var._get_variable_list()))
else:
run_vars.append(sess.run(var))
return var_list, run_vars
def _create_dummy_inputs(self):
return {
"sc_int": array_ops.sparse_placeholder(dtypes.int32),
@ -98,7 +128,7 @@ class WarmStartingUtilTest(test.TestCase):
def _assert_cols_to_vars(self, cols_to_vars, cols_to_expected_values, sess):
for col, expected_values in six.iteritems(cols_to_expected_values):
for i, var in enumerate(cols_to_vars[col]):
self.assertAllEqual(expected_values[i], var.eval(sess))
self.assertAllClose(expected_values[i], var.eval(sess))
def testWarmStartVar(self):
_, prev_val = self._create_prev_run_var(
@ -175,6 +205,99 @@ class WarmStartingUtilTest(test.TestCase):
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
self.assertAllEqual(prev_val, new_val)
def testWarmStartVarMultipleVars(self):
_, prev_vals = self._create_prev_run_multiple_vars(
var_names=["fruit_weights", "other_weights"],
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]]])
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
other_weights = variable_scope.get_variable(
"other_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warmstart_var([fruit_weights, other_weights],
self.get_temp_dir())
sess.run(variables.global_variables_initializer())
self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
self.assertAllEqual(prev_vals[1], other_weights.eval(sess))
def testWarmStartVarMultipleVarsBothPartitioned(self):
_, prev_vals = self._create_prev_run_multiple_vars(
var_names=["fruit_weights", "other_weights"],
shapes=[[4, 1], [4, 1]],
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]]],
partitioners=[lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]])
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights",
shape=[4, 1],
initializer=[[0.], [0.], [0.], [0.]],
partitioner=lambda shape, dtype: [2, 1])
other_weights = variable_scope.get_variable(
"other_weights",
shape=[4, 1],
initializer=[[0.], [0.], [0.], [0.]],
partitioner=lambda shape, dtype: [2, 1])
ws_util._warmstart_var([fruit_weights, other_weights],
self.get_temp_dir())
sess.run(variables.global_variables_initializer())
fruit_weights = fruit_weights._get_variable_list()
new_fruit_weights_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
other_weights = other_weights._get_variable_list()
new_other_weights_val = np.concatenate(
[other_weights[0].eval(sess), other_weights[1].eval(sess)], axis=0)
self.assertAllEqual(
np.concatenate(prev_vals[0], axis=0), new_fruit_weights_val)
self.assertAllEqual(
np.concatenate(prev_vals[1], axis=0), new_other_weights_val)
def testWarmStartVarMultipleVarsMixOfPartitions(self):
# First is not partitioned, but the second two are.
_, prev_vals = self._create_prev_run_multiple_vars(
var_names=["fruit_weights", "other_weights", "veggie_weights"],
shapes=[None, [4, 1], [4, 1]],
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]],
[[5.], [10.], [15.], [20.]]],
partitioners=[
None, lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]
])
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
other_weights = variable_scope.get_variable(
"other_weights",
shape=[4, 1],
initializer=[[0.], [0.], [0.], [0.]],
partitioner=lambda shape, dtype: [2, 1])
veggie_weights = variable_scope.get_variable(
"veggie_weights",
shape=[4, 1],
initializer=[[0.], [0.], [0.], [0.]],
partitioner=lambda shape, dtype: [2, 1])
# Flatten one of the partitioned variables.
ws_util._warmstart_var([fruit_weights, other_weights] +
veggie_weights._get_variable_list(),
self.get_temp_dir())
sess.run(variables.global_variables_initializer())
veggie_weights = veggie_weights._get_variable_list()
new_veggie_weights_val = np.concatenate(
[veggie_weights[0].eval(sess), veggie_weights[1].eval(sess)],
axis=0)
other_weights = other_weights._get_variable_list()
new_other_weights_val = np.concatenate(
[other_weights[0].eval(sess), other_weights[1].eval(sess)], axis=0)
self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
self.assertAllEqual(
np.concatenate(prev_vals[1], axis=0), new_other_weights_val)
self.assertAllEqual(
np.concatenate(prev_vals[2], axis=0), new_veggie_weights_val)
def testWarmStartVarWithVocab(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@ -558,6 +681,66 @@ class WarmStartingUtilTest(test.TestCase):
]
}, sess)
def testWarmStartInputLayerEmbeddingColumn(self):
# Create old and new vocabs for embedding column "sc_vocab".
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
new_vocab_path = self._write_vocab(
["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
"new_vocab")
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
_ = variable_scope.get_variable(
"input_layer/sc_vocab_embedding/embedding_weights",
initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
self._write_checkpoint(sess)
def _partitioner(shape, dtype): # pylint:disable=unused-argument
# Partition each var into 2 equal slices.
partitions = [1] * len(shape)
partitions[0] = min(2, shape[0].value)
return partitions
# Create feature columns.
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
emb_vocab = fc.embedding_column(
categorical_column=sc_vocab,
dimension=2,
# Can't use constant_initializer with load_and_remap. In practice,
# use a truncated normal initializer.
initializer=init_ops.random_uniform_initializer(
minval=0.42, maxval=0.42))
all_deep_cols = [emb_vocab]
# New graph, new session with warmstarting.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
cols_to_vars = {}
with variable_scope.variable_scope("", partitioner=_partitioner):
# Create the variables.
fc.input_layer(
features=self._create_dummy_inputs(),
feature_columns=all_deep_cols,
cols_to_vars=cols_to_vars)
ws_settings = ws_util._WarmStartSettings(
self.get_temp_dir(), col_to_prev_vocab={
emb_vocab: prev_vocab_path
})
ws_util._warmstart_input_layer(cols_to_vars, ws_settings)
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warmstarted. Var corresponding to
# emb_vocab should be correctly warmstarted after vocab remapping.
# Missing values are filled in with the EmbeddingColumn's initializer.
self._assert_cols_to_vars(
cols_to_vars, {
emb_vocab: [
np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
]
}, sess)
def testErrorConditions(self):
self.assertRaises(ValueError, ws_util._WarmStartSettings, None)
x = variable_scope.get_variable(
@ -566,8 +749,7 @@ class WarmStartingUtilTest(test.TestCase):
initializer=ones(),
partitioner=lambda shape, dtype: [2, 1])
# List of PartitionedVariable is invalid type.
self.assertRaises(TypeError, ws_util._warmstart_var, [x], prev_ckpt="/tmp")
# List of PartitionedVariable is invalid type when warmstarting with vocab.
self.assertRaises(TypeError, ws_util._warmstart_var_with_vocab, [x], "/tmp",
5, "/tmp", "/tmp")
# Keys of type other than FeatureColumn.

View File

@ -503,11 +503,13 @@ class BaseSaverBuilder(object):
return sorted(per_device.items(), key=lambda t: t[0])
@staticmethod
def OpListToDict(op_list):
def OpListToDict(op_list, convert_variable_to_tensor=True):
"""Create a dictionary of names to operation lists.
Args:
op_list: A list, tuple, or set of Variables or SaveableObjects.
convert_variable_to_tensor: Whether or not to convert single Variables
with no slice info into Tensors.
Returns:
A dictionary of names to the operations that must be saved under
@ -543,9 +545,10 @@ class BaseSaverBuilder(object):
names_to_saveables[name] = [var]
else:
if context.in_graph_mode():
var = ops.internal_convert_to_tensor(var, as_ref=True)
if not BaseSaverBuilder._IsVariable(var):
raise TypeError("Variable to save is not a Variable: %s" % var)
if convert_variable_to_tensor:
var = ops.internal_convert_to_tensor(var, as_ref=True)
if not BaseSaverBuilder._IsVariable(var):
raise TypeError("Variable to save is not a Variable: %s" % var)
if var.op.type == "ReadVariableOp":
name = var.op.inputs[0].op.name
else: