mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fix index_table_from_file to allow vocabulary_file be a Tensor
PiperOrigin-RevId: 157740677
This commit is contained in:
parent
0aa3e01941
commit
9fc1642250
|
|
@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None,
|
|||
```
|
||||
|
||||
Args:
|
||||
vocabulary_file: The vocabulary filename.
|
||||
vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
|
||||
num_oov_buckets: The number of out-of-vocabulary buckets.
|
||||
vocab_size: Number of the elements in the vocabulary, if known.
|
||||
default_value: The value to use for out-of-vocabulary feature values.
|
||||
|
|
@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None,
|
|||
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
|
||||
than zero.
|
||||
"""
|
||||
if not vocabulary_file:
|
||||
raise ValueError("vocabulary_file must be specified.")
|
||||
if vocabulary_file is None or (
|
||||
isinstance(vocabulary_file, str) and not vocabulary_file):
|
||||
raise ValueError("vocabulary_file must be specified and must not be empty.")
|
||||
if num_oov_buckets < 0:
|
||||
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
|
||||
% num_oov_buckets)
|
||||
|
|
|
|||
|
|
@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase):
|
|||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||
|
||||
def test_string_index_table_from_file_tensor_filename(self):
|
||||
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
|
||||
with self.test_session():
|
||||
vocabulary_file = constant_op.constant(vocabulary_file)
|
||||
table = lookup.index_table_from_file(
|
||||
vocabulary_file=vocabulary_file, num_oov_buckets=1)
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||
|
||||
def test_int32_index_table_from_file(self):
|
||||
vocabulary_file = self._createVocabFile(
|
||||
"f2i_vocab2.txt", values=("42", "1", "-1000"))
|
||||
|
|
@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase):
|
|||
860), # 3 + fingerprint("toccata") mod 300.
|
||||
ids.eval())
|
||||
|
||||
def test_index_table_from_file_with_only_oov_buckets(self):
|
||||
def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lookup.index_table_from_file,
|
||||
vocabulary_file="")
|
||||
|
||||
def test_index_table_from_file_fails_with_empty_vocabulary(self):
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lookup.index_table_from_file,
|
||||
|
|
|
|||
|
|
@ -280,6 +280,18 @@ class IndexTableFromFile(test.TestCase):
|
|||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||
|
||||
def test_string_index_table_from_file_tensor_filename(self):
|
||||
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
|
||||
with self.test_session():
|
||||
vocabulary_file = constant_op.constant(vocabulary_file)
|
||||
table = lookup_ops.index_table_from_file(
|
||||
vocabulary_file=vocabulary_file, num_oov_buckets=1)
|
||||
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
|
||||
|
||||
self.assertRaises(errors_impl.OpError, ids.eval)
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual((1, 2, 3), ids.eval())
|
||||
|
||||
def test_int32_index_table_from_file(self):
|
||||
vocabulary_file = self._createVocabFile(
|
||||
"f2i_vocab2.txt", values=("42", "1", "-1000"))
|
||||
|
|
@ -340,7 +352,11 @@ class IndexTableFromFile(test.TestCase):
|
|||
860), # 3 + fingerprint("toccata") mod 300.
|
||||
ids.eval())
|
||||
|
||||
def test_index_table_from_file_with_only_oov_buckets(self):
|
||||
def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
|
||||
self.assertRaises(
|
||||
ValueError, lookup_ops.index_table_from_file, vocabulary_file="")
|
||||
|
||||
def test_index_table_from_file_fails_with_empty_vocabulary(self):
|
||||
self.assertRaises(
|
||||
ValueError, lookup_ops.index_table_from_file, vocabulary_file=None)
|
||||
|
||||
|
|
|
|||
|
|
@ -893,7 +893,7 @@ def index_table_from_file(vocabulary_file=None,
|
|||
```
|
||||
|
||||
Args:
|
||||
vocabulary_file: The vocabulary filename.
|
||||
vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
|
||||
num_oov_buckets: The number of out-of-vocabulary buckets.
|
||||
vocab_size: Number of the elements in the vocabulary, if known.
|
||||
default_value: The value to use for out-of-vocabulary feature values.
|
||||
|
|
@ -911,8 +911,9 @@ def index_table_from_file(vocabulary_file=None,
|
|||
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
|
||||
than zero.
|
||||
"""
|
||||
if not vocabulary_file:
|
||||
raise ValueError("vocabulary_file must be specified.")
|
||||
if vocabulary_file is None or (
|
||||
isinstance(vocabulary_file, str) and not vocabulary_file):
|
||||
raise ValueError("vocabulary_file must be specified and must not be empty.")
|
||||
if num_oov_buckets < 0:
|
||||
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
|
||||
% num_oov_buckets)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user