Fix index_table_from_file to allow vocabulary_file be a Tensor

PiperOrigin-RevId: 157740677
This commit is contained in:
A. Unique TensorFlower 2017-06-01 11:41:49 -07:00 committed by TensorFlower Gardener
parent 0aa3e01941
commit 9fc1642250
4 changed files with 44 additions and 8 deletions

View File

@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None,
```
Args:
vocabulary_file: The vocabulary filename.
vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
num_oov_buckets: The number of out-of-vocabulary buckets.
vocab_size: Number of the elements in the vocabulary, if known.
default_value: The value to use for out-of-vocabulary feature values.
@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None,
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
than zero.
"""
if not vocabulary_file:
raise ValueError("vocabulary_file must be specified.")
if vocabulary_file is None or (
isinstance(vocabulary_file, str) and not vocabulary_file):
raise ValueError("vocabulary_file must be specified and must not be empty.")
if num_oov_buckets < 0:
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
% num_oov_buckets)

View File

@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase):
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
with self.test_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase):
860), # 3 + fingerprint("toccata") mod 300.
ids.eval())
def test_index_table_from_file_with_only_oov_buckets(self):
def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
self.assertRaises(
ValueError,
lookup.index_table_from_file,
vocabulary_file="")
def test_index_table_from_file_fails_with_empty_vocabulary(self):
self.assertRaises(
ValueError,
lookup.index_table_from_file,

View File

@ -280,6 +280,18 @@ class IndexTableFromFile(test.TestCase):
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
with self.test_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
@ -340,7 +352,11 @@ class IndexTableFromFile(test.TestCase):
860), # 3 + fingerprint("toccata") mod 300.
ids.eval())
def test_index_table_from_file_with_only_oov_buckets(self):
def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
self.assertRaises(
ValueError, lookup_ops.index_table_from_file, vocabulary_file="")
def test_index_table_from_file_fails_with_empty_vocabulary(self):
self.assertRaises(
ValueError, lookup_ops.index_table_from_file, vocabulary_file=None)

View File

@ -893,7 +893,7 @@ def index_table_from_file(vocabulary_file=None,
```
Args:
vocabulary_file: The vocabulary filename.
vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
num_oov_buckets: The number of out-of-vocabulary buckets.
vocab_size: Number of the elements in the vocabulary, if known.
default_value: The value to use for out-of-vocabulary feature values.
@ -911,8 +911,9 @@ def index_table_from_file(vocabulary_file=None,
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
than zero.
"""
if not vocabulary_file:
raise ValueError("vocabulary_file must be specified.")
if vocabulary_file is None or (
isinstance(vocabulary_file, str) and not vocabulary_file):
raise ValueError("vocabulary_file must be specified and must not be empty.")
if num_oov_buckets < 0:
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
% num_oov_buckets)