mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
- Remove slice hack to properly initialize missing entries in weight matrices
- Add real support for EmbeddingColumns / input_layer() - Fix warmstarting for non-PartitionedVariables PiperOrigin-RevId: 174083777
This commit is contained in:
parent
0cddb9bcaf
commit
f1916f8f6c
|
|
@ -23,7 +23,6 @@ import six
|
||||||
|
|
||||||
from tensorflow.python.feature_column import feature_column
|
from tensorflow.python.feature_column import feature_column
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
|
@ -125,7 +124,7 @@ def _infer_var_name(var):
|
||||||
Name of the `var`
|
Name of the `var`
|
||||||
"""
|
"""
|
||||||
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var)
|
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var)
|
||||||
if len(name_to_var_dict.keys()) > 1:
|
if len(name_to_var_dict) > 1:
|
||||||
raise TypeError("`var` passed as arg violates the constraints.")
|
raise TypeError("`var` passed as arg violates the constraints.")
|
||||||
return list(name_to_var_dict.keys())[0]
|
return list(name_to_var_dict.keys())[0]
|
||||||
|
|
||||||
|
|
@ -138,26 +137,69 @@ def _warmstart_var(var, prev_ckpt, prev_tensor_name=None):
|
||||||
Can be either of the following:
|
Can be either of the following:
|
||||||
(i) `Variable`
|
(i) `Variable`
|
||||||
(ii) `ResourceVariable`
|
(ii) `ResourceVariable`
|
||||||
(iii) list of `Variable`: The list must contain slices of the same larger
|
(iii) `PartitionedVariable`
|
||||||
variable.
|
(iv) list of `Variable` and/or `PartitionedVariable`: The list may
|
||||||
(iv) `PartitionedVariable`
|
contain one or more variables that has been sharded. For example:
|
||||||
|
[Variable('a/part_0'), Variable('b/part_0'), Variable('a/part_1'),
|
||||||
|
PartitionedVariable([Variable('c/part_0'), Variable('c/part_1')])]
|
||||||
|
where we have three whole Variables represented ('a', 'b', and 'c').
|
||||||
prev_ckpt: A string specifying the directory with checkpoint file(s) or path
|
prev_ckpt: A string specifying the directory with checkpoint file(s) or path
|
||||||
to checkpoint. The given checkpoint must have tensor with name
|
to checkpoint. The given checkpoint must have tensor with name
|
||||||
`prev_tensor_name` (if not None) or tensor with name same as given `var`.
|
`prev_tensor_name` (if not None) or tensor with name same as given `var`.
|
||||||
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
|
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
|
||||||
None, we lookup tensor with same name as given `var`.
|
None, we lookup tensor with same name as given `var`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prev_tensor_name is not None, but the given var represents
|
||||||
|
more than one Variable.
|
||||||
|
TypeError: If var is not one of the allowed types.
|
||||||
"""
|
"""
|
||||||
if _is_variable(var):
|
if _is_variable(var):
|
||||||
current_var_name = _infer_var_name([var])
|
current_var_name = _infer_var_name([var])
|
||||||
elif isinstance(var, list) and all(_is_variable(v) for v in var):
|
|
||||||
current_var_name = _infer_var_name(var)
|
|
||||||
elif isinstance(var, variables.PartitionedVariable):
|
elif isinstance(var, variables.PartitionedVariable):
|
||||||
current_var_name = _infer_var_name([var])
|
current_var_name = _infer_var_name([var])
|
||||||
var = var._get_variable_list() # pylint: disable=protected-access
|
var = var._get_variable_list() # pylint: disable=protected-access
|
||||||
|
elif (isinstance(var, list) and all(
|
||||||
|
_is_variable(v) or isinstance(v, variables.PartitionedVariable)
|
||||||
|
for v in var)):
|
||||||
|
# Convert length-1 lists of vars to single tf.Variables. This ensures that
|
||||||
|
# checkpoint_utils.init_from_checkpoint() doesn't incorrectly assume
|
||||||
|
# slice info is present.
|
||||||
|
if len(var) == 1:
|
||||||
|
current_var_name = _infer_var_name(var)
|
||||||
|
var = var[0]
|
||||||
|
else:
|
||||||
|
# If we have multiple elements in var, we cannot assume they all
|
||||||
|
# represent the same Variable.
|
||||||
|
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(
|
||||||
|
var, convert_variable_to_tensor=False)
|
||||||
|
if prev_tensor_name:
|
||||||
|
# Providing a prev_tensor_name is only viable if var representes a
|
||||||
|
# single Variable.
|
||||||
|
if len(name_to_var_dict) > 1:
|
||||||
|
raise ValueError("var represented more than one Variable, but "
|
||||||
|
"prev_tensor_name was provided.")
|
||||||
|
checkpoint_utils.init_from_checkpoint(prev_ckpt, {
|
||||||
|
prev_tensor_name: var
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# OpListToDict gives us roughly what we need, but
|
||||||
|
# the values in the dict may be PartitionedVariables (which
|
||||||
|
# init_from_checkpoint does not expect) that we need to convert to
|
||||||
|
# lists.
|
||||||
|
name_to_var_dict_fixed = {}
|
||||||
|
for name, var in six.iteritems(name_to_var_dict):
|
||||||
|
if isinstance(var, variables.PartitionedVariable):
|
||||||
|
name_to_var_dict_fixed[name] = var._get_variable_list() # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
|
name_to_var_dict_fixed[name] = var
|
||||||
|
checkpoint_utils.init_from_checkpoint(prev_ckpt, name_to_var_dict_fixed)
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"var MUST be one of the following: a Variable, list of Variable or "
|
"var MUST be one of the following: a Variable, PartitionedVariable, or "
|
||||||
"PartitionedVariable, but is {}".format(type(var)))
|
"list of Variable's and/or PartitionedVariable's, but is {}".format(
|
||||||
|
type(var)))
|
||||||
if not prev_tensor_name:
|
if not prev_tensor_name:
|
||||||
# Assume tensor name remains the same.
|
# Assume tensor name remains the same.
|
||||||
prev_tensor_name = current_var_name
|
prev_tensor_name = current_var_name
|
||||||
|
|
@ -173,7 +215,8 @@ def _warmstart_var_with_vocab(var,
|
||||||
prev_ckpt,
|
prev_ckpt,
|
||||||
prev_vocab_path,
|
prev_vocab_path,
|
||||||
current_oov_buckets=0,
|
current_oov_buckets=0,
|
||||||
prev_tensor_name=None):
|
prev_tensor_name=None,
|
||||||
|
initializer=None):
|
||||||
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
|
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
|
||||||
|
|
||||||
Use this method when the `var` is backed by vocabulary. This method stitches
|
Use this method when the `var` is backed by vocabulary. This method stitches
|
||||||
|
|
@ -200,6 +243,8 @@ def _warmstart_var_with_vocab(var,
|
||||||
buckets used for given `var`.
|
buckets used for given `var`.
|
||||||
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
|
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
|
||||||
None, we lookup tensor with same name as given `var`.
|
None, we lookup tensor with same name as given `var`.
|
||||||
|
initializer: Variable initializer to be used for missing entries. If None,
|
||||||
|
missing entries will be zero-initialized.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If required args are not provided.
|
ValueError: If required args are not provided.
|
||||||
|
|
@ -232,18 +277,6 @@ def _warmstart_var_with_vocab(var,
|
||||||
full_shape=slice_info.full_shape,
|
full_shape=slice_info.full_shape,
|
||||||
var_offset=slice_info.var_offset)
|
var_offset=slice_info.var_offset)
|
||||||
|
|
||||||
# TODO(vihanjain): This is brittle. Can we instead infer actual initializer
|
|
||||||
# used originally for the variable or use a fixed initializer?
|
|
||||||
def _missing_ids_init(shape, dtype=None):
|
|
||||||
# pylint: disable=cell-var-from-loop
|
|
||||||
if dtype and dtype.base_dtype != v.dtype.base_dtype:
|
|
||||||
raise ValueError("Trying to initialize missing ids with a different "
|
|
||||||
"dtype `{}` than variable's dtype `{}`".format(
|
|
||||||
dtype, v.dtype))
|
|
||||||
return array_ops.slice(v.initial_value, [0, 0], shape)
|
|
||||||
|
|
||||||
# pylint: enable=cell-var-from-loop
|
|
||||||
|
|
||||||
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
|
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
|
||||||
# remapping too.
|
# remapping too.
|
||||||
init = checkpoint_ops._load_and_remap_matrix_initializer(
|
init = checkpoint_ops._load_and_remap_matrix_initializer(
|
||||||
|
|
@ -257,7 +290,7 @@ def _warmstart_var_with_vocab(var,
|
||||||
new_col_vocab_file=None,
|
new_col_vocab_file=None,
|
||||||
num_row_oov_buckets=current_oov_buckets,
|
num_row_oov_buckets=current_oov_buckets,
|
||||||
num_col_oov_buckets=0,
|
num_col_oov_buckets=0,
|
||||||
initializer=_missing_ids_init)
|
initializer=initializer)
|
||||||
new_init_val = ops.convert_to_tensor(
|
new_init_val = ops.convert_to_tensor(
|
||||||
init(shape=v_shape, partition_info=partition_info))
|
init(shape=v_shape, partition_info=partition_info))
|
||||||
v._initializer_op = state_ops.assign(v, new_init_val)
|
v._initializer_op = state_ops.assign(v, new_init_val)
|
||||||
|
|
@ -305,6 +338,11 @@ def _warmstart_input_layer(cols_to_vars, warmstart_settings):
|
||||||
```
|
```
|
||||||
|
|
||||||
The above example effectively warm-starts full linear model.
|
The above example effectively warm-starts full linear model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a column in cols_to_vars has an entry in
|
||||||
|
warmstart_settings.cols_to_prev_vocab, but is not an instance of
|
||||||
|
_VocabularyFileCategoricalColumn or _EmbeddingColumn.
|
||||||
"""
|
"""
|
||||||
for col, var in six.iteritems(cols_to_vars):
|
for col, var in six.iteritems(cols_to_vars):
|
||||||
if not isinstance(col, feature_column._FeatureColumn): # pylint: disable=protected-access
|
if not isinstance(col, feature_column._FeatureColumn): # pylint: disable=protected-access
|
||||||
|
|
@ -316,21 +354,43 @@ def _warmstart_input_layer(cols_to_vars, warmstart_settings):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prev_tensor_name = warmstart_settings.col_to_prev_tensor.get(col)
|
prev_tensor_name = warmstart_settings.col_to_prev_tensor.get(col)
|
||||||
if isinstance(col, feature_column._VocabularyFileCategoricalColumn): # pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
is_sparse_vocab_column = isinstance(
|
||||||
|
col, feature_column._VocabularyFileCategoricalColumn)
|
||||||
|
is_embedding_vocab_column = (
|
||||||
|
isinstance(col, feature_column._EmbeddingColumn) and
|
||||||
|
isinstance(col.categorical_column,
|
||||||
|
feature_column._VocabularyFileCategoricalColumn))
|
||||||
|
if is_sparse_vocab_column or is_embedding_vocab_column:
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
initializer = None
|
||||||
|
if is_embedding_vocab_column:
|
||||||
|
initializer = col.initializer
|
||||||
|
vocabulary_file = col.categorical_column.vocabulary_file
|
||||||
|
vocabulary_size = col.categorical_column.vocabulary_size
|
||||||
|
num_oov_buckets = col.categorical_column.num_oov_buckets
|
||||||
|
else:
|
||||||
|
vocabulary_file = col.vocabulary_file
|
||||||
|
vocabulary_size = col.vocabulary_size
|
||||||
|
num_oov_buckets = col.num_oov_buckets
|
||||||
prev_vocab_path = warmstart_settings.col_to_prev_vocab.get(
|
prev_vocab_path = warmstart_settings.col_to_prev_vocab.get(
|
||||||
col, col.vocabulary_file)
|
col, vocabulary_file)
|
||||||
logging.info("Warm-starting column: {}; prev_vocab: {}; prev_tensor: {}".
|
logging.info("Warm-starting column: {}; prev_vocab: {}; prev_tensor: {}".
|
||||||
format(col.name, prev_vocab_path, (
|
format(col.name, prev_vocab_path, (
|
||||||
prev_tensor_name or "Unchanged")))
|
prev_tensor_name or "Unchanged")))
|
||||||
_warmstart_var_with_vocab(
|
_warmstart_var_with_vocab(
|
||||||
var,
|
var,
|
||||||
current_vocab_path=col.vocabulary_file,
|
current_vocab_path=vocabulary_file,
|
||||||
current_vocab_size=col.vocabulary_size,
|
current_vocab_size=vocabulary_size,
|
||||||
prev_ckpt=warmstart_settings.ckpt_to_initialize_from,
|
prev_ckpt=warmstart_settings.ckpt_to_initialize_from,
|
||||||
prev_vocab_path=prev_vocab_path,
|
prev_vocab_path=prev_vocab_path,
|
||||||
current_oov_buckets=col.num_oov_buckets,
|
current_oov_buckets=num_oov_buckets,
|
||||||
prev_tensor_name=prev_tensor_name)
|
prev_tensor_name=prev_tensor_name,
|
||||||
|
initializer=initializer)
|
||||||
else:
|
else:
|
||||||
|
if col in warmstart_settings.col_to_prev_vocab:
|
||||||
|
raise ValueError("Vocabulary provided for column %s which is not a "
|
||||||
|
"_VocabularyFileCategoricalColumn or _EmbeddingColumn")
|
||||||
logging.info("Warm-starting column: {}; prev_tensor: {}".format(
|
logging.info("Warm-starting column: {}; prev_tensor: {}".format(
|
||||||
col.name, prev_tensor_name or "Unchanged"))
|
col.name, prev_tensor_name or "Unchanged"))
|
||||||
_warmstart_var(var, warmstart_settings.ckpt_to_initialize_from,
|
_warmstart_var(var, warmstart_settings.ckpt_to_initialize_from,
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,36 @@ class WarmStartingUtilTest(test.TestCase):
|
||||||
var = var._get_variable_list()
|
var = var._get_variable_list()
|
||||||
return var, sess.run(var)
|
return var, sess.run(var)
|
||||||
|
|
||||||
|
def _create_prev_run_multiple_vars(self,
|
||||||
|
var_names,
|
||||||
|
initializers,
|
||||||
|
shapes=None,
|
||||||
|
partitioners=None):
|
||||||
|
if not shapes:
|
||||||
|
shapes = [None] * len(var_names)
|
||||||
|
if not partitioners:
|
||||||
|
partitioners = [None] * len(var_names)
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
var_list = []
|
||||||
|
for var_name, shape, initializer, partitioner in zip(
|
||||||
|
var_names, shapes, initializers, partitioners):
|
||||||
|
var_list.append(
|
||||||
|
variable_scope.get_variable(
|
||||||
|
var_name,
|
||||||
|
shape=shape,
|
||||||
|
initializer=initializer,
|
||||||
|
partitioner=partitioner))
|
||||||
|
self._write_checkpoint(sess)
|
||||||
|
run_vars = []
|
||||||
|
for var, partitioner in zip(var_list, partitioners):
|
||||||
|
if partitioner:
|
||||||
|
self.assertTrue(isinstance(var, variables.PartitionedVariable))
|
||||||
|
run_vars.append(sess.run(var._get_variable_list()))
|
||||||
|
else:
|
||||||
|
run_vars.append(sess.run(var))
|
||||||
|
return var_list, run_vars
|
||||||
|
|
||||||
def _create_dummy_inputs(self):
|
def _create_dummy_inputs(self):
|
||||||
return {
|
return {
|
||||||
"sc_int": array_ops.sparse_placeholder(dtypes.int32),
|
"sc_int": array_ops.sparse_placeholder(dtypes.int32),
|
||||||
|
|
@ -98,7 +128,7 @@ class WarmStartingUtilTest(test.TestCase):
|
||||||
def _assert_cols_to_vars(self, cols_to_vars, cols_to_expected_values, sess):
|
def _assert_cols_to_vars(self, cols_to_vars, cols_to_expected_values, sess):
|
||||||
for col, expected_values in six.iteritems(cols_to_expected_values):
|
for col, expected_values in six.iteritems(cols_to_expected_values):
|
||||||
for i, var in enumerate(cols_to_vars[col]):
|
for i, var in enumerate(cols_to_vars[col]):
|
||||||
self.assertAllEqual(expected_values[i], var.eval(sess))
|
self.assertAllClose(expected_values[i], var.eval(sess))
|
||||||
|
|
||||||
def testWarmStartVar(self):
|
def testWarmStartVar(self):
|
||||||
_, prev_val = self._create_prev_run_var(
|
_, prev_val = self._create_prev_run_var(
|
||||||
|
|
@ -175,6 +205,99 @@ class WarmStartingUtilTest(test.TestCase):
|
||||||
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
|
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
|
||||||
self.assertAllEqual(prev_val, new_val)
|
self.assertAllEqual(prev_val, new_val)
|
||||||
|
|
||||||
|
def testWarmStartVarMultipleVars(self):
|
||||||
|
_, prev_vals = self._create_prev_run_multiple_vars(
|
||||||
|
var_names=["fruit_weights", "other_weights"],
|
||||||
|
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]]])
|
||||||
|
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
fruit_weights = variable_scope.get_variable(
|
||||||
|
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
|
||||||
|
other_weights = variable_scope.get_variable(
|
||||||
|
"other_weights", initializer=[[0.], [0.], [0.], [0.]])
|
||||||
|
ws_util._warmstart_var([fruit_weights, other_weights],
|
||||||
|
self.get_temp_dir())
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
|
||||||
|
self.assertAllEqual(prev_vals[1], other_weights.eval(sess))
|
||||||
|
|
||||||
|
def testWarmStartVarMultipleVarsBothPartitioned(self):
|
||||||
|
_, prev_vals = self._create_prev_run_multiple_vars(
|
||||||
|
var_names=["fruit_weights", "other_weights"],
|
||||||
|
shapes=[[4, 1], [4, 1]],
|
||||||
|
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]]],
|
||||||
|
partitioners=[lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]])
|
||||||
|
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
fruit_weights = variable_scope.get_variable(
|
||||||
|
"fruit_weights",
|
||||||
|
shape=[4, 1],
|
||||||
|
initializer=[[0.], [0.], [0.], [0.]],
|
||||||
|
partitioner=lambda shape, dtype: [2, 1])
|
||||||
|
other_weights = variable_scope.get_variable(
|
||||||
|
"other_weights",
|
||||||
|
shape=[4, 1],
|
||||||
|
initializer=[[0.], [0.], [0.], [0.]],
|
||||||
|
partitioner=lambda shape, dtype: [2, 1])
|
||||||
|
ws_util._warmstart_var([fruit_weights, other_weights],
|
||||||
|
self.get_temp_dir())
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
fruit_weights = fruit_weights._get_variable_list()
|
||||||
|
new_fruit_weights_val = np.concatenate(
|
||||||
|
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
|
||||||
|
other_weights = other_weights._get_variable_list()
|
||||||
|
new_other_weights_val = np.concatenate(
|
||||||
|
[other_weights[0].eval(sess), other_weights[1].eval(sess)], axis=0)
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.concatenate(prev_vals[0], axis=0), new_fruit_weights_val)
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.concatenate(prev_vals[1], axis=0), new_other_weights_val)
|
||||||
|
|
||||||
|
def testWarmStartVarMultipleVarsMixOfPartitions(self):
|
||||||
|
# First is not partitioned, but the second two are.
|
||||||
|
_, prev_vals = self._create_prev_run_multiple_vars(
|
||||||
|
var_names=["fruit_weights", "other_weights", "veggie_weights"],
|
||||||
|
shapes=[None, [4, 1], [4, 1]],
|
||||||
|
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]],
|
||||||
|
[[5.], [10.], [15.], [20.]]],
|
||||||
|
partitioners=[
|
||||||
|
None, lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]
|
||||||
|
])
|
||||||
|
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
fruit_weights = variable_scope.get_variable(
|
||||||
|
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
|
||||||
|
other_weights = variable_scope.get_variable(
|
||||||
|
"other_weights",
|
||||||
|
shape=[4, 1],
|
||||||
|
initializer=[[0.], [0.], [0.], [0.]],
|
||||||
|
partitioner=lambda shape, dtype: [2, 1])
|
||||||
|
veggie_weights = variable_scope.get_variable(
|
||||||
|
"veggie_weights",
|
||||||
|
shape=[4, 1],
|
||||||
|
initializer=[[0.], [0.], [0.], [0.]],
|
||||||
|
partitioner=lambda shape, dtype: [2, 1])
|
||||||
|
# Flatten one of the partitioned variables.
|
||||||
|
ws_util._warmstart_var([fruit_weights, other_weights] +
|
||||||
|
veggie_weights._get_variable_list(),
|
||||||
|
self.get_temp_dir())
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
veggie_weights = veggie_weights._get_variable_list()
|
||||||
|
new_veggie_weights_val = np.concatenate(
|
||||||
|
[veggie_weights[0].eval(sess), veggie_weights[1].eval(sess)],
|
||||||
|
axis=0)
|
||||||
|
other_weights = other_weights._get_variable_list()
|
||||||
|
new_other_weights_val = np.concatenate(
|
||||||
|
[other_weights[0].eval(sess), other_weights[1].eval(sess)], axis=0)
|
||||||
|
self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.concatenate(prev_vals[1], axis=0), new_other_weights_val)
|
||||||
|
self.assertAllEqual(
|
||||||
|
np.concatenate(prev_vals[2], axis=0), new_veggie_weights_val)
|
||||||
|
|
||||||
def testWarmStartVarWithVocab(self):
|
def testWarmStartVarWithVocab(self):
|
||||||
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||||
"old_vocab")
|
"old_vocab")
|
||||||
|
|
@ -558,6 +681,66 @@ class WarmStartingUtilTest(test.TestCase):
|
||||||
]
|
]
|
||||||
}, sess)
|
}, sess)
|
||||||
|
|
||||||
|
def testWarmStartInputLayerEmbeddingColumn(self):
|
||||||
|
# Create old and new vocabs for embedding column "sc_vocab".
|
||||||
|
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||||
|
"old_vocab")
|
||||||
|
new_vocab_path = self._write_vocab(
|
||||||
|
["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
|
||||||
|
"new_vocab")
|
||||||
|
|
||||||
|
# Save checkpoint from which to warm-start.
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
_ = variable_scope.get_variable(
|
||||||
|
"input_layer/sc_vocab_embedding/embedding_weights",
|
||||||
|
initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
|
||||||
|
self._write_checkpoint(sess)
|
||||||
|
|
||||||
|
def _partitioner(shape, dtype): # pylint:disable=unused-argument
|
||||||
|
# Partition each var into 2 equal slices.
|
||||||
|
partitions = [1] * len(shape)
|
||||||
|
partitions[0] = min(2, shape[0].value)
|
||||||
|
return partitions
|
||||||
|
|
||||||
|
# Create feature columns.
|
||||||
|
sc_vocab = fc.categorical_column_with_vocabulary_file(
|
||||||
|
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
|
||||||
|
emb_vocab = fc.embedding_column(
|
||||||
|
categorical_column=sc_vocab,
|
||||||
|
dimension=2,
|
||||||
|
# Can't use constant_initializer with load_and_remap. In practice,
|
||||||
|
# use a truncated normal initializer.
|
||||||
|
initializer=init_ops.random_uniform_initializer(
|
||||||
|
minval=0.42, maxval=0.42))
|
||||||
|
all_deep_cols = [emb_vocab]
|
||||||
|
# New graph, new session with warmstarting.
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
cols_to_vars = {}
|
||||||
|
with variable_scope.variable_scope("", partitioner=_partitioner):
|
||||||
|
# Create the variables.
|
||||||
|
fc.input_layer(
|
||||||
|
features=self._create_dummy_inputs(),
|
||||||
|
feature_columns=all_deep_cols,
|
||||||
|
cols_to_vars=cols_to_vars)
|
||||||
|
ws_settings = ws_util._WarmStartSettings(
|
||||||
|
self.get_temp_dir(), col_to_prev_vocab={
|
||||||
|
emb_vocab: prev_vocab_path
|
||||||
|
})
|
||||||
|
ws_util._warmstart_input_layer(cols_to_vars, ws_settings)
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
# Verify weights were correctly warmstarted. Var corresponding to
|
||||||
|
# emb_vocab should be correctly warmstarted after vocab remapping.
|
||||||
|
# Missing values are filled in with the EmbeddingColumn's initializer.
|
||||||
|
self._assert_cols_to_vars(
|
||||||
|
cols_to_vars, {
|
||||||
|
emb_vocab: [
|
||||||
|
np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
|
||||||
|
np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
|
||||||
|
]
|
||||||
|
}, sess)
|
||||||
|
|
||||||
def testErrorConditions(self):
|
def testErrorConditions(self):
|
||||||
self.assertRaises(ValueError, ws_util._WarmStartSettings, None)
|
self.assertRaises(ValueError, ws_util._WarmStartSettings, None)
|
||||||
x = variable_scope.get_variable(
|
x = variable_scope.get_variable(
|
||||||
|
|
@ -566,8 +749,7 @@ class WarmStartingUtilTest(test.TestCase):
|
||||||
initializer=ones(),
|
initializer=ones(),
|
||||||
partitioner=lambda shape, dtype: [2, 1])
|
partitioner=lambda shape, dtype: [2, 1])
|
||||||
|
|
||||||
# List of PartitionedVariable is invalid type.
|
# List of PartitionedVariable is invalid type when warmstarting with vocab.
|
||||||
self.assertRaises(TypeError, ws_util._warmstart_var, [x], prev_ckpt="/tmp")
|
|
||||||
self.assertRaises(TypeError, ws_util._warmstart_var_with_vocab, [x], "/tmp",
|
self.assertRaises(TypeError, ws_util._warmstart_var_with_vocab, [x], "/tmp",
|
||||||
5, "/tmp", "/tmp")
|
5, "/tmp", "/tmp")
|
||||||
# Keys of type other than FeatureColumn.
|
# Keys of type other than FeatureColumn.
|
||||||
|
|
|
||||||
|
|
@ -503,11 +503,13 @@ class BaseSaverBuilder(object):
|
||||||
return sorted(per_device.items(), key=lambda t: t[0])
|
return sorted(per_device.items(), key=lambda t: t[0])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def OpListToDict(op_list):
|
def OpListToDict(op_list, convert_variable_to_tensor=True):
|
||||||
"""Create a dictionary of names to operation lists.
|
"""Create a dictionary of names to operation lists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op_list: A list, tuple, or set of Variables or SaveableObjects.
|
op_list: A list, tuple, or set of Variables or SaveableObjects.
|
||||||
|
convert_variable_to_tensor: Whether or not to convert single Variables
|
||||||
|
with no slice info into Tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary of names to the operations that must be saved under
|
A dictionary of names to the operations that must be saved under
|
||||||
|
|
@ -543,9 +545,10 @@ class BaseSaverBuilder(object):
|
||||||
names_to_saveables[name] = [var]
|
names_to_saveables[name] = [var]
|
||||||
else:
|
else:
|
||||||
if context.in_graph_mode():
|
if context.in_graph_mode():
|
||||||
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
if convert_variable_to_tensor:
|
||||||
if not BaseSaverBuilder._IsVariable(var):
|
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
||||||
raise TypeError("Variable to save is not a Variable: %s" % var)
|
if not BaseSaverBuilder._IsVariable(var):
|
||||||
|
raise TypeError("Variable to save is not a Variable: %s" % var)
|
||||||
if var.op.type == "ReadVariableOp":
|
if var.op.type == "ReadVariableOp":
|
||||||
name = var.op.inputs[0].op.name
|
name = var.op.inputs[0].op.name
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user