mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Make tf.contrib.lookup python functions use the kernels v2 that uses the resource tensor as handler.
PiperOrigin-RevId: 158291836
This commit is contained in:
parent
ebae3deba8
commit
55f987692a
|
|
@ -209,7 +209,9 @@ def _get_replica_device_setter(config):
|
||||||
"""
|
"""
|
||||||
ps_ops = [
|
ps_ops = [
|
||||||
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
|
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
|
||||||
'MutableHashTableOfTensors', 'MutableDenseHashTable'
|
'MutableHashTableV2', 'MutableHashTableOfTensors',
|
||||||
|
'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
|
||||||
|
'MutableDenseHashTableV2'
|
||||||
]
|
]
|
||||||
|
|
||||||
if config.task_type:
|
if config.task_type:
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -88,10 +88,11 @@ limitations under the License.
|
||||||
// shapes, particularly when restoring a graph from GraphDef
|
// shapes, particularly when restoring a graph from GraphDef
|
||||||
// produced at version 22 or later. (04/10/2016)
|
// produced at version 22 or later. (04/10/2016)
|
||||||
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
|
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
|
||||||
|
// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
|
||||||
|
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||||
#define TF_GRAPH_DEF_VERSION 23
|
#define TF_GRAPH_DEF_VERSION 24
|
||||||
|
|
||||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -725,7 +725,9 @@ def _get_replica_device_setter(config):
|
||||||
"""
|
"""
|
||||||
ps_ops = [
|
ps_ops = [
|
||||||
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
|
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
|
||||||
'MutableHashTableOfTensors', 'MutableDenseHashTable'
|
'MutableHashTableV2', 'MutableHashTableOfTensors',
|
||||||
|
'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
|
||||||
|
'MutableDenseHashTableV2'
|
||||||
]
|
]
|
||||||
|
|
||||||
if config.task_type:
|
if config.task_type:
|
||||||
|
|
|
||||||
|
|
@ -1286,7 +1286,7 @@ class EstimatorExportTest(test.TestCase):
|
||||||
self.assertTrue('input_example_tensor' in graph_ops)
|
self.assertTrue('input_example_tensor' in graph_ops)
|
||||||
self.assertTrue('ParseExample/ParseExample' in graph_ops)
|
self.assertTrue('ParseExample/ParseExample' in graph_ops)
|
||||||
# Note that the SavedModel builder replaced the Saver with a new one
|
# Note that the SavedModel builder replaced the Saver with a new one
|
||||||
self.assertTrue('save_1/LookupTableImport' in graph_ops)
|
self.assertTrue('save_1/LookupTableImportV2' in graph_ops)
|
||||||
|
|
||||||
# Clean up.
|
# Clean up.
|
||||||
gfile.DeleteRecursively(tmpdir)
|
gfile.DeleteRecursively(tmpdir)
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class CheckpointedOp(object):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
def __init__(self, name, table_ref=None):
|
def __init__(self, name, table_ref=None):
|
||||||
if table_ref is None:
|
if table_ref is None:
|
||||||
self.table_ref = gen_lookup_ops._mutable_hash_table(
|
self.table_ref = gen_lookup_ops._mutable_hash_table_v2(
|
||||||
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
|
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
|
||||||
else:
|
else:
|
||||||
self.table_ref = table_ref
|
self.table_ref = table_ref
|
||||||
|
|
@ -52,10 +52,10 @@ class CheckpointedOp(object):
|
||||||
return self._saveable
|
return self._saveable
|
||||||
|
|
||||||
def insert(self, keys, values):
|
def insert(self, keys, values):
|
||||||
return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values)
|
return gen_lookup_ops._lookup_table_insert_v2(self.table_ref, keys, values)
|
||||||
|
|
||||||
def lookup(self, keys, default):
|
def lookup(self, keys, default):
|
||||||
return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default)
|
return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys, default)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self._export()[0]
|
return self._export()[0]
|
||||||
|
|
@ -64,8 +64,8 @@ class CheckpointedOp(object):
|
||||||
return self._export()[1]
|
return self._export()[1]
|
||||||
|
|
||||||
def _export(self):
|
def _export(self):
|
||||||
return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string,
|
return gen_lookup_ops._lookup_table_export_v2(self.table_ref, dtypes.string,
|
||||||
dtypes.float32)
|
dtypes.float32)
|
||||||
|
|
||||||
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
|
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
|
||||||
"""A custom saveable for CheckpointedOp."""
|
"""A custom saveable for CheckpointedOp."""
|
||||||
|
|
@ -81,6 +81,6 @@ class CheckpointedOp(object):
|
||||||
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
|
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
|
||||||
|
|
||||||
def restore(self, restore_tensors, shapes):
|
def restore(self, restore_tensors, shapes):
|
||||||
return gen_lookup_ops._lookup_table_import(
|
return gen_lookup_ops._lookup_table_import_v2(
|
||||||
self.op.table_ref, restore_tensors[0], restore_tensors[1])
|
self.op.table_ref, restore_tensors[0], restore_tensors[1])
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user