mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add double support for tf.decode_csv
In the current tensorflow `tf.decode_csv` accepts `float`, `int32`, `int64`, `string` but not `double`. It seems adding `double` support makes sense as `StringToNumber` already support `double` type. This fix adds `double` support for `tf.decode_csv` Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
231ca9dd4e
commit
3595d1613d
|
|
@ -137,6 +137,25 @@ class DecodeCSVOp : public OpKernel {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case DT_DOUBLE: {
|
||||
// If this field is empty or NA value, check if default is given:
|
||||
// If yes, use default value; Otherwise report error.
|
||||
if (fields[f].empty() || fields[f] == na_value_) {
|
||||
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
|
||||
errors::InvalidArgument(
|
||||
"Field ", f,
|
||||
" is required but missing in record ", i, "!"));
|
||||
output[f]->flat<double>()(i) = record_defaults[f].flat<double>()(0);
|
||||
} else {
|
||||
double value;
|
||||
OP_REQUIRES(ctx, strings::safe_strtod(fields[f].c_str(), &value),
|
||||
errors::InvalidArgument(
|
||||
"Field ", f, " in record ", i,
|
||||
" is not a valid double: ", fields[f]));
|
||||
output[f]->flat<double>()(i) = value;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DT_STRING: {
|
||||
// If this field is empty or NA value, check if default is given:
|
||||
// If yes, use default value; Otherwise report error.
|
||||
|
|
|
|||
|
|
@ -329,7 +329,7 @@ REGISTER_OP("DecodeCSV")
|
|||
.Input("records: string")
|
||||
.Input("record_defaults: OUT_TYPE")
|
||||
.Output("output: OUT_TYPE")
|
||||
.Attr("OUT_TYPE: list({float,int32,int64,string})")
|
||||
.Attr("OUT_TYPE: list({float,double,int32,int64,string})")
|
||||
.Attr("field_delim: string = ','")
|
||||
.Attr("use_quote_delim: bool = true")
|
||||
.Attr("na_value: string = ''")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user