Add test cases for double support of tf.decode_csv

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-10-15 23:11:04 +00:00
parent 73aaed655b
commit 285ea39108

View File

@ -34,7 +34,7 @@ class DecodeCSVOpTest(test.TestCase):
out = sess.run(decode) out = sess.run(decode)
for i, field in enumerate(out): for i, field in enumerate(out):
if field.dtype == np.float32: if field.dtype == np.float32 or field.dtype == np.float64:
self.assertAllClose(field, expected_out[i]) self.assertAllClose(field, expected_out[i])
else: else:
self.assertAllEqual(field, expected_out[i]) self.assertAllEqual(field, expected_out[i])
@ -85,6 +85,17 @@ class DecodeCSVOpTest(test.TestCase):
self._test(args, expected_out) self._test(args, expected_out)
def testDouble(self):
args = {
"records": ["1.0", "-1.79e+308", '"1.79e+308"'],
"record_defaults": [np.array(
[], dtype=np.double)],
}
expected_out = [[1.0, -1.79e+308, 1.79e+308]]
self._test(args, expected_out)
def testInt64(self): def testInt64(self):
args = { args = {
"records": ["1", "2", '"2147483648"'], "records": ["1", "2", '"2147483648"'],