Support uint32 cast for tflite

PiperOrigin-RevId: 394943535
Change-Id: I98cc78ad52898b2a5704a24811da25bd4c288e87
This commit is contained in:
Renjie Liu 2021-09-05 03:53:32 -07:00 committed by TensorFlower Gardener
parent cfcb091c4e
commit c38e12c163
10 changed files with 56 additions and 21 deletions

View File

@ -42,6 +42,7 @@
for the migration.
* Add experimental API `experimental_from_jax` to support conversion from Jax
models to TensorFlow Lite.
* Support uint32 data type for cast op.
* TF Core:
* `tf.Graph.get_name_scope()` now always returns a string, as documented.

View File

@ -3575,10 +3575,10 @@ def TFL_CastOp : TFL_Op<"cast", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
TFL_TensorOf<[F32, I1, I16, I32, UI32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
);
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, UI32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.

View File

@ -80,6 +80,9 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
case kTfLiteInt32:
copyCast(in, out->data.i32, num_elements);
break;
case kTfLiteUInt32:
copyCast(in, out->data.u32, num_elements);
break;
case kTfLiteInt16:
copyCast(in, out->data.i16, num_elements);
break;
@ -116,6 +119,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return copyToTensor(context, input->data.i64, output, num_elements);
case kTfLiteInt32:
return copyToTensor(context, input->data.i32, output, num_elements);
case kTfLiteUInt32:
return copyToTensor(context, input->data.u32, output, num_elements);
case kTfLiteInt16:
return copyToTensor(context, input->data.i16, output, num_elements);
case kTfLiteUInt8:

View File

@ -215,5 +215,21 @@ TEST(CastOpModel, CastComplex64ToComplex64) {
std::complex<float>(6.0f, 16.0f)}));
}
TEST(CastOpModel, CastUInt32ToInt32) {
CastOpModel m({TensorType_UINT32, {2, 3}}, {TensorType_INT32, {2, 3}});
m.PopulateTensor<uint32_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
ElementsAreArray({100, 200, 300, 400, 500, 600}));
}
TEST(CastOpModel, CastInt32ToUInt32) {
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_UINT32, {2, 3}});
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<uint32_t>(m.output()),
ElementsAreArray({100, 200, 300, 400, 500, 600}));
}
} // namespace
} // namespace tflite

View File

@ -168,7 +168,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
/* min_version = */ 1,
/* max_version = */ 5);

View File

@ -336,7 +336,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE_REF(),
/* min_version = */ 1,
/* max_version = */ 4);

View File

@ -27,22 +27,23 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
def make_cast_tests(options):
"""Generate examples for cast."""
if options.use_experimental_converter:
test_parameters = [
{
"input_dtype": [tf.float32],
"output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
test_parameters = [{
"input_dtype": [tf.float32],
"output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.int16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}, {
"input_dtype": [tf.uint32],
"output_dtype": [tf.int32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
else:
test_parameters = [
{

View File

@ -77,6 +77,7 @@ TF_TYPE_INFO = {
tf.float16: (np.float16, "FLOAT"),
tf.float64: (np.float64, "FLOAT64"),
tf.int32: (np.int32, "INT32"),
tf.uint32: (np.uint32, "UINT32"),
tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
tf.int16: (np.int16, "QUANTIZED_INT16"),
tf.int64: (np.int64, "INT64"),
@ -115,7 +116,7 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
real = (max_value - min_value) * np.random.random_sample(shape) + min_value
imag = (max_value - min_value) * np.random.random_sample(shape) + min_value
value = real + imag * 1j
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
elif dtype in (tf.uint32, tf.int32, tf.uint8, tf.int64, tf.int16):
value = np.random.randint(min_value, max_value + 1, shape)
elif dtype == tf.bool:
value = np.random.choice([True, False], size=shape)

View File

@ -788,6 +788,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 3;
}
return 2;
case BuiltinOperator_CAST:
if (op_sig.inputs.at(0).type == kTfLiteUInt32 ||
op_sig.outputs.at(0).type == kTfLiteUInt32) {
return 2;
}
return 1;
default:
return 1;
}

View File

@ -94,6 +94,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_BATCH_TO_SPACE_ND, 2}, "1.14.0"},
{{BuiltinOperator_BATCH_TO_SPACE_ND, 3}, "2.3.0"},
{{BuiltinOperator_CAST, 1}, "1.5.0"},
{{BuiltinOperator_CAST, 2}, "2.7.0"},
{{BuiltinOperator_CONCATENATION, 1}, "1.5.0"},
{{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
{{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},