Add double support for tf.decode_csv

In the current tensorflow `tf.decode_csv` accepts
`float`, `int32`, `int64`, `string` but not `double`.
It seems adding `double` support makes sense as `StringToNumber`
already support `double` type.

This fix adds `double` support for `tf.decode_csv`

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-10-15 23:07:30 +00:00
parent 231ca9dd4e
commit 3595d1613d
2 changed files with 20 additions and 1 deletions

View File

@ -137,6 +137,25 @@ class DecodeCSVOp : public OpKernel {
}
break;
}
case DT_DOUBLE: {
// If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
" is required but missing in record ", i, "!"));
output[f]->flat<double>()(i) = record_defaults[f].flat<double>()(0);
} else {
double value;
OP_REQUIRES(ctx, strings::safe_strtod(fields[f].c_str(), &value),
errors::InvalidArgument(
"Field ", f, " in record ", i,
" is not a valid double: ", fields[f]));
output[f]->flat<double>()(i) = value;
}
break;
}
case DT_STRING: {
// If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.

View File

@ -329,7 +329,7 @@ REGISTER_OP("DecodeCSV")
.Input("records: string")
.Input("record_defaults: OUT_TYPE")
.Output("output: OUT_TYPE")
.Attr("OUT_TYPE: list({float,int32,int64,string})")
.Attr("OUT_TYPE: list({float,double,int32,int64,string})")
.Attr("field_delim: string = ','")
.Attr("use_quote_delim: bool = true")
.Attr("na_value: string = ''")