Add float16 and float64 input&output type support for TFLite operator 'cast'

Type float16 and float64 input/output for TensorFlow 'cast' operator is used in some Federated Learning models, thus adding these type supports to TFLite 'cast' op can make these operators converted to TFLite build-in ops instead of flex ops.

PiperOrigin-RevId: 506997479
This commit is contained in:
Youchuan Hu 2023-02-03 14:38:25 -08:00 committed by TensorFlower Gardener
parent a38841aa17
commit 9fbf113704
12 changed files with 189 additions and 82 deletions

View File

@ -83,6 +83,7 @@
* `tf.lite`:
* Add 16-bit float type support for built-in op `fill`.
* Add 16-bit and 64-bit float type support for built-in op `cast`.
* Transpose now supports 6D tensors.
* Float LSTM now supports diagonal recurrent tensors:
https://arxiv.org/abs/1903.08023

View File

@ -3885,10 +3885,10 @@ def TFL_CastOp : TFL_Op<"cast", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
TFL_TensorOf<[F16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
);
let results = (outs TFL_TensorOf<[F32, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
let results = (outs TFL_TensorOf<[F16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.

View File

@ -190,33 +190,3 @@ func.func @whileDifferentResultShapes(%arg0: tensor<i32>) -> tensor<?xf32>
// CHECK: (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<?xf32>, tensor<i32>)
func.return %0#1 : tensor<?xf32>
}
// -----
func.func @unsupportedCast(%arg0: tensor<4x4x3xf32>) -> tensor<*xf32> {
%cst = arith.constant dense<0.000000e+00> : tensor<4x2xf32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<4x4x3xf64>
%cst_1 = arith.constant dense<[1, 0, 2]> : tensor<3xi32>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<4x4x2xf32>
%cst_3 = arith.constant dense<4> : tensor<i32>
%cst_4 = arith.constant dense<0> : tensor<i32>
%cst_5 = arith.constant dense<0.000000e+00> : tensor<4x2xf64>
%0 = "tfl.transpose"(%arg0, %cst_1) : (tensor<4x4x3xf32>, tensor<3xi32>) -> tensor<4x4x3xf32>
%1:6 = "tfl.while"(%cst_4, %cst_4, %cst_2, %cst, %cst_5, %cst_0) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<*xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x2xf64>, %arg6: tensor<*xf64>):
%5 = "tfl.less"(%arg2, %cst_3) : (tensor<i32>, tensor<i32>) -> tensor<i1>
%6 = "tfl.less"(%arg1, %cst_3) : (tensor<i32>, tensor<i32>) -> tensor<i1>
%7 = tfl.logical_and %6, %5 : tensor<i1>
"tfl.yield"(%7) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<*xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x2xf64>, %arg6: tensor<*xf64>):
"tfl.yield"(%arg1, %arg2, %arg3, %arg4, %arg5, %cst_0) : (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<4x4x3xf64>) -> ()
}) {is_stateless = true} : (tensor<i32>, tensor<i32>, tensor<4x4x2xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<4x4x3xf64>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<*xf32>)
func.return %1#2 : tensor<*xf32>
}
// CHECK-LABEL: func @unsupportedCast(
// CHECK-LABEL: func private @tfl.while_body(
// CHECK-SAME: %arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<*xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf64>, %arg5: tensor<*xf64>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<*xf64>)
// CHECK: [[VAL:%.*]] = "tf.Cast"

View File

@ -79,8 +79,10 @@ bool IsAlreadyOutlined(WhileOp while_op) {
bool IsCompatibleTypeWithTFLCastOp(Type type) {
auto elemType = getElementTypeOrSelf(type);
// F32 and BF16 types are allowed.
if (elemType.isBF16() || elemType.isF32()) return true;
// F16, F32, F64, BF16 types are allowed.
if (elemType.isBF16() || elemType.isF16() || elemType.isF32() ||
elemType.isF64())
return true;
// I1, I8 I16, I32, I64 types are allowed.
if (elemType.isInteger(1) || elemType.isInteger(8) ||

View File

@ -171,7 +171,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
/* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
/* min_version = */ 1,
/* max_version = */ 5);

View File

@ -70,6 +70,42 @@ void copyCast(const std::complex<float>* in, std::complex<float>* out,
[](std::complex<float> a) { return a; });
}
template <typename ToT>
void copyCast(const Eigen::half* in, ToT* out, int num_elements) {
std::transform(in, in + num_elements, out, [](Eigen::half a) {
return static_cast<ToT>(Eigen::half_impl::half_to_float(a));
});
}
template <>
void copyCast(const Eigen::half* in, std::complex<float>* out,
int num_elements) {
std::transform(in, in + num_elements, out, [](Eigen::half a) {
return std::complex<float>(Eigen::half_impl::half_to_float(a));
});
}
template <typename FromT>
void copyCastToFloat16(const FromT* in, Eigen::half* out, int num_elements) {
std::transform(in, in + num_elements, out, [](FromT a) {
return Eigen::half_impl::float_to_half_rtne(static_cast<float>(a));
});
}
template <>
void copyCastToFloat16(const std::complex<float>* in, Eigen::half* out,
int num_elements) {
std::transform(in, in + num_elements, out, [](std::complex<float> a) {
return Eigen::half_impl::float_to_half_rtne(std::real(a));
});
}
template <>
void copyCastToFloat16(const Eigen::half* in, Eigen::half* out,
int num_elements) {
std::transform(in, in + num_elements, out, [](Eigen::half a) { return a; });
}
template <typename FromT>
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
TfLiteTensor* out, int num_elements) {
@ -95,9 +131,16 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
case kTfLiteInt8:
copyCast(in, out->data.int8, num_elements);
break;
case kTfLiteFloat16:
copyCastToFloat16(in, reinterpret_cast<Eigen::half*>(out->data.f16),
num_elements);
break;
case kTfLiteFloat32:
copyCast(in, GetTensorData<float>(out), num_elements);
break;
case kTfLiteFloat64:
copyCast(in, out->data.f64, num_elements);
break;
case kTfLiteBool:
copyCast(in, out->data.b, num_elements);
break;
@ -135,9 +178,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return copyToTensor(context, input->data.uint8, output, num_elements);
case kTfLiteInt8:
return copyToTensor(context, input->data.int8, output, num_elements);
case kTfLiteFloat16:
return copyToTensor(context,
reinterpret_cast<Eigen::half*>(input->data.f16),
output, num_elements);
case kTfLiteFloat32:
return copyToTensor(context, GetTensorData<float>(input), output,
num_elements);
case kTfLiteFloat64:
return copyToTensor(context, input->data.f64, output, num_elements);
case kTfLiteBool:
return copyToTensor(context, input->data.b, output, num_elements);
case kTfLiteComplex64:

View File

@ -228,6 +228,7 @@ cc_library(
hdrs = ["split.h"],
deps = [
"//tensorflow/lite:string",
"//third_party/eigen3",
],
)

View File

@ -22,51 +22,103 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
@register_make_test_function()
def make_cast_tests(options):
"""Generate examples for cast."""
test_parameters = [{
"input_dtype": [tf.float32],
"output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.int16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.int8],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.float32],
"output_dtype": [tf.int8],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.uint16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.uint32],
"output_dtype": [tf.int32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.uint8],
"output_dtype": [tf.int8],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.int8],
"output_dtype": [tf.uint8],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.uint16],
"output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.int16],
"output_dtype": [tf.uint16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
test_parameters = [
{
"input_dtype": [tf.float32],
"output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int8],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.float32],
"output_dtype": [tf.int8],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.uint16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.uint32],
"output_dtype": [tf.int32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.uint8],
"output_dtype": [tf.int8],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int8],
"output_dtype": [tf.uint8],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.uint16],
"output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int16],
"output_dtype": [tf.uint16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float64],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.float64],
"output_dtype": [tf.int32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.float32],
"output_dtype": [tf.float64],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.float64],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int64],
"output_dtype": [tf.float64],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.float64],
"output_dtype": [tf.int64],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.float32],
"output_dtype": [tf.float16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.float16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
]
def build_graph(parameters):
"""Build the cast testing graph."""

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/string_type.h"
namespace tflite {
@ -197,6 +198,17 @@ inline std::vector<std::complex<double>> Split(const string& s,
return fields;
}
template <>
inline std::vector<Eigen::half> Split(const string& s,
const string& delimiter) {
std::vector<Eigen::half> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.push_back(Eigen::half_impl::float_to_half_rtne(
strtof(s.data() + p.first, nullptr)));
}
return fields;
}
} // namespace testing
} // namespace tflite

View File

@ -376,6 +376,8 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
tensor);
case kTfLiteFloat64:
return TypedCheck<double, double>(verbose, tensor);
case kTfLiteFloat16:
return TypedCheck<Eigen::half, float>(verbose, tensor);
default:
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
return false;
@ -678,6 +680,15 @@ void TfLiteDriver::SetInput(const string& name, const string& csv_values) {
SetTensorData(values, tensor->data.raw);
break;
}
case kTfLiteFloat16: {
const auto& values = testing::Split<Eigen::half>(csv_values, ",");
for (auto k : values) {
TFLITE_LOG(INFO) << "input" << k;
}
if (!CheckSizes<Eigen::half>(tensor->bytes, values.size())) return;
SetTensorData(values, tensor->data.raw);
break;
}
default:
Invalidate(absl::StrCat("Unsupported tensor type ",
TfLiteTypeGetName(tensor->type),
@ -755,6 +766,9 @@ void TfLiteDriver::SetExpectation(const string& name,
case kTfLiteComplex128:
expected_output_[id]->SetData<std::complex<double>>(csv_values);
break;
case kTfLiteFloat16:
expected_output_[id]->SetData<Eigen::half>(csv_values);
break;
default:
Invalidate(absl::StrCat("Unsupported tensor type ",
TfLiteTypeGetName(tensor->type),

View File

@ -901,8 +901,13 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 2;
case BuiltinOperator_CAST:
if (op_sig.inputs.at(0).type == kTfLiteUInt16 ||
op_sig.outputs.at(0).type == kTfLiteUInt16) {
if (op_sig.inputs.at(0).type == kTfLiteFloat64 ||
op_sig.outputs.at(0).type == kTfLiteFloat64 ||
op_sig.inputs.at(0).type == kTfLiteFloat16 ||
op_sig.outputs.at(0).type == kTfLiteFloat16) {
return 5;
} else if (op_sig.inputs.at(0).type == kTfLiteUInt16 ||
op_sig.outputs.at(0).type == kTfLiteUInt16) {
return 4;
} else if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
op_sig.outputs.at(0).type == kTfLiteInt8) {

View File

@ -102,6 +102,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_CAST, 2}, "2.7.0"},
{{BuiltinOperator_CAST, 3}, "2.8.0"},
{{BuiltinOperator_CAST, 4}, "2.9.0"},
{{BuiltinOperator_CAST, 5}, "2.12.0"},
{{BuiltinOperator_CONCATENATION, 1}, "1.5.0"},
{{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
{{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},