Support i4 EmbeddingLookup in TFLite reference

PiperOrigin-RevId: 653782883
This commit is contained in:
Pauline Sho 2024-07-18 15:58:10 -07:00 committed by TensorFlower Gardener
parent 3b95472bb0
commit 87b00de034
7 changed files with 102 additions and 7 deletions

View File

@ -35,6 +35,7 @@
* `Dequantize` op supports `TensorType_INT4`.
* This change includes per-channel dequantization.
* Add support for `stablehlo.composite`.
* `EmbeddingLookup` op supports `TensorType_INT4` values.
## Keras

View File

@ -1646,7 +1646,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
let arguments = (ins
TFL_TensorOf<[I32]>:$lookup,
TFL_TensorOf<[F32, I8, UI8]>:$value
TFL_TensorOf<[F32, I8, UI8, QI4]>:$value
);
let results = (outs TFL_TensorOf<[F32, I8, UI8]>:$output);

View File

@ -77,7 +77,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP(),
/* min_version = */ 1,
/* max_version = */ 3);
/* max_version = */ 4);
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
Register_EMBEDDING_LOOKUP_SPARSE());
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),

View File

@ -33,6 +33,7 @@ limitations under the License.
#include <cstring>
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -62,7 +63,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, qparams->zero_point != nullptr);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
if ((value->type == kTfLiteUInt8 || value->type == kTfLiteInt8) &&
if ((value->type == kTfLiteUInt8 || value->type == kTfLiteInt8 ||
value->type == kTfLiteInt4) &&
(output->type == kTfLiteFloat32)) {
// EvalHybrid supports only symmetric quantization for now.
TF_LITE_ENSURE(context, qparams->zero_point->data[0] == 0);
@ -70,7 +72,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (qparams->scale->size > 1 || qparams->zero_point->size > 1) {
// Per-axis quantization is supported by EvalHybrid only.
TF_LITE_ENSURE(context, value->type == kTfLiteUInt8 ||
value->type == kTfLiteInt8);
value->type == kTfLiteInt8 ||
value->type == kTfLiteInt4);
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
// Per-axis quantization must have quantized_dimension == 0 and correct
// sizes for scale and zero_point.
@ -160,9 +163,21 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
}
}
for (int j = 0; j < col_size; j++) {
output_ptr[j + i * col_size] =
value_ptr[j + idx * col_size] * scaling_factor;
if (value->type == kTfLiteInt4) {
for (int j = 0; j < col_size; j++) {
int i8_idx = j + idx * col_size;
int i4_idx = i8_idx / 2;
bool even = i8_idx % 2 == 0;
int8_t i4_val = value_ptr[i4_idx];
int8_t i8_val =
even ? static_cast<int8_t>(i4_val << 4) >> 4 : i4_val >> 4;
output_ptr[j + i * col_size] = i8_val * scaling_factor;
}
} else {
for (int j = 0; j < col_size; j++) {
output_ptr[j + i * col_size] =
value_ptr[j + idx * col_size] * scaling_factor;
}
}
}
}
@ -180,6 +195,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (value->type) {
case kTfLiteFloat32:
return EvalSimple(context, node, lookup, value, output);
case kTfLiteInt4:
return EvalHybrid(context, node, lookup, value, output);
case kTfLiteUInt8:
case kTfLiteInt8:
if (output->type == kTfLiteFloat32) {

View File

@ -352,5 +352,74 @@ TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple4DTestInt8) {
kTestTolerance)));
}
TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple2DTestInt4) {
PerAxisHybridEmbeddingLookupOpModel m({3}, {3, 8}, {0.001, 0.02, 0.3},
TensorType_INT4);
m.SetInput({1, 0, 2});
m.SetSignedWeight({
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
},
kTestTolerance)));
}
TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple3DTestInt4) {
PerAxisHybridEmbeddingLookupOpModel m({3}, {3, 2, 4}, {0.001, 0.02, 0.3},
TensorType_INT4);
m.SetInput({1, 0, 2});
m.SetSignedWeight({
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
},
kTestTolerance)));
}
TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple4DTestInt4) {
PerAxisHybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}, {0.001, 0.02, 0.3},
TensorType_INT4);
m.SetInput({1, 0, 2});
m.SetSignedWeight({
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
},
kTestTolerance)));
}
} // namespace
} // namespace tflite

View File

@ -149,6 +149,13 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
}
case BuiltinOperator_EMBEDDING_LOOKUP: {
if (op_sig.inputs.at(1).type == kTfLiteInt4) {
return 4;
}
return 1;
}
case BuiltinOperator_FAKE_QUANT: {
auto fake_quant_params =
reinterpret_cast<TfLiteFakeQuantParams*>(op_sig.builtin_data);

View File

@ -118,6 +118,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_EMBEDDING_LOOKUP, 1}, "1.13.0"},
{{BuiltinOperator_EMBEDDING_LOOKUP, 2}, "1.14.0"},
{{BuiltinOperator_EMBEDDING_LOOKUP, 3}, "1.14.0"},
{{BuiltinOperator_EMBEDDING_LOOKUP, 4}, "2.18.0"},
{{BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, 1}, "1.5.0"},
{{BuiltinOperator_FAKE_QUANT, 1}, "1.5.0"},
{{BuiltinOperator_FAKE_QUANT, 2}, "1.10.0"},