mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
- Remove slice hack to properly initialize missing entries in weight matrices
- Add real support for EmbeddingColumns / input_layer() - Fix warmstarting for non-PartitionedVariables PiperOrigin-RevId: 174083777
This commit is contained in:
parent
0cddb9bcaf
commit
f1916f8f6c
|
|
@ -23,7 +23,6 @@ import six
|
|||
|
||||
from tensorflow.python.feature_column import feature_column
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
|
@ -125,7 +124,7 @@ def _infer_var_name(var):
|
|||
Name of the `var`
|
||||
"""
|
||||
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(var)
|
||||
if len(name_to_var_dict.keys()) > 1:
|
||||
if len(name_to_var_dict) > 1:
|
||||
raise TypeError("`var` passed as arg violates the constraints.")
|
||||
return list(name_to_var_dict.keys())[0]
|
||||
|
||||
|
|
@ -138,26 +137,69 @@ def _warmstart_var(var, prev_ckpt, prev_tensor_name=None):
|
|||
Can be either of the following:
|
||||
(i) `Variable`
|
||||
(ii) `ResourceVariable`
|
||||
(iii) list of `Variable`: The list must contain slices of the same larger
|
||||
variable.
|
||||
(iv) `PartitionedVariable`
|
||||
(iii) `PartitionedVariable`
|
||||
(iv) list of `Variable` and/or `PartitionedVariable`: The list may
|
||||
contain one or more variables that has been sharded. For example:
|
||||
[Variable('a/part_0'), Variable('b/part_0'), Variable('a/part_1'),
|
||||
PartitionedVariable([Variable('c/part_0'), Variable('c/part_1')])]
|
||||
where we have three whole Variables represented ('a', 'b', and 'c').
|
||||
prev_ckpt: A string specifying the directory with checkpoint file(s) or path
|
||||
to checkpoint. The given checkpoint must have tensor with name
|
||||
`prev_tensor_name` (if not None) or tensor with name same as given `var`.
|
||||
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
|
||||
None, we lookup tensor with same name as given `var`.
|
||||
|
||||
Raises:
|
||||
ValueError: If prev_tensor_name is not None, but the given var represents
|
||||
more than one Variable.
|
||||
TypeError: If var is not one of the allowed types.
|
||||
"""
|
||||
if _is_variable(var):
|
||||
current_var_name = _infer_var_name([var])
|
||||
elif isinstance(var, list) and all(_is_variable(v) for v in var):
|
||||
current_var_name = _infer_var_name(var)
|
||||
elif isinstance(var, variables.PartitionedVariable):
|
||||
current_var_name = _infer_var_name([var])
|
||||
var = var._get_variable_list() # pylint: disable=protected-access
|
||||
elif (isinstance(var, list) and all(
|
||||
_is_variable(v) or isinstance(v, variables.PartitionedVariable)
|
||||
for v in var)):
|
||||
# Convert length-1 lists of vars to single tf.Variables. This ensures that
|
||||
# checkpoint_utils.init_from_checkpoint() doesn't incorrectly assume
|
||||
# slice info is present.
|
||||
if len(var) == 1:
|
||||
current_var_name = _infer_var_name(var)
|
||||
var = var[0]
|
||||
else:
|
||||
# If we have multiple elements in var, we cannot assume they all
|
||||
# represent the same Variable.
|
||||
name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(
|
||||
var, convert_variable_to_tensor=False)
|
||||
if prev_tensor_name:
|
||||
# Providing a prev_tensor_name is only viable if var representes a
|
||||
# single Variable.
|
||||
if len(name_to_var_dict) > 1:
|
||||
raise ValueError("var represented more than one Variable, but "
|
||||
"prev_tensor_name was provided.")
|
||||
checkpoint_utils.init_from_checkpoint(prev_ckpt, {
|
||||
prev_tensor_name: var
|
||||
})
|
||||
else:
|
||||
# OpListToDict gives us roughly what we need, but
|
||||
# the values in the dict may be PartitionedVariables (which
|
||||
# init_from_checkpoint does not expect) that we need to convert to
|
||||
# lists.
|
||||
name_to_var_dict_fixed = {}
|
||||
for name, var in six.iteritems(name_to_var_dict):
|
||||
if isinstance(var, variables.PartitionedVariable):
|
||||
name_to_var_dict_fixed[name] = var._get_variable_list() # pylint: disable=protected-access
|
||||
else:
|
||||
name_to_var_dict_fixed[name] = var
|
||||
checkpoint_utils.init_from_checkpoint(prev_ckpt, name_to_var_dict_fixed)
|
||||
return
|
||||
else:
|
||||
raise TypeError(
|
||||
"var MUST be one of the following: a Variable, list of Variable or "
|
||||
"PartitionedVariable, but is {}".format(type(var)))
|
||||
"var MUST be one of the following: a Variable, PartitionedVariable, or "
|
||||
"list of Variable's and/or PartitionedVariable's, but is {}".format(
|
||||
type(var)))
|
||||
if not prev_tensor_name:
|
||||
# Assume tensor name remains the same.
|
||||
prev_tensor_name = current_var_name
|
||||
|
|
@ -173,7 +215,8 @@ def _warmstart_var_with_vocab(var,
|
|||
prev_ckpt,
|
||||
prev_vocab_path,
|
||||
current_oov_buckets=0,
|
||||
prev_tensor_name=None):
|
||||
prev_tensor_name=None,
|
||||
initializer=None):
|
||||
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
|
||||
|
||||
Use this method when the `var` is backed by vocabulary. This method stitches
|
||||
|
|
@ -200,6 +243,8 @@ def _warmstart_var_with_vocab(var,
|
|||
buckets used for given `var`.
|
||||
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
|
||||
None, we lookup tensor with same name as given `var`.
|
||||
initializer: Variable initializer to be used for missing entries. If None,
|
||||
missing entries will be zero-initialized.
|
||||
|
||||
Raises:
|
||||
ValueError: If required args are not provided.
|
||||
|
|
@ -232,18 +277,6 @@ def _warmstart_var_with_vocab(var,
|
|||
full_shape=slice_info.full_shape,
|
||||
var_offset=slice_info.var_offset)
|
||||
|
||||
# TODO(vihanjain): This is brittle. Can we instead infer actual initializer
|
||||
# used originally for the variable or use a fixed initializer?
|
||||
def _missing_ids_init(shape, dtype=None):
|
||||
# pylint: disable=cell-var-from-loop
|
||||
if dtype and dtype.base_dtype != v.dtype.base_dtype:
|
||||
raise ValueError("Trying to initialize missing ids with a different "
|
||||
"dtype `{}` than variable's dtype `{}`".format(
|
||||
dtype, v.dtype))
|
||||
return array_ops.slice(v.initial_value, [0, 0], shape)
|
||||
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
|
||||
# remapping too.
|
||||
init = checkpoint_ops._load_and_remap_matrix_initializer(
|
||||
|
|
@ -257,7 +290,7 @@ def _warmstart_var_with_vocab(var,
|
|||
new_col_vocab_file=None,
|
||||
num_row_oov_buckets=current_oov_buckets,
|
||||
num_col_oov_buckets=0,
|
||||
initializer=_missing_ids_init)
|
||||
initializer=initializer)
|
||||
new_init_val = ops.convert_to_tensor(
|
||||
init(shape=v_shape, partition_info=partition_info))
|
||||
v._initializer_op = state_ops.assign(v, new_init_val)
|
||||
|
|
@ -305,6 +338,11 @@ def _warmstart_input_layer(cols_to_vars, warmstart_settings):
|
|||
```
|
||||
|
||||
The above example effectively warm-starts full linear model.
|
||||
|
||||
Raises:
|
||||
ValueError: If a column in cols_to_vars has an entry in
|
||||
warmstart_settings.cols_to_prev_vocab, but is not an instance of
|
||||
_VocabularyFileCategoricalColumn or _EmbeddingColumn.
|
||||
"""
|
||||
for col, var in six.iteritems(cols_to_vars):
|
||||
if not isinstance(col, feature_column._FeatureColumn): # pylint: disable=protected-access
|
||||
|
|
@ -316,21 +354,43 @@ def _warmstart_input_layer(cols_to_vars, warmstart_settings):
|
|||
continue
|
||||
|
||||
prev_tensor_name = warmstart_settings.col_to_prev_tensor.get(col)
|
||||
if isinstance(col, feature_column._VocabularyFileCategoricalColumn): # pylint: disable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
is_sparse_vocab_column = isinstance(
|
||||
col, feature_column._VocabularyFileCategoricalColumn)
|
||||
is_embedding_vocab_column = (
|
||||
isinstance(col, feature_column._EmbeddingColumn) and
|
||||
isinstance(col.categorical_column,
|
||||
feature_column._VocabularyFileCategoricalColumn))
|
||||
if is_sparse_vocab_column or is_embedding_vocab_column:
|
||||
# pylint: enable=protected-access
|
||||
initializer = None
|
||||
if is_embedding_vocab_column:
|
||||
initializer = col.initializer
|
||||
vocabulary_file = col.categorical_column.vocabulary_file
|
||||
vocabulary_size = col.categorical_column.vocabulary_size
|
||||
num_oov_buckets = col.categorical_column.num_oov_buckets
|
||||
else:
|
||||
vocabulary_file = col.vocabulary_file
|
||||
vocabulary_size = col.vocabulary_size
|
||||
num_oov_buckets = col.num_oov_buckets
|
||||
prev_vocab_path = warmstart_settings.col_to_prev_vocab.get(
|
||||
col, col.vocabulary_file)
|
||||
col, vocabulary_file)
|
||||
logging.info("Warm-starting column: {}; prev_vocab: {}; prev_tensor: {}".
|
||||
format(col.name, prev_vocab_path, (
|
||||
prev_tensor_name or "Unchanged")))
|
||||
_warmstart_var_with_vocab(
|
||||
var,
|
||||
current_vocab_path=col.vocabulary_file,
|
||||
current_vocab_size=col.vocabulary_size,
|
||||
current_vocab_path=vocabulary_file,
|
||||
current_vocab_size=vocabulary_size,
|
||||
prev_ckpt=warmstart_settings.ckpt_to_initialize_from,
|
||||
prev_vocab_path=prev_vocab_path,
|
||||
current_oov_buckets=col.num_oov_buckets,
|
||||
prev_tensor_name=prev_tensor_name)
|
||||
current_oov_buckets=num_oov_buckets,
|
||||
prev_tensor_name=prev_tensor_name,
|
||||
initializer=initializer)
|
||||
else:
|
||||
if col in warmstart_settings.col_to_prev_vocab:
|
||||
raise ValueError("Vocabulary provided for column %s which is not a "
|
||||
"_VocabularyFileCategoricalColumn or _EmbeddingColumn")
|
||||
logging.info("Warm-starting column: {}; prev_tensor: {}".format(
|
||||
col.name, prev_tensor_name or "Unchanged"))
|
||||
_warmstart_var(var, warmstart_settings.ckpt_to_initialize_from,
|
||||
|
|
|
|||
|
|
@ -72,6 +72,36 @@ class WarmStartingUtilTest(test.TestCase):
|
|||
var = var._get_variable_list()
|
||||
return var, sess.run(var)
|
||||
|
||||
def _create_prev_run_multiple_vars(self,
|
||||
var_names,
|
||||
initializers,
|
||||
shapes=None,
|
||||
partitioners=None):
|
||||
if not shapes:
|
||||
shapes = [None] * len(var_names)
|
||||
if not partitioners:
|
||||
partitioners = [None] * len(var_names)
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
var_list = []
|
||||
for var_name, shape, initializer, partitioner in zip(
|
||||
var_names, shapes, initializers, partitioners):
|
||||
var_list.append(
|
||||
variable_scope.get_variable(
|
||||
var_name,
|
||||
shape=shape,
|
||||
initializer=initializer,
|
||||
partitioner=partitioner))
|
||||
self._write_checkpoint(sess)
|
||||
run_vars = []
|
||||
for var, partitioner in zip(var_list, partitioners):
|
||||
if partitioner:
|
||||
self.assertTrue(isinstance(var, variables.PartitionedVariable))
|
||||
run_vars.append(sess.run(var._get_variable_list()))
|
||||
else:
|
||||
run_vars.append(sess.run(var))
|
||||
return var_list, run_vars
|
||||
|
||||
def _create_dummy_inputs(self):
|
||||
return {
|
||||
"sc_int": array_ops.sparse_placeholder(dtypes.int32),
|
||||
|
|
@ -98,7 +128,7 @@ class WarmStartingUtilTest(test.TestCase):
|
|||
def _assert_cols_to_vars(self, cols_to_vars, cols_to_expected_values, sess):
|
||||
for col, expected_values in six.iteritems(cols_to_expected_values):
|
||||
for i, var in enumerate(cols_to_vars[col]):
|
||||
self.assertAllEqual(expected_values[i], var.eval(sess))
|
||||
self.assertAllClose(expected_values[i], var.eval(sess))
|
||||
|
||||
def testWarmStartVar(self):
|
||||
_, prev_val = self._create_prev_run_var(
|
||||
|
|
@ -175,6 +205,99 @@ class WarmStartingUtilTest(test.TestCase):
|
|||
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
|
||||
self.assertAllEqual(prev_val, new_val)
|
||||
|
||||
def testWarmStartVarMultipleVars(self):
|
||||
_, prev_vals = self._create_prev_run_multiple_vars(
|
||||
var_names=["fruit_weights", "other_weights"],
|
||||
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]]])
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
fruit_weights = variable_scope.get_variable(
|
||||
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
|
||||
other_weights = variable_scope.get_variable(
|
||||
"other_weights", initializer=[[0.], [0.], [0.], [0.]])
|
||||
ws_util._warmstart_var([fruit_weights, other_weights],
|
||||
self.get_temp_dir())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
|
||||
self.assertAllEqual(prev_vals[1], other_weights.eval(sess))
|
||||
|
||||
def testWarmStartVarMultipleVarsBothPartitioned(self):
|
||||
_, prev_vals = self._create_prev_run_multiple_vars(
|
||||
var_names=["fruit_weights", "other_weights"],
|
||||
shapes=[[4, 1], [4, 1]],
|
||||
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]]],
|
||||
partitioners=[lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]])
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
fruit_weights = variable_scope.get_variable(
|
||||
"fruit_weights",
|
||||
shape=[4, 1],
|
||||
initializer=[[0.], [0.], [0.], [0.]],
|
||||
partitioner=lambda shape, dtype: [2, 1])
|
||||
other_weights = variable_scope.get_variable(
|
||||
"other_weights",
|
||||
shape=[4, 1],
|
||||
initializer=[[0.], [0.], [0.], [0.]],
|
||||
partitioner=lambda shape, dtype: [2, 1])
|
||||
ws_util._warmstart_var([fruit_weights, other_weights],
|
||||
self.get_temp_dir())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
fruit_weights = fruit_weights._get_variable_list()
|
||||
new_fruit_weights_val = np.concatenate(
|
||||
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
|
||||
other_weights = other_weights._get_variable_list()
|
||||
new_other_weights_val = np.concatenate(
|
||||
[other_weights[0].eval(sess), other_weights[1].eval(sess)], axis=0)
|
||||
self.assertAllEqual(
|
||||
np.concatenate(prev_vals[0], axis=0), new_fruit_weights_val)
|
||||
self.assertAllEqual(
|
||||
np.concatenate(prev_vals[1], axis=0), new_other_weights_val)
|
||||
|
||||
def testWarmStartVarMultipleVarsMixOfPartitions(self):
|
||||
# First is not partitioned, but the second two are.
|
||||
_, prev_vals = self._create_prev_run_multiple_vars(
|
||||
var_names=["fruit_weights", "other_weights", "veggie_weights"],
|
||||
shapes=[None, [4, 1], [4, 1]],
|
||||
initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]],
|
||||
[[5.], [10.], [15.], [20.]]],
|
||||
partitioners=[
|
||||
None, lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]
|
||||
])
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
fruit_weights = variable_scope.get_variable(
|
||||
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
|
||||
other_weights = variable_scope.get_variable(
|
||||
"other_weights",
|
||||
shape=[4, 1],
|
||||
initializer=[[0.], [0.], [0.], [0.]],
|
||||
partitioner=lambda shape, dtype: [2, 1])
|
||||
veggie_weights = variable_scope.get_variable(
|
||||
"veggie_weights",
|
||||
shape=[4, 1],
|
||||
initializer=[[0.], [0.], [0.], [0.]],
|
||||
partitioner=lambda shape, dtype: [2, 1])
|
||||
# Flatten one of the partitioned variables.
|
||||
ws_util._warmstart_var([fruit_weights, other_weights] +
|
||||
veggie_weights._get_variable_list(),
|
||||
self.get_temp_dir())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
veggie_weights = veggie_weights._get_variable_list()
|
||||
new_veggie_weights_val = np.concatenate(
|
||||
[veggie_weights[0].eval(sess), veggie_weights[1].eval(sess)],
|
||||
axis=0)
|
||||
other_weights = other_weights._get_variable_list()
|
||||
new_other_weights_val = np.concatenate(
|
||||
[other_weights[0].eval(sess), other_weights[1].eval(sess)], axis=0)
|
||||
self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
|
||||
self.assertAllEqual(
|
||||
np.concatenate(prev_vals[1], axis=0), new_other_weights_val)
|
||||
self.assertAllEqual(
|
||||
np.concatenate(prev_vals[2], axis=0), new_veggie_weights_val)
|
||||
|
||||
def testWarmStartVarWithVocab(self):
|
||||
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||
"old_vocab")
|
||||
|
|
@ -558,6 +681,66 @@ class WarmStartingUtilTest(test.TestCase):
|
|||
]
|
||||
}, sess)
|
||||
|
||||
def testWarmStartInputLayerEmbeddingColumn(self):
|
||||
# Create old and new vocabs for embedding column "sc_vocab".
|
||||
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||
"old_vocab")
|
||||
new_vocab_path = self._write_vocab(
|
||||
["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
|
||||
"new_vocab")
|
||||
|
||||
# Save checkpoint from which to warm-start.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
_ = variable_scope.get_variable(
|
||||
"input_layer/sc_vocab_embedding/embedding_weights",
|
||||
initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
|
||||
self._write_checkpoint(sess)
|
||||
|
||||
def _partitioner(shape, dtype): # pylint:disable=unused-argument
|
||||
# Partition each var into 2 equal slices.
|
||||
partitions = [1] * len(shape)
|
||||
partitions[0] = min(2, shape[0].value)
|
||||
return partitions
|
||||
|
||||
# Create feature columns.
|
||||
sc_vocab = fc.categorical_column_with_vocabulary_file(
|
||||
"sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
|
||||
emb_vocab = fc.embedding_column(
|
||||
categorical_column=sc_vocab,
|
||||
dimension=2,
|
||||
# Can't use constant_initializer with load_and_remap. In practice,
|
||||
# use a truncated normal initializer.
|
||||
initializer=init_ops.random_uniform_initializer(
|
||||
minval=0.42, maxval=0.42))
|
||||
all_deep_cols = [emb_vocab]
|
||||
# New graph, new session with warmstarting.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
cols_to_vars = {}
|
||||
with variable_scope.variable_scope("", partitioner=_partitioner):
|
||||
# Create the variables.
|
||||
fc.input_layer(
|
||||
features=self._create_dummy_inputs(),
|
||||
feature_columns=all_deep_cols,
|
||||
cols_to_vars=cols_to_vars)
|
||||
ws_settings = ws_util._WarmStartSettings(
|
||||
self.get_temp_dir(), col_to_prev_vocab={
|
||||
emb_vocab: prev_vocab_path
|
||||
})
|
||||
ws_util._warmstart_input_layer(cols_to_vars, ws_settings)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Verify weights were correctly warmstarted. Var corresponding to
|
||||
# emb_vocab should be correctly warmstarted after vocab remapping.
|
||||
# Missing values are filled in with the EmbeddingColumn's initializer.
|
||||
self._assert_cols_to_vars(
|
||||
cols_to_vars, {
|
||||
emb_vocab: [
|
||||
np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
|
||||
np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
|
||||
]
|
||||
}, sess)
|
||||
|
||||
def testErrorConditions(self):
|
||||
self.assertRaises(ValueError, ws_util._WarmStartSettings, None)
|
||||
x = variable_scope.get_variable(
|
||||
|
|
@ -566,8 +749,7 @@ class WarmStartingUtilTest(test.TestCase):
|
|||
initializer=ones(),
|
||||
partitioner=lambda shape, dtype: [2, 1])
|
||||
|
||||
# List of PartitionedVariable is invalid type.
|
||||
self.assertRaises(TypeError, ws_util._warmstart_var, [x], prev_ckpt="/tmp")
|
||||
# List of PartitionedVariable is invalid type when warmstarting with vocab.
|
||||
self.assertRaises(TypeError, ws_util._warmstart_var_with_vocab, [x], "/tmp",
|
||||
5, "/tmp", "/tmp")
|
||||
# Keys of type other than FeatureColumn.
|
||||
|
|
|
|||
|
|
@ -503,11 +503,13 @@ class BaseSaverBuilder(object):
|
|||
return sorted(per_device.items(), key=lambda t: t[0])
|
||||
|
||||
@staticmethod
|
||||
def OpListToDict(op_list):
|
||||
def OpListToDict(op_list, convert_variable_to_tensor=True):
|
||||
"""Create a dictionary of names to operation lists.
|
||||
|
||||
Args:
|
||||
op_list: A list, tuple, or set of Variables or SaveableObjects.
|
||||
convert_variable_to_tensor: Whether or not to convert single Variables
|
||||
with no slice info into Tensors.
|
||||
|
||||
Returns:
|
||||
A dictionary of names to the operations that must be saved under
|
||||
|
|
@ -543,9 +545,10 @@ class BaseSaverBuilder(object):
|
|||
names_to_saveables[name] = [var]
|
||||
else:
|
||||
if context.in_graph_mode():
|
||||
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
||||
if not BaseSaverBuilder._IsVariable(var):
|
||||
raise TypeError("Variable to save is not a Variable: %s" % var)
|
||||
if convert_variable_to_tensor:
|
||||
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
||||
if not BaseSaverBuilder._IsVariable(var):
|
||||
raise TypeError("Variable to save is not a Variable: %s" % var)
|
||||
if var.op.type == "ReadVariableOp":
|
||||
name = var.op.inputs[0].op.name
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user