Make tf.contrib.lookup python functions use the kernels v2 that uses the resource tensor as handler.

PiperOrigin-RevId: 158291836
This commit is contained in:
Yutaka Leon 2017-06-07 11:16:59 -07:00 committed by TensorFlower Gardener
parent ebae3deba8
commit 55f987692a
6 changed files with 90 additions and 1065 deletions

View File

@ -209,7 +209,9 @@ def _get_replica_device_setter(config):
"""
ps_ops = [
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
'MutableHashTableOfTensors', 'MutableDenseHashTable'
'MutableHashTableV2', 'MutableHashTableOfTensors',
'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
'MutableDenseHashTableV2'
]
if config.task_type:

File diff suppressed because it is too large Load Diff

View File

@ -88,10 +88,11 @@ limitations under the License.
// shapes, particularly when restoring a graph from GraphDef
// produced at version 22 or later. (04/10/2016)
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
#define TF_GRAPH_DEF_VERSION 23
#define TF_GRAPH_DEF_VERSION 24
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//

View File

@ -725,7 +725,9 @@ def _get_replica_device_setter(config):
"""
ps_ops = [
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
'MutableHashTableOfTensors', 'MutableDenseHashTable'
'MutableHashTableV2', 'MutableHashTableOfTensors',
'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
'MutableDenseHashTableV2'
]
if config.task_type:

View File

@ -1286,7 +1286,7 @@ class EstimatorExportTest(test.TestCase):
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
# Note that the SavedModel builder replaced the Saver with a new one
self.assertTrue('save_1/LookupTableImport' in graph_ops)
self.assertTrue('save_1/LookupTableImportV2' in graph_ops)
# Clean up.
gfile.DeleteRecursively(tmpdir)

View File

@ -34,7 +34,7 @@ class CheckpointedOp(object):
# pylint: disable=protected-access
def __init__(self, name, table_ref=None):
if table_ref is None:
self.table_ref = gen_lookup_ops._mutable_hash_table(
self.table_ref = gen_lookup_ops._mutable_hash_table_v2(
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
else:
self.table_ref = table_ref
@ -52,10 +52,10 @@ class CheckpointedOp(object):
return self._saveable
def insert(self, keys, values):
return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values)
return gen_lookup_ops._lookup_table_insert_v2(self.table_ref, keys, values)
def lookup(self, keys, default):
return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default)
return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys, default)
def keys(self):
return self._export()[0]
@ -64,8 +64,8 @@ class CheckpointedOp(object):
return self._export()[1]
def _export(self):
return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string,
dtypes.float32)
return gen_lookup_ops._lookup_table_export_v2(self.table_ref, dtypes.string,
dtypes.float32)
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
"""A custom saveable for CheckpointedOp."""
@ -81,6 +81,6 @@ class CheckpointedOp(object):
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
def restore(self, restore_tensors, shapes):
return gen_lookup_ops._lookup_table_import(
return gen_lookup_ops._lookup_table_import_v2(
self.op.table_ref, restore_tensors[0], restore_tensors[1])
# pylint: enable=protected-access