mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Support uint32 cast for tflite
PiperOrigin-RevId: 394943535 Change-Id: I98cc78ad52898b2a5704a24811da25bd4c288e87
This commit is contained in:
parent
cfcb091c4e
commit
c38e12c163
|
|
@ -42,6 +42,7 @@
|
|||
for the migration.
|
||||
* Add experimental API `experimental_from_jax` to support conversion from Jax
|
||||
models to TensorFlow Lite.
|
||||
* Support uint32 data type for cast op.
|
||||
|
||||
* TF Core:
|
||||
* `tf.Graph.get_name_scope()` now always returns a string, as documented.
|
||||
|
|
|
|||
|
|
@ -3575,10 +3575,10 @@ def TFL_CastOp : TFL_Op<"cast", [
|
|||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
|
||||
TFL_TensorOf<[F32, I1, I16, I32, UI32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, UI32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
|
||||
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
|
|
|
|||
|
|
@ -80,6 +80,9 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
|||
case kTfLiteInt32:
|
||||
copyCast(in, out->data.i32, num_elements);
|
||||
break;
|
||||
case kTfLiteUInt32:
|
||||
copyCast(in, out->data.u32, num_elements);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
copyCast(in, out->data.i16, num_elements);
|
||||
break;
|
||||
|
|
@ -116,6 +119,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
return copyToTensor(context, input->data.i64, output, num_elements);
|
||||
case kTfLiteInt32:
|
||||
return copyToTensor(context, input->data.i32, output, num_elements);
|
||||
case kTfLiteUInt32:
|
||||
return copyToTensor(context, input->data.u32, output, num_elements);
|
||||
case kTfLiteInt16:
|
||||
return copyToTensor(context, input->data.i16, output, num_elements);
|
||||
case kTfLiteUInt8:
|
||||
|
|
|
|||
|
|
@ -215,5 +215,21 @@ TEST(CastOpModel, CastComplex64ToComplex64) {
|
|||
std::complex<float>(6.0f, 16.0f)}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastUInt32ToInt32) {
|
||||
CastOpModel m({TensorType_UINT32, {2, 3}}, {TensorType_INT32, {2, 3}});
|
||||
m.PopulateTensor<uint32_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
|
||||
ElementsAreArray({100, 200, 300, 400, 500, 600}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt32ToUInt32) {
|
||||
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_UINT32, {2, 3}});
|
||||
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<uint32_t>(m.output()),
|
||||
ElementsAreArray({100, 200, 300, 400, 500, 600}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -168,7 +168,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 5);
|
||||
|
|
|
|||
|
|
@ -336,7 +336,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE_REF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
|
|
|
|||
|
|
@ -27,22 +27,23 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
|||
def make_cast_tests(options):
|
||||
"""Generate examples for cast."""
|
||||
if options.use_experimental_converter:
|
||||
test_parameters = [
|
||||
{
|
||||
"input_dtype": [tf.float32],
|
||||
"output_dtype": [tf.int16],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.int16],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.int32],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}]
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32],
|
||||
"output_dtype": [tf.int16],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}, {
|
||||
"input_dtype": [tf.int16],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}, {
|
||||
"input_dtype": [tf.int32],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}, {
|
||||
"input_dtype": [tf.uint32],
|
||||
"output_dtype": [tf.int32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}]
|
||||
else:
|
||||
test_parameters = [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ TF_TYPE_INFO = {
|
|||
tf.float16: (np.float16, "FLOAT"),
|
||||
tf.float64: (np.float64, "FLOAT64"),
|
||||
tf.int32: (np.int32, "INT32"),
|
||||
tf.uint32: (np.uint32, "UINT32"),
|
||||
tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
|
||||
tf.int16: (np.int16, "QUANTIZED_INT16"),
|
||||
tf.int64: (np.int64, "INT64"),
|
||||
|
|
@ -115,7 +116,7 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
|
|||
real = (max_value - min_value) * np.random.random_sample(shape) + min_value
|
||||
imag = (max_value - min_value) * np.random.random_sample(shape) + min_value
|
||||
value = real + imag * 1j
|
||||
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
|
||||
elif dtype in (tf.uint32, tf.int32, tf.uint8, tf.int64, tf.int16):
|
||||
value = np.random.randint(min_value, max_value + 1, shape)
|
||||
elif dtype == tf.bool:
|
||||
value = np.random.choice([True, False], size=shape)
|
||||
|
|
|
|||
|
|
@ -788,6 +788,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||
return 3;
|
||||
}
|
||||
return 2;
|
||||
case BuiltinOperator_CAST:
|
||||
if (op_sig.inputs.at(0).type == kTfLiteUInt32 ||
|
||||
op_sig.outputs.at(0).type == kTfLiteUInt32) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -94,6 +94,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_BATCH_TO_SPACE_ND, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_BATCH_TO_SPACE_ND, 3}, "2.3.0"},
|
||||
{{BuiltinOperator_CAST, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_CAST, 2}, "2.7.0"},
|
||||
{{BuiltinOperator_CONCATENATION, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user