Add i4 support in tfl.slice

PiperOrigin-RevId: 825217744
This commit is contained in:
Majid Dadashi 2025-10-28 14:58:07 -07:00 committed by TensorFlower Gardener
parent 202bd1ac59
commit 2feb74eeff
11 changed files with 170 additions and 6 deletions

View File

@ -23,7 +23,9 @@
* Adds int8 and int16x8 support for SQRT operator.
* Adds int16x8 support for EQUAL and NOT_EQUAL operators.
* Adds support for int2 type.
* Adds support for int2/int4 in tfl.cast.
* Adds support for int2/int4 in tfl.cast .
* Adds support for SRQ int2 in tfl.fully_connected.
* Adds support for int4 in tfl.slice.
### Bug Fixes and Other Changes

View File

@ -2477,13 +2477,13 @@ equivalent to setting:
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32OrI64Tensor:$begin,
TFL_I32OrI64Tensor:$size
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
);
let hasVerifier = 1;

View File

@ -468,6 +468,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_SLICE:
if (op_sig.inputs.at(0).type == kTfLiteInt4) {
return 7;
}
if (op_sig.inputs.at(0).type == kTfLiteUInt32) {
return 6;
}
@ -477,7 +480,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
return 4;
}
// Version 3 supports string input types.
if (op_sig.inputs.at(0).type == kTfLiteString) {
return 3;
}

View File

@ -295,6 +295,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_SLICE, 4}, "2.4.0"},
{{BuiltinOperator_SLICE, 5}, "2.5.0"},
{{BuiltinOperator_SLICE, 6}, "2.14.0"},
{{BuiltinOperator_SLICE, 7}, "2.21.0"},
{{BuiltinOperator_TANH, 1}, "1.14.0"},
{{BuiltinOperator_TANH, 2}, "1.14.0"},
{{BuiltinOperator_TANH, 3}, "2.3.0"},

View File

@ -217,7 +217,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
/* min_version = */ 1,
/* max_version = */ 6);
/* max_version = */ 7);
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),

View File

@ -2573,10 +2573,13 @@ cc_test(
],
tags = ["tflite_nnapi"],
deps = [
":kernel_util",
":test_main",
":test_util",
"//tensorflow/lite:string",
"//tensorflow/lite/core/c:common",
"//tensorflow/lite/kernels/internal:tensor_ctypes",
"//tensorflow/lite/kernels/internal:tensor_utils_no_eigen",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_googletest//:gtest",
"@eigen_archive//:eigen3",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <tuple>
@ -34,6 +35,8 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/add.h"
#include "tensorflow/lite/kernels/internal/reference/mul.h"
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
#include <Accelerate/Accelerate.h>
@ -4798,6 +4801,78 @@ inline void Slice(const tflite::SliceParams& op_params,
return Slice(op_params, input_shape, output_shape, &writer);
}
// Iterates through the desired slice region and copies nibbles directly from
// the input to the output tensor.
inline void SliceInt4(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape,
const TfLiteTensor* input,
const RuntimeShape& output_shape, TfLiteTensor* output) {
ruy::profiler::ScopeLabel label("SliceInt4");
const int8_t* input_data = GetTensorData<int8_t>(input);
int8_t* output_data = GetTensorData<int8_t>(output);
// Clear output buffer, as we will be writing nibbles.
const int output_byte_size = (output_shape.FlatSize() + 1) / 2;
memset(output_data, 0, output_byte_size);
// Calculate the start and stop indices for each dimension of the slice.
const RuntimeShape ext_input_shape =
RuntimeShape::ExtendedShape(5, input_shape);
TFLITE_DCHECK_LE(op_params.begin_count, 5);
TFLITE_DCHECK_LE(op_params.size_count, 5);
const int begin_count = op_params.begin_count;
const int size_count = op_params.size_count;
// We front-pad the begin and size vectors.
int start[5];
int stop[5];
for (int i = 0; i < 5; ++i) {
int padded_i = 5 - i;
start[i] =
begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
stop[i] =
(size_count < padded_i || op_params.size[size_count - padded_i] == -1)
? ext_input_shape.Dims(i)
: start[i] + op_params.size[size_count - padded_i];
}
// Loop over the slice region and copy nibbles.
int output_nibble_idx = 0;
for (int i0 = start[0]; i0 < stop[0]; ++i0) {
for (int i1 = start[1]; i1 < stop[1]; ++i1) {
for (int i2 = start[2]; i2 < stop[2]; ++i2) {
for (int i3 = start[3]; i3 < stop[3]; ++i3) {
for (int i4 = start[4]; i4 < stop[4]; ++i4) {
const int input_nibble_idx =
Offset(ext_input_shape, i0, i1, i2, i3, i4);
// Get nibble from input. Since int4 data is packed, two nibbles
// share a byte.
const int8_t input_byte = input_data[input_nibble_idx / 2];
int8_t nibble;
if (input_nibble_idx % 2 == 0) { // low nibble
// The `(val << 4) >> 4` trick is to sign-extend the 4-bit value.
nibble = static_cast<int8_t>(input_byte << 4) >> 4;
} else { // high nibble
nibble = input_byte >> 4;
}
// Set nibble in output.
if (output_nibble_idx % 2 == 0) {
// First nibble of a byte. We simply set the lower 4 bits.
output_data[output_nibble_idx / 2] = (nibble & 0x0F);
} else {
// Second nibble. OR with existing low nibble.
output_data[output_nibble_idx / 2] |= (nibble << 4);
}
output_nibble_idx++;
}
}
}
}
}
}
template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,

View File

@ -15,7 +15,14 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
#include <cstdint>
#include <vector>
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@ -74,6 +81,27 @@ inline void Slice(const tflite::SliceParams& op_params,
return Slice(op_params, input_shape, output_shape, &writer);
}
inline void SliceInt4(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape,
const TfLiteTensor* input,
const RuntimeShape& output_shape, TfLiteTensor* output) {
const int num_input_elements = input_shape.FlatSize();
std::vector<int8_t> unpacked_input(num_input_elements);
tensor_utils::UnpackPackedIntToInt8(GetTensorData<int8_t>(input),
num_input_elements, 4,
unpacked_input.data());
const int num_output_elements = output_shape.FlatSize();
std::vector<int8_t> unpacked_output(num_output_elements);
reference_ops::Slice<int8_t>(op_params, input_shape, unpacked_input.data(),
output_shape, unpacked_output.data());
tensor_utils::PackInt8IntoDenseInt(unpacked_output.data(),
num_output_elements, 4,
GetTensorData<int8_t>(output));
}
} // namespace reference_ops
} // namespace tflite

View File

@ -415,7 +415,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF(),
/* min_version = */ 1,
/* max_version = */ 5);
/* max_version = */ 7);
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF(),

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/slice.h"
#include <stdint.h>
#include <algorithm>
@ -206,6 +208,27 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// The dimensions in the kernel used to be in reverse-order, and TFLite
// arranged the begins and sizes vectors accordingly. This macro incorporates
// the needed reversing.
#define TF_LITE_SLICE_INT4() \
{ \
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
tflite::SliceParams op_params; \
op_params.begin_count = kMaxDim; \
op_params.size_count = kMaxDim; \
for (int i = 0; i < kMaxDim; ++i) { \
op_params.begin[i] = begins[i]; \
op_params.size[i] = sizes[i]; \
} \
\
if (kernel_type == kGenericOptimized) { \
optimized_ops::SliceInt4(op_params, GetTensorShape(input), input, \
GetTensorShape(output), output); \
} else { \
reference_ops::SliceInt4(op_params, GetTensorShape(input), input, \
GetTensorShape(output), output); \
} \
}
#define TF_LITE_SLICE(data_type) \
{ \
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
@ -231,6 +254,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32:
TF_LITE_SLICE(float);
break;
case kTfLiteInt4:
TF_LITE_SLICE_INT4();
break;
case kTfLiteInt32:
TF_LITE_SLICE(int32_t);
break;

View File

@ -16,11 +16,16 @@ limitations under the License.
#include <initializer_list>
#include <string>
#include <type_traits>
#include <vector>
#include "Eigen/Core"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_type.h"
@ -67,6 +72,12 @@ class SliceOpModel : public SingleOpModel {
}
void SetInput(std::initializer_list<input_type> data) {
if constexpr (std::is_same<input_type, int8_t>::value) {
if (interpreter_->tensor(input_)->type == kTfLiteInt4) {
PopulateTensor4bit(input_, 0, data.begin(), data.end());
return;
}
}
PopulateTensor<input_type>(input_, data);
}
void SetStringInput(std::vector<string> data) {
@ -253,6 +264,22 @@ TEST_P(SliceOpTest, SliceInt8) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST_P(SliceOpTest, SliceInt4) {
SliceOpModel<int8_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,
TensorType_INT4, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
const TfLiteTensor* output_tensor = m.GetOutputTensor();
int num_elements = NumElements(output_tensor);
std::vector<int8_t> unpacked_output(num_elements);
tensor_utils::UnpackPackedIntToInt8(GetTensorData<int8_t>(output_tensor),
num_elements,
/*bit_width=*/4, unpacked_output.data());
EXPECT_THAT(unpacked_output, ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST_P(SliceOpTest, SliceInt16) {
SliceOpModel<int16_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,