mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add int16x8 kernel support for equal and not_equal ops
PiperOrigin-RevId: 807869896
This commit is contained in:
parent
bcf149d89c
commit
e791c5217d
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
* `tf.lite`
|
||||
* Adds int8 and int16x8 support for SQRT operator.
|
||||
* Adds int16x8 support for EQUAL and NOT_EQUAL operators.
|
||||
|
||||
### Bug Fixes and Other Changes
|
||||
|
||||
|
|
|
|||
|
|
@ -1589,8 +1589,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
|||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs,
|
||||
TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs);
|
||||
ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, QI16, TFL_Quint8, TFL_Str]>:$lhs,
|
||||
TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, QI16, TFL_Quint8, TFL_Str]>:$rhs);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
|
|
@ -1729,8 +1729,8 @@ def TFL_EqualOp: TFL_Op<"equal", [
|
|||
|
||||
let arguments = (
|
||||
ins
|
||||
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$x,
|
||||
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$y
|
||||
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, QI16, UI8, TFL_Str]>:$x,
|
||||
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, QI16, UI8, TFL_Str]>:$y
|
||||
);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
|
|
|||
|
|
@ -1209,6 +1209,40 @@ func.func @testEqualInt16(tensor<? x i16>, tensor<? x i16>) -> tensor<? x i1> {
|
|||
func.return %0#0 : tensor<? x i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testEqualQuant
|
||||
func.func @testEqualQuant(%arg0: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, %arg1: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1> {
|
||||
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1>
|
||||
func.return %0 : tensor<1x80x1xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testEqualQuantWithQI16
|
||||
func.func @testEqualQuantWithQI16(%arg0: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, %arg1: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1> {
|
||||
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1>
|
||||
func.return %0 : tensor<1x80x1xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testNotEqual
|
||||
func.func @testNotEqual(tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1> {
|
||||
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
|
||||
// CHECK: tfl.not_equal(%arg0, %arg1)
|
||||
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1>
|
||||
func.return %0#0 : tensor<? x i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testNotEqualQuant
|
||||
func.func @testNotEqualQuant(%arg0: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, %arg1: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1> {
|
||||
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1>
|
||||
func.return %0 : tensor<1x80x1xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testNotEqualQuantWithQI16
|
||||
func.func @testNotEqualQuantWithQI16(%arg0: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, %arg1: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1> {
|
||||
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1>
|
||||
func.return %0 : tensor<1x80x1xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testPad
|
||||
|
|
|
|||
|
|
@ -789,7 +789,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||
case BuiltinOperator_EQUAL:
|
||||
if (!op_sig.inputs.empty()) {
|
||||
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
|
||||
return 4;
|
||||
return 5;
|
||||
}
|
||||
if (op_sig.inputs.at(0).type == kTfLiteString) {
|
||||
return 3;
|
||||
|
|
@ -801,6 +801,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||
return 1;
|
||||
case BuiltinOperator_NOT_EQUAL:
|
||||
if (!op_sig.inputs.empty()) {
|
||||
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
|
||||
return 4;
|
||||
}
|
||||
if (op_sig.inputs.at(0).type == kTfLiteString) {
|
||||
return 3;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -177,21 +177,35 @@ void SimpleOutputVersioningTest(BuiltinOperator op) {
|
|||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningEqualTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_EQUAL);
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_EQUAL,
|
||||
.inputs = CreateOpSignatureTensorSpecs(kTfLiteString),
|
||||
};
|
||||
OpSignature fake_op_sig = {};
|
||||
fake_op_sig.op = BuiltinOperator_EQUAL;
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteString);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningNotEqualTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_NOT_EQUAL,
|
||||
.inputs = CreateOpSignatureTensorSpecs(kTfLiteString),
|
||||
};
|
||||
OpSignature fake_op_sig = {};
|
||||
fake_op_sig.op = BuiltinOperator_NOT_EQUAL;
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteString);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningLessTest) {
|
||||
|
|
|
|||
|
|
@ -330,9 +330,11 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_EQUAL, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_EQUAL, 3}, "2.3.0"},
|
||||
{{BuiltinOperator_EQUAL, 4}, "2.13.0"},
|
||||
{{BuiltinOperator_EQUAL, 5}, "2.21.0"},
|
||||
{{BuiltinOperator_NOT_EQUAL, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_NOT_EQUAL, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_NOT_EQUAL, 3}, "2.3.0"},
|
||||
{{BuiltinOperator_NOT_EQUAL, 4}, "2.21.0"},
|
||||
{{BuiltinOperator_GREATER, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_GREATER, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_GREATER_EQUAL, 1}, "1.14.0"},
|
||||
|
|
|
|||
|
|
@ -246,10 +246,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
/* max_version = */ 5);
|
||||
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
/* max_version = */ 4);
|
||||
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
|
|
|
|||
|
|
@ -95,7 +95,8 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
|
|||
template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
|
||||
void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||
TfLiteTensor* output, bool requires_broadcast) {
|
||||
if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
|
||||
if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8 ||
|
||||
input1->type == kTfLiteInt16) {
|
||||
auto input1_offset = -input1->params.zero_point;
|
||||
auto input2_offset = -input2->params.zero_point;
|
||||
const int left_shift = 8;
|
||||
|
|
@ -180,8 +181,13 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
|||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
if (input1->quantization.type == kTfLiteNoQuantization) {
|
||||
Comparison<int16_t, reference_ops::EqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
} else {
|
||||
ComparisonQuantized<int16_t, reference_ops::EqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
Comparison<int32_t, reference_ops::EqualFn>(input1, input2, output,
|
||||
|
|
@ -204,9 +210,9 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
|||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
"Does not support type %d, requires bool|float|int|uint8|string",
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Does not support type %d, requires "
|
||||
"bool|float|int|uint8|int16|string",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
|
@ -249,14 +255,20 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
|||
ComparisonQuantized<int8_t, reference_ops::NotEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
if (input1->quantization.type != kTfLiteNoQuantization) {
|
||||
ComparisonQuantized<int16_t, reference_ops::NotEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
}
|
||||
break;
|
||||
case kTfLiteString:
|
||||
ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2,
|
||||
output, requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
"Does not support type %d, requires bool|float|int|uint8|string",
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Does not support type %d, requires "
|
||||
"bool|float|int|uint8|qint16|string",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#include <stdint.h>
|
||||
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -538,6 +539,19 @@ TEST(QuantizedComparisonsTest, EqualInt8Quantized) {
|
|||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
|
||||
}
|
||||
|
||||
TEST(QuantizedComparisonsTest, EqualInt16Quantized) {
|
||||
const float kMin = std::numeric_limits<int16_t>::min() + 1;
|
||||
const float kMax = std::numeric_limits<int16_t>::max();
|
||||
ComparisonOpModel model({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
|
||||
{TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
|
||||
TensorType_INT16, BuiltinOperator_EQUAL);
|
||||
model.QuantizeAndPopulate<int16_t>(model.input1(), {10, -90, 70, kMin});
|
||||
model.QuantizeAndPopulate<int16_t>(model.input2(), {10, 20, 71, kMin});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
|
||||
}
|
||||
|
||||
TEST(QuantizedComparisonsTest, NotEqualUInt8Quantized) {
|
||||
const float kMin = -1.f;
|
||||
const float kMax = 128.f;
|
||||
|
|
@ -564,6 +578,19 @@ TEST(QuantizedComparisonsTest, NotEqualInt8Quantized) {
|
|||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true));
|
||||
}
|
||||
|
||||
TEST(QuantizedComparisonsTest, NotEqualInt16Quantized) {
|
||||
const float kMin = std::numeric_limits<int16_t>::min() + 1;
|
||||
const float kMax = std::numeric_limits<int16_t>::max();
|
||||
ComparisonOpModel model({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
|
||||
{TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
|
||||
TensorType_INT16, BuiltinOperator_NOT_EQUAL);
|
||||
model.QuantizeAndPopulate<int16_t>(model.input1(), {10, -90, 70, kMin + 1});
|
||||
model.QuantizeAndPopulate<int16_t>(model.input2(), {10, 20, 71, kMin + 2});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterQuantized) {
|
||||
const float kMin = -1.f;
|
||||
const float kMax = 128.f;
|
||||
|
|
|
|||
|
|
@ -444,10 +444,10 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
/* max_version = */ 5);
|
||||
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
/* max_version = */ 4);
|
||||
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user