Add int16x8 kernel support for equal and not_equal ops

PiperOrigin-RevId: 807869896
This commit is contained in:
Maria Lyubimtseva 2025-09-16 15:19:27 -07:00 committed by TensorFlower Gardener
parent bcf149d89c
commit e791c5217d
10 changed files with 123 additions and 30 deletions

View File

@ -21,6 +21,7 @@
* `tf.lite` * `tf.lite`
* Adds int8 and int16x8 support for SQRT operator. * Adds int8 and int16x8 support for SQRT operator.
* Adds int16x8 support for EQUAL and NOT_EQUAL operators.
### Bug Fixes and Other Changes ### Bug Fixes and Other Changes

View File

@ -1589,8 +1589,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs, ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, QI16, TFL_Quint8, TFL_Str]>:$lhs,
TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs); TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, QI16, TFL_Quint8, TFL_Str]>:$rhs);
let results = (outs TFL_BoolTensor:$output); let results = (outs TFL_BoolTensor:$output);
@ -1729,8 +1729,8 @@ def TFL_EqualOp: TFL_Op<"equal", [
let arguments = ( let arguments = (
ins ins
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$x, TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, QI16, UI8, TFL_Str]>:$x,
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$y TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, QI16, UI8, TFL_Str]>:$y
); );
let results = (outs TFL_BoolTensor:$output); let results = (outs TFL_BoolTensor:$output);

View File

@ -1209,6 +1209,40 @@ func.func @testEqualInt16(tensor<? x i16>, tensor<? x i16>) -> tensor<? x i1> {
func.return %0#0 : tensor<? x i1> func.return %0#0 : tensor<? x i1>
} }
// CHECK-LABEL: testEqualQuant
func.func @testEqualQuant(%arg0: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, %arg1: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1> {
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1>
func.return %0 : tensor<1x80x1xi1>
}
// CHECK-LABEL: testEqualQuantWithQI16
func.func @testEqualQuantWithQI16(%arg0: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, %arg1: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1> {
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1>
func.return %0 : tensor<1x80x1xi1>
}
// -----
// CHECK-LABEL: testNotEqual
func.func @testNotEqual(tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
// CHECK: tfl.not_equal(%arg0, %arg1)
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1>
func.return %0#0 : tensor<? x i1>
}
// CHECK-LABEL: testNotEqualQuant
func.func @testNotEqualQuant(%arg0: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, %arg1: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1> {
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1>
func.return %0 : tensor<1x80x1xi1>
}
// CHECK-LABEL: testNotEqualQuantWithQI16
func.func @testNotEqualQuantWithQI16(%arg0: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, %arg1: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1> {
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1>
func.return %0 : tensor<1x80x1xi1>
}
// ----- // -----
// CHECK-LABEL: testPad // CHECK-LABEL: testPad

View File

@ -789,7 +789,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_EQUAL: case BuiltinOperator_EQUAL:
if (!op_sig.inputs.empty()) { if (!op_sig.inputs.empty()) {
if (op_sig.inputs.at(0).type == kTfLiteInt16) { if (op_sig.inputs.at(0).type == kTfLiteInt16) {
return 4; return 5;
} }
if (op_sig.inputs.at(0).type == kTfLiteString) { if (op_sig.inputs.at(0).type == kTfLiteString) {
return 3; return 3;
@ -801,6 +801,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1; return 1;
case BuiltinOperator_NOT_EQUAL: case BuiltinOperator_NOT_EQUAL:
if (!op_sig.inputs.empty()) { if (!op_sig.inputs.empty()) {
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
return 4;
}
if (op_sig.inputs.at(0).type == kTfLiteString) { if (op_sig.inputs.at(0).type == kTfLiteString) {
return 3; return 3;
} }

View File

@ -177,21 +177,35 @@ void SimpleOutputVersioningTest(BuiltinOperator op) {
} }
TEST(OpVersionTest, VersioningEqualTest) { TEST(OpVersionTest, VersioningEqualTest) {
SimpleVersioningTest(BuiltinOperator_EQUAL); OpSignature fake_op_sig = {};
OpSignature fake_op_sig = { fake_op_sig.op = BuiltinOperator_EQUAL;
.op = BuiltinOperator_EQUAL, fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
.inputs = CreateOpSignatureTensorSpecs(kTfLiteString), EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
};
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteString);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
} }
TEST(OpVersionTest, VersioningNotEqualTest) { TEST(OpVersionTest, VersioningNotEqualTest) {
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL); OpSignature fake_op_sig = {};
OpSignature fake_op_sig = { fake_op_sig.op = BuiltinOperator_NOT_EQUAL;
.op = BuiltinOperator_NOT_EQUAL, fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
.inputs = CreateOpSignatureTensorSpecs(kTfLiteString), EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
};
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteString);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
} }
TEST(OpVersionTest, VersioningLessTest) { TEST(OpVersionTest, VersioningLessTest) {

View File

@ -330,9 +330,11 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_EQUAL, 2}, "1.14.0"}, {{BuiltinOperator_EQUAL, 2}, "1.14.0"},
{{BuiltinOperator_EQUAL, 3}, "2.3.0"}, {{BuiltinOperator_EQUAL, 3}, "2.3.0"},
{{BuiltinOperator_EQUAL, 4}, "2.13.0"}, {{BuiltinOperator_EQUAL, 4}, "2.13.0"},
{{BuiltinOperator_EQUAL, 5}, "2.21.0"},
{{BuiltinOperator_NOT_EQUAL, 1}, "1.14.0"}, {{BuiltinOperator_NOT_EQUAL, 1}, "1.14.0"},
{{BuiltinOperator_NOT_EQUAL, 2}, "1.14.0"}, {{BuiltinOperator_NOT_EQUAL, 2}, "1.14.0"},
{{BuiltinOperator_NOT_EQUAL, 3}, "2.3.0"}, {{BuiltinOperator_NOT_EQUAL, 3}, "2.3.0"},
{{BuiltinOperator_NOT_EQUAL, 4}, "2.21.0"},
{{BuiltinOperator_GREATER, 1}, "1.14.0"}, {{BuiltinOperator_GREATER, 1}, "1.14.0"},
{{BuiltinOperator_GREATER, 2}, "1.14.0"}, {{BuiltinOperator_GREATER, 2}, "1.14.0"},
{{BuiltinOperator_GREATER_EQUAL, 1}, "1.14.0"}, {{BuiltinOperator_GREATER_EQUAL, 1}, "1.14.0"},

View File

@ -246,10 +246,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 3); /* max_version = */ 3);
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(), AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
/* min_version = */ 1, /* min_version = */ 1,
/* max_version = */ 4); /* max_version = */ 5);
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(), AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
/* min_version = */ 1, /* min_version = */ 1,
/* max_version = */ 3); /* max_version = */ 4);
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(), AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(),
/* min_version = */ 1, /* min_version = */ 1,
/* max_version = */ 2); /* max_version = */ 2);

View File

@ -95,7 +95,8 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
template <typename input_dtype, reference_ops::ComparisonFn<int32> opname> template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2, void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output, bool requires_broadcast) { TfLiteTensor* output, bool requires_broadcast) {
if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) { if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8 ||
input1->type == kTfLiteInt16) {
auto input1_offset = -input1->params.zero_point; auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point; auto input2_offset = -input2->params.zero_point;
const int left_shift = 8; const int left_shift = 8;
@ -180,8 +181,13 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
requires_broadcast); requires_broadcast);
break; break;
case kTfLiteInt16: case kTfLiteInt16:
if (input1->quantization.type == kTfLiteNoQuantization) {
Comparison<int16_t, reference_ops::EqualFn>(input1, input2, output, Comparison<int16_t, reference_ops::EqualFn>(input1, input2, output,
requires_broadcast); requires_broadcast);
} else {
ComparisonQuantized<int16_t, reference_ops::EqualFn>(
input1, input2, output, requires_broadcast);
}
break; break;
case kTfLiteInt32: case kTfLiteInt32:
Comparison<int32_t, reference_ops::EqualFn>(input1, input2, output, Comparison<int32_t, reference_ops::EqualFn>(input1, input2, output,
@ -204,9 +210,9 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
requires_broadcast); requires_broadcast);
break; break;
default: default:
TF_LITE_KERNEL_LOG( TF_LITE_KERNEL_LOG(context,
context, "Does not support type %d, requires "
"Does not support type %d, requires bool|float|int|uint8|string", "bool|float|int|uint8|int16|string",
input1->type); input1->type);
return kTfLiteError; return kTfLiteError;
} }
@ -249,14 +255,20 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
ComparisonQuantized<int8_t, reference_ops::NotEqualFn>( ComparisonQuantized<int8_t, reference_ops::NotEqualFn>(
input1, input2, output, requires_broadcast); input1, input2, output, requires_broadcast);
break; break;
case kTfLiteInt16:
if (input1->quantization.type != kTfLiteNoQuantization) {
ComparisonQuantized<int16_t, reference_ops::NotEqualFn>(
input1, input2, output, requires_broadcast);
}
break;
case kTfLiteString: case kTfLiteString:
ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2, ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2,
output, requires_broadcast); output, requires_broadcast);
break; break;
default: default:
TF_LITE_KERNEL_LOG( TF_LITE_KERNEL_LOG(context,
context, "Does not support type %d, requires "
"Does not support type %d, requires bool|float|int|uint8|string", "bool|float|int|uint8|qint16|string",
input1->type); input1->type);
return kTfLiteError; return kTfLiteError;
} }

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include <initializer_list> #include <initializer_list>
#include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
@ -538,6 +539,19 @@ TEST(QuantizedComparisonsTest, EqualInt8Quantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
} }
TEST(QuantizedComparisonsTest, EqualInt16Quantized) {
const float kMin = std::numeric_limits<int16_t>::min() + 1;
const float kMax = std::numeric_limits<int16_t>::max();
ComparisonOpModel model({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
{TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
TensorType_INT16, BuiltinOperator_EQUAL);
model.QuantizeAndPopulate<int16_t>(model.input1(), {10, -90, 70, kMin});
model.QuantizeAndPopulate<int16_t>(model.input2(), {10, 20, 71, kMin});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
}
TEST(QuantizedComparisonsTest, NotEqualUInt8Quantized) { TEST(QuantizedComparisonsTest, NotEqualUInt8Quantized) {
const float kMin = -1.f; const float kMin = -1.f;
const float kMax = 128.f; const float kMax = 128.f;
@ -564,6 +578,19 @@ TEST(QuantizedComparisonsTest, NotEqualInt8Quantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true)); EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true));
} }
TEST(QuantizedComparisonsTest, NotEqualInt16Quantized) {
const float kMin = std::numeric_limits<int16_t>::min() + 1;
const float kMax = std::numeric_limits<int16_t>::max();
ComparisonOpModel model({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
{TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
TensorType_INT16, BuiltinOperator_NOT_EQUAL);
model.QuantizeAndPopulate<int16_t>(model.input1(), {10, -90, 70, kMin + 1});
model.QuantizeAndPopulate<int16_t>(model.input2(), {10, 20, 71, kMin + 2});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
}
TEST(ComparisonsTest, GreaterQuantized) { TEST(ComparisonsTest, GreaterQuantized) {
const float kMin = -1.f; const float kMin = -1.f;
const float kMax = 128.f; const float kMax = 128.f;

View File

@ -444,10 +444,10 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
/* max_version = */ 3); /* max_version = */ 3);
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(), AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
/* min_version = */ 1, /* min_version = */ 1,
/* max_version = */ 4); /* max_version = */ 5);
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(), AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
/* min_version = */ 1, /* min_version = */ 1,
/* max_version = */ 3); /* max_version = */ 4);
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(), AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(),
/* min_version = */ 1, /* min_version = */ 1,
/* max_version = */ 2); /* max_version = */ 2);