Add int16x8 kernel support for equal and not_equal ops

PiperOrigin-RevId: 807869896
This commit is contained in:
Maria Lyubimtseva 2025-09-16 15:19:27 -07:00 committed by TensorFlower Gardener
parent bcf149d89c
commit e791c5217d
10 changed files with 123 additions and 30 deletions

View File

@ -21,6 +21,7 @@
* `tf.lite`
* Adds int8 and int16x8 support for SQRT operator.
* Adds int16x8 support for EQUAL and NOT_EQUAL operators.
### Bug Fixes and Other Changes

View File

@ -1589,8 +1589,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
}];
let arguments = (
ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs,
TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs);
ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, QI16, TFL_Quint8, TFL_Str]>:$lhs,
TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, QI16, TFL_Quint8, TFL_Str]>:$rhs);
let results = (outs TFL_BoolTensor:$output);
@ -1729,8 +1729,8 @@ def TFL_EqualOp: TFL_Op<"equal", [
let arguments = (
ins
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$x,
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$y
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, QI16, UI8, TFL_Str]>:$x,
TFL_TensorOf<[I1, F32, I16, I32, I64, QI8, QUI8, QI16, UI8, TFL_Str]>:$y
);
let results = (outs TFL_BoolTensor:$output);

View File

@ -1209,6 +1209,40 @@ func.func @testEqualInt16(tensor<? x i16>, tensor<? x i16>) -> tensor<? x i1> {
func.return %0#0 : tensor<? x i1>
}
// CHECK-LABEL: testEqualQuant
func.func @testEqualQuant(%arg0: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, %arg1: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1> {
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1>
func.return %0 : tensor<1x80x1xi1>
}
// CHECK-LABEL: testEqualQuantWithQI16
func.func @testEqualQuantWithQI16(%arg0: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, %arg1: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1> {
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1>
func.return %0 : tensor<1x80x1xi1>
}
// -----
// CHECK-LABEL: testNotEqual
func.func @testNotEqual(tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
// CHECK: tfl.not_equal(%arg0, %arg1)
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1>
func.return %0#0 : tensor<? x i1>
}
// CHECK-LABEL: testNotEqualQuant
func.func @testNotEqualQuant(%arg0: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, %arg1: tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1> {
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>, tensor<1x80x1x!quant.uniform<i8:f32, 0.04:-128>>) -> tensor<1x80x1xi1>
func.return %0 : tensor<1x80x1xi1>
}
// CHECK-LABEL: testNotEqualQuantWithQI16
func.func @testNotEqualQuantWithQI16(%arg0: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, %arg1: tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1> {
%0 = "tfl.not_equal"(%arg0, %arg1) : (tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>, tensor<1x80x1x!quant.uniform<i16:f32, 0.04:0>>) -> tensor<1x80x1xi1>
func.return %0 : tensor<1x80x1xi1>
}
// -----
// CHECK-LABEL: testPad

View File

@ -789,7 +789,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_EQUAL:
if (!op_sig.inputs.empty()) {
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
return 4;
return 5;
}
if (op_sig.inputs.at(0).type == kTfLiteString) {
return 3;
@ -801,6 +801,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_NOT_EQUAL:
if (!op_sig.inputs.empty()) {
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
return 4;
}
if (op_sig.inputs.at(0).type == kTfLiteString) {
return 3;
}

View File

@ -177,21 +177,35 @@ void SimpleOutputVersioningTest(BuiltinOperator op) {
}
TEST(OpVersionTest, VersioningEqualTest) {
SimpleVersioningTest(BuiltinOperator_EQUAL);
OpSignature fake_op_sig = {
.op = BuiltinOperator_EQUAL,
.inputs = CreateOpSignatureTensorSpecs(kTfLiteString),
};
OpSignature fake_op_sig = {};
fake_op_sig.op = BuiltinOperator_EQUAL;
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteString);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
}
TEST(OpVersionTest, VersioningNotEqualTest) {
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
OpSignature fake_op_sig = {
.op = BuiltinOperator_NOT_EQUAL,
.inputs = CreateOpSignatureTensorSpecs(kTfLiteString),
};
OpSignature fake_op_sig = {};
fake_op_sig.op = BuiltinOperator_NOT_EQUAL;
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteString);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
}
TEST(OpVersionTest, VersioningLessTest) {

View File

@ -330,9 +330,11 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_EQUAL, 2}, "1.14.0"},
{{BuiltinOperator_EQUAL, 3}, "2.3.0"},
{{BuiltinOperator_EQUAL, 4}, "2.13.0"},
{{BuiltinOperator_EQUAL, 5}, "2.21.0"},
{{BuiltinOperator_NOT_EQUAL, 1}, "1.14.0"},
{{BuiltinOperator_NOT_EQUAL, 2}, "1.14.0"},
{{BuiltinOperator_NOT_EQUAL, 3}, "2.3.0"},
{{BuiltinOperator_NOT_EQUAL, 4}, "2.21.0"},
{{BuiltinOperator_GREATER, 1}, "1.14.0"},
{{BuiltinOperator_GREATER, 2}, "1.14.0"},
{{BuiltinOperator_GREATER_EQUAL, 1}, "1.14.0"},

View File

@ -246,10 +246,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
/* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
/* min_version = */ 1,
/* max_version = */ 3);
/* max_version = */ 4);
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(),
/* min_version = */ 1,
/* max_version = */ 2);

View File

@ -95,7 +95,8 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output, bool requires_broadcast) {
if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8 ||
input1->type == kTfLiteInt16) {
auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point;
const int left_shift = 8;
@ -180,8 +181,13 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
requires_broadcast);
break;
case kTfLiteInt16:
Comparison<int16_t, reference_ops::EqualFn>(input1, input2, output,
requires_broadcast);
if (input1->quantization.type == kTfLiteNoQuantization) {
Comparison<int16_t, reference_ops::EqualFn>(input1, input2, output,
requires_broadcast);
} else {
ComparisonQuantized<int16_t, reference_ops::EqualFn>(
input1, input2, output, requires_broadcast);
}
break;
case kTfLiteInt32:
Comparison<int32_t, reference_ops::EqualFn>(input1, input2, output,
@ -204,10 +210,10 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
requires_broadcast);
break;
default:
TF_LITE_KERNEL_LOG(
context,
"Does not support type %d, requires bool|float|int|uint8|string",
input1->type);
TF_LITE_KERNEL_LOG(context,
"Does not support type %d, requires "
"bool|float|int|uint8|int16|string",
input1->type);
return kTfLiteError;
}
return kTfLiteOk;
@ -249,15 +255,21 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
ComparisonQuantized<int8_t, reference_ops::NotEqualFn>(
input1, input2, output, requires_broadcast);
break;
case kTfLiteInt16:
if (input1->quantization.type != kTfLiteNoQuantization) {
ComparisonQuantized<int16_t, reference_ops::NotEqualFn>(
input1, input2, output, requires_broadcast);
}
break;
case kTfLiteString:
ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2,
output, requires_broadcast);
break;
default:
TF_LITE_KERNEL_LOG(
context,
"Does not support type %d, requires bool|float|int|uint8|string",
input1->type);
TF_LITE_KERNEL_LOG(context,
"Does not support type %d, requires "
"bool|float|int|uint8|qint16|string",
input1->type);
return kTfLiteError;
}
return kTfLiteOk;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <stdint.h>
#include <initializer_list>
#include <limits>
#include <string>
#include <vector>
@ -538,6 +539,19 @@ TEST(QuantizedComparisonsTest, EqualInt8Quantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
}
TEST(QuantizedComparisonsTest, EqualInt16Quantized) {
const float kMin = std::numeric_limits<int16_t>::min() + 1;
const float kMax = std::numeric_limits<int16_t>::max();
ComparisonOpModel model({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
{TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
TensorType_INT16, BuiltinOperator_EQUAL);
model.QuantizeAndPopulate<int16_t>(model.input1(), {10, -90, 70, kMin});
model.QuantizeAndPopulate<int16_t>(model.input2(), {10, 20, 71, kMin});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
}
TEST(QuantizedComparisonsTest, NotEqualUInt8Quantized) {
const float kMin = -1.f;
const float kMax = 128.f;
@ -564,6 +578,19 @@ TEST(QuantizedComparisonsTest, NotEqualInt8Quantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true));
}
TEST(QuantizedComparisonsTest, NotEqualInt16Quantized) {
const float kMin = std::numeric_limits<int16_t>::min() + 1;
const float kMax = std::numeric_limits<int16_t>::max();
ComparisonOpModel model({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
{TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
TensorType_INT16, BuiltinOperator_NOT_EQUAL);
model.QuantizeAndPopulate<int16_t>(model.input1(), {10, -90, 70, kMin + 1});
model.QuantizeAndPopulate<int16_t>(model.input2(), {10, 20, 71, kMin + 2});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
}
TEST(ComparisonsTest, GreaterQuantized) {
const float kMin = -1.f;
const float kMax = 128.f;

View File

@ -444,10 +444,10 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
/* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
/* min_version = */ 1,
/* max_version = */ 3);
/* max_version = */ 4);
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT(),
/* min_version = */ 1,
/* max_version = */ 2);