mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add test cases for double support of tf.decode_csv
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
73aaed655b
commit
285ea39108
|
|
@ -34,7 +34,7 @@ class DecodeCSVOpTest(test.TestCase):
|
|||
out = sess.run(decode)
|
||||
|
||||
for i, field in enumerate(out):
|
||||
if field.dtype == np.float32:
|
||||
if field.dtype == np.float32 or field.dtype == np.float64:
|
||||
self.assertAllClose(field, expected_out[i])
|
||||
else:
|
||||
self.assertAllEqual(field, expected_out[i])
|
||||
|
|
@ -85,6 +85,17 @@ class DecodeCSVOpTest(test.TestCase):
|
|||
|
||||
self._test(args, expected_out)
|
||||
|
||||
def testDouble(self):
|
||||
args = {
|
||||
"records": ["1.0", "-1.79e+308", '"1.79e+308"'],
|
||||
"record_defaults": [np.array(
|
||||
[], dtype=np.double)],
|
||||
}
|
||||
|
||||
expected_out = [[1.0, -1.79e+308, 1.79e+308]]
|
||||
|
||||
self._test(args, expected_out)
|
||||
|
||||
def testInt64(self):
|
||||
args = {
|
||||
"records": ["1", "2", '"2147483648"'],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user