mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Support i4 EmbeddingLookup in TFLite reference
PiperOrigin-RevId: 653782883
This commit is contained in:
parent
3b95472bb0
commit
87b00de034
|
|
@ -35,6 +35,7 @@
|
|||
* `Dequantize` op supports `TensorType_INT4`.
|
||||
* This change includes per-channel dequantization.
|
||||
* Add support for `stablehlo.composite`.
|
||||
* `EmbeddingLookup` op supports `TensorType_INT4` values.
|
||||
|
||||
## Keras
|
||||
|
||||
|
|
|
|||
|
|
@ -1646,7 +1646,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
|||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[I32]>:$lookup,
|
||||
TFL_TensorOf<[F32, I8, UI8]>:$value
|
||||
TFL_TensorOf<[F32, I8, UI8, QI4]>:$value
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I8, UI8]>:$output);
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
/* max_version = */ 4);
|
||||
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
|
||||
Register_EMBEDDING_LOOKUP_SPARSE());
|
||||
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
|
@ -62,7 +63,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
TF_LITE_ENSURE(context, qparams->zero_point != nullptr);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
if ((value->type == kTfLiteUInt8 || value->type == kTfLiteInt8) &&
|
||||
if ((value->type == kTfLiteUInt8 || value->type == kTfLiteInt8 ||
|
||||
value->type == kTfLiteInt4) &&
|
||||
(output->type == kTfLiteFloat32)) {
|
||||
// EvalHybrid supports only symmetric quantization for now.
|
||||
TF_LITE_ENSURE(context, qparams->zero_point->data[0] == 0);
|
||||
|
|
@ -70,7 +72,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
if (qparams->scale->size > 1 || qparams->zero_point->size > 1) {
|
||||
// Per-axis quantization is supported by EvalHybrid only.
|
||||
TF_LITE_ENSURE(context, value->type == kTfLiteUInt8 ||
|
||||
value->type == kTfLiteInt8);
|
||||
value->type == kTfLiteInt8 ||
|
||||
value->type == kTfLiteInt4);
|
||||
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
|
||||
// Per-axis quantization must have quantized_dimension == 0 and correct
|
||||
// sizes for scale and zero_point.
|
||||
|
|
@ -160,9 +163,21 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
|||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < col_size; j++) {
|
||||
output_ptr[j + i * col_size] =
|
||||
value_ptr[j + idx * col_size] * scaling_factor;
|
||||
if (value->type == kTfLiteInt4) {
|
||||
for (int j = 0; j < col_size; j++) {
|
||||
int i8_idx = j + idx * col_size;
|
||||
int i4_idx = i8_idx / 2;
|
||||
bool even = i8_idx % 2 == 0;
|
||||
int8_t i4_val = value_ptr[i4_idx];
|
||||
int8_t i8_val =
|
||||
even ? static_cast<int8_t>(i4_val << 4) >> 4 : i4_val >> 4;
|
||||
output_ptr[j + i * col_size] = i8_val * scaling_factor;
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < col_size; j++) {
|
||||
output_ptr[j + i * col_size] =
|
||||
value_ptr[j + idx * col_size] * scaling_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -180,6 +195,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
switch (value->type) {
|
||||
case kTfLiteFloat32:
|
||||
return EvalSimple(context, node, lookup, value, output);
|
||||
case kTfLiteInt4:
|
||||
return EvalHybrid(context, node, lookup, value, output);
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
|
|
|
|||
|
|
@ -352,5 +352,74 @@ TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple4DTestInt8) {
|
|||
kTestTolerance)));
|
||||
}
|
||||
|
||||
TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple2DTestInt4) {
|
||||
PerAxisHybridEmbeddingLookupOpModel m({3}, {3, 8}, {0.001, 0.02, 0.3},
|
||||
TensorType_INT4);
|
||||
m.SetInput({1, 0, 2});
|
||||
m.SetSignedWeight({
|
||||
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
|
||||
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
|
||||
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
|
||||
});
|
||||
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(
|
||||
m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
|
||||
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
|
||||
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
|
||||
},
|
||||
kTestTolerance)));
|
||||
}
|
||||
|
||||
TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple3DTestInt4) {
|
||||
PerAxisHybridEmbeddingLookupOpModel m({3}, {3, 2, 4}, {0.001, 0.02, 0.3},
|
||||
TensorType_INT4);
|
||||
m.SetInput({1, 0, 2});
|
||||
m.SetSignedWeight({
|
||||
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
|
||||
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
|
||||
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
|
||||
});
|
||||
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(
|
||||
m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
|
||||
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
|
||||
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
|
||||
},
|
||||
kTestTolerance)));
|
||||
}
|
||||
|
||||
TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple4DTestInt4) {
|
||||
PerAxisHybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}, {0.001, 0.02, 0.3},
|
||||
TensorType_INT4);
|
||||
m.SetInput({1, 0, 2});
|
||||
m.SetSignedWeight({
|
||||
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
|
||||
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
|
||||
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
|
||||
});
|
||||
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(
|
||||
m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
0.02, -0.02, 0.04, 0.06, 0.08, -0.04, -0.08, -0.06, // Row 1
|
||||
0.00, 0.007, 0.006, 0.005, 0.004, 0.003, 0.002, 0.001, // Row 0
|
||||
0.3, 0.6, 0.9, 1.2, 1.5, -0.3, -0.6, -0.9, // Row 2
|
||||
},
|
||||
kTestTolerance)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -149,6 +149,13 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
case BuiltinOperator_EMBEDDING_LOOKUP: {
|
||||
if (op_sig.inputs.at(1).type == kTfLiteInt4) {
|
||||
return 4;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
case BuiltinOperator_FAKE_QUANT: {
|
||||
auto fake_quant_params =
|
||||
reinterpret_cast<TfLiteFakeQuantParams*>(op_sig.builtin_data);
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_EMBEDDING_LOOKUP, 1}, "1.13.0"},
|
||||
{{BuiltinOperator_EMBEDDING_LOOKUP, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_EMBEDDING_LOOKUP, 3}, "1.14.0"},
|
||||
{{BuiltinOperator_EMBEDDING_LOOKUP, 4}, "2.18.0"},
|
||||
{{BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_FAKE_QUANT, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_FAKE_QUANT, 2}, "1.10.0"},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user