Make tf.contrib.lookup python functions use the kernels v2 that uses the resource tensor as handler.

PiperOrigin-RevId: 158291836
This commit is contained in:
Yutaka Leon 2017-06-07 11:16:59 -07:00 committed by TensorFlower Gardener
parent ebae3deba8
commit 55f987692a
6 changed files with 90 additions and 1065 deletions

View File

@ -209,7 +209,9 @@ def _get_replica_device_setter(config):
""" """
ps_ops = [ ps_ops = [
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
'MutableHashTableOfTensors', 'MutableDenseHashTable' 'MutableHashTableV2', 'MutableHashTableOfTensors',
'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
'MutableDenseHashTableV2'
] ]
if config.task_type: if config.task_type:

File diff suppressed because it is too large Load Diff

View File

@ -88,10 +88,11 @@ limitations under the License.
// shapes, particularly when restoring a graph from GraphDef // shapes, particularly when restoring a graph from GraphDef
// produced at version 22 or later. (04/10/2016) // produced at version 22 or later. (04/10/2016)
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2. // 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
#define TF_GRAPH_DEF_VERSION 23 #define TF_GRAPH_DEF_VERSION 24
// Checkpoint compatibility versions (the versions field in SavedSliceMeta). // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
// //

View File

@ -725,7 +725,9 @@ def _get_replica_device_setter(config):
""" """
ps_ops = [ ps_ops = [
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
'MutableHashTableOfTensors', 'MutableDenseHashTable' 'MutableHashTableV2', 'MutableHashTableOfTensors',
'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
'MutableDenseHashTableV2'
] ]
if config.task_type: if config.task_type:

View File

@ -1286,7 +1286,7 @@ class EstimatorExportTest(test.TestCase):
self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops)
# Note that the SavedModel builder replaced the Saver with a new one # Note that the SavedModel builder replaced the Saver with a new one
self.assertTrue('save_1/LookupTableImport' in graph_ops) self.assertTrue('save_1/LookupTableImportV2' in graph_ops)
# Clean up. # Clean up.
gfile.DeleteRecursively(tmpdir) gfile.DeleteRecursively(tmpdir)

View File

@ -34,7 +34,7 @@ class CheckpointedOp(object):
# pylint: disable=protected-access # pylint: disable=protected-access
def __init__(self, name, table_ref=None): def __init__(self, name, table_ref=None):
if table_ref is None: if table_ref is None:
self.table_ref = gen_lookup_ops._mutable_hash_table( self.table_ref = gen_lookup_ops._mutable_hash_table_v2(
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
else: else:
self.table_ref = table_ref self.table_ref = table_ref
@ -52,10 +52,10 @@ class CheckpointedOp(object):
return self._saveable return self._saveable
def insert(self, keys, values): def insert(self, keys, values):
return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values) return gen_lookup_ops._lookup_table_insert_v2(self.table_ref, keys, values)
def lookup(self, keys, default): def lookup(self, keys, default):
return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default) return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys, default)
def keys(self): def keys(self):
return self._export()[0] return self._export()[0]
@ -64,8 +64,8 @@ class CheckpointedOp(object):
return self._export()[1] return self._export()[1]
def _export(self): def _export(self):
return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string, return gen_lookup_ops._lookup_table_export_v2(self.table_ref, dtypes.string,
dtypes.float32) dtypes.float32)
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
"""A custom saveable for CheckpointedOp.""" """A custom saveable for CheckpointedOp."""
@ -81,6 +81,6 @@ class CheckpointedOp(object):
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
def restore(self, restore_tensors, shapes): def restore(self, restore_tensors, shapes):
return gen_lookup_ops._lookup_table_import( return gen_lookup_ops._lookup_table_import_v2(
self.op.table_ref, restore_tensors[0], restore_tensors[1]) self.op.table_ref, restore_tensors[0], restore_tensors[1])
# pylint: enable=protected-access # pylint: enable=protected-access