mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add test cases for double support of tf.decode_csv
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
73aaed655b
commit
285ea39108
|
|
@ -34,7 +34,7 @@ class DecodeCSVOpTest(test.TestCase):
|
||||||
out = sess.run(decode)
|
out = sess.run(decode)
|
||||||
|
|
||||||
for i, field in enumerate(out):
|
for i, field in enumerate(out):
|
||||||
if field.dtype == np.float32:
|
if field.dtype == np.float32 or field.dtype == np.float64:
|
||||||
self.assertAllClose(field, expected_out[i])
|
self.assertAllClose(field, expected_out[i])
|
||||||
else:
|
else:
|
||||||
self.assertAllEqual(field, expected_out[i])
|
self.assertAllEqual(field, expected_out[i])
|
||||||
|
|
@ -85,6 +85,17 @@ class DecodeCSVOpTest(test.TestCase):
|
||||||
|
|
||||||
self._test(args, expected_out)
|
self._test(args, expected_out)
|
||||||
|
|
||||||
|
def testDouble(self):
|
||||||
|
args = {
|
||||||
|
"records": ["1.0", "-1.79e+308", '"1.79e+308"'],
|
||||||
|
"record_defaults": [np.array(
|
||||||
|
[], dtype=np.double)],
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_out = [[1.0, -1.79e+308, 1.79e+308]]
|
||||||
|
|
||||||
|
self._test(args, expected_out)
|
||||||
|
|
||||||
def testInt64(self):
|
def testInt64(self):
|
||||||
args = {
|
args = {
|
||||||
"records": ["1", "2", '"2147483648"'],
|
"records": ["1", "2", '"2147483648"'],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user