From 2feb74eeff171c362a26d2f5e2cb42419a790500 Mon Sep 17 00:00:00 2001 From: Majid Dadashi Date: Tue, 28 Oct 2025 14:58:07 -0700 Subject: [PATCH] Add i4 support in tfl.slice PiperOrigin-RevId: 825217744 --- RELEASE.md | 4 +- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 4 +- .../mlir/lite/tools/versioning/op_version.cc | 4 +- .../lite/tools/versioning/runtime_version.cc | 1 + tensorflow/lite/core/kernels/register.cc | 2 +- tensorflow/lite/kernels/BUILD | 3 + .../internal/optimized/optimized_ops.h | 75 +++++++++++++++++++ .../lite/kernels/internal/reference/slice.h | 28 +++++++ tensorflow/lite/kernels/register_ref.cc | 2 +- tensorflow/lite/kernels/slice.cc | 26 +++++++ tensorflow/lite/kernels/slice_test.cc | 27 +++++++ 11 files changed, 170 insertions(+), 6 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index b65c1839862..6255a4a1d86 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -23,7 +23,9 @@ * Adds int8 and int16x8 support for SQRT operator. * Adds int16x8 support for EQUAL and NOT_EQUAL operators. * Adds support for int2 type. - * Adds support for int2/int4 in tfl.cast. + * Adds support for int2/int4 in tfl.cast . + * Adds support for SRQ int2 in tfl.fully_connected. + * Adds support for int4 in tfl.slice. ### Bug Fixes and Other Changes diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 5b701e674dc..e6b1e37eb6c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -2477,13 +2477,13 @@ equivalent to setting: }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, + TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32OrI64Tensor:$begin, TFL_I32OrI64Tensor:$size ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output + TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output ); let hasVerifier = 1; diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc index 6a238409ea8..9ccda1d0c95 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc @@ -468,6 +468,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_SLICE: + if (op_sig.inputs.at(0).type == kTfLiteInt4) { + return 7; + } if (op_sig.inputs.at(0).type == kTfLiteUInt32) { return 6; } @@ -477,7 +480,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { if (op_sig.inputs.at(0).type == kTfLiteInt16) { return 4; } - // Version 3 supports string input types. if (op_sig.inputs.at(0).type == kTfLiteString) { return 3; } diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc index 54702a97d7a..aca1b463878 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc @@ -295,6 +295,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_SLICE, 4}, "2.4.0"}, {{BuiltinOperator_SLICE, 5}, "2.5.0"}, {{BuiltinOperator_SLICE, 6}, "2.14.0"}, + {{BuiltinOperator_SLICE, 7}, "2.21.0"}, {{BuiltinOperator_TANH, 1}, "1.14.0"}, {{BuiltinOperator_TANH, 2}, "1.14.0"}, {{BuiltinOperator_TANH, 3}, "2.3.0"}, diff --git a/tensorflow/lite/core/kernels/register.cc b/tensorflow/lite/core/kernels/register.cc index 2c13edae231..848b28f108a 100644 --- a/tensorflow/lite/core/kernels/register.cc +++ b/tensorflow/lite/core/kernels/register.cc @@ -217,7 +217,7 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version = */ 2); AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), /* min_version = */ 1, - /* max_version = */ 6); + /* max_version = */ 7); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_COS, Register_COS()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(), diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 2decbb76ffc..6a3ec9f57e2 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -2573,10 +2573,13 @@ cc_test( ], tags = ["tflite_nnapi"], deps = [ + ":kernel_util", ":test_main", ":test_util", "//tensorflow/lite:string", "//tensorflow/lite/core/c:common", + "//tensorflow/lite/kernels/internal:tensor_ctypes", + "//tensorflow/lite/kernels/internal:tensor_utils_no_eigen", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index d845b3ee6f5..debdc5142e9 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -34,6 +35,8 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/add.h" #include "tensorflow/lite/kernels/internal/reference/mul.h" #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__) #include @@ -4798,6 +4801,78 @@ inline void Slice(const tflite::SliceParams& op_params, return Slice(op_params, input_shape, output_shape, &writer); } +// Iterates through the desired slice region and copies nibbles directly from +// the input to the output tensor. +inline void SliceInt4(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, + const TfLiteTensor* input, + const RuntimeShape& output_shape, TfLiteTensor* output) { + ruy::profiler::ScopeLabel label("SliceInt4"); + + const int8_t* input_data = GetTensorData(input); + int8_t* output_data = GetTensorData(output); + + // Clear output buffer, as we will be writing nibbles. + const int output_byte_size = (output_shape.FlatSize() + 1) / 2; + memset(output_data, 0, output_byte_size); + + // Calculate the start and stop indices for each dimension of the slice. + const RuntimeShape ext_input_shape = + RuntimeShape::ExtendedShape(5, input_shape); + TFLITE_DCHECK_LE(op_params.begin_count, 5); + TFLITE_DCHECK_LE(op_params.size_count, 5); + const int begin_count = op_params.begin_count; + const int size_count = op_params.size_count; + // We front-pad the begin and size vectors. + int start[5]; + int stop[5]; + for (int i = 0; i < 5; ++i) { + int padded_i = 5 - i; + start[i] = + begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i]; + stop[i] = + (size_count < padded_i || op_params.size[size_count - padded_i] == -1) + ? ext_input_shape.Dims(i) + : start[i] + op_params.size[size_count - padded_i]; + } + + // Loop over the slice region and copy nibbles. + int output_nibble_idx = 0; + for (int i0 = start[0]; i0 < stop[0]; ++i0) { + for (int i1 = start[1]; i1 < stop[1]; ++i1) { + for (int i2 = start[2]; i2 < stop[2]; ++i2) { + for (int i3 = start[3]; i3 < stop[3]; ++i3) { + for (int i4 = start[4]; i4 < stop[4]; ++i4) { + const int input_nibble_idx = + Offset(ext_input_shape, i0, i1, i2, i3, i4); + + // Get nibble from input. Since int4 data is packed, two nibbles + // share a byte. + const int8_t input_byte = input_data[input_nibble_idx / 2]; + int8_t nibble; + if (input_nibble_idx % 2 == 0) { // low nibble + // The `(val << 4) >> 4` trick is to sign-extend the 4-bit value. + nibble = static_cast(input_byte << 4) >> 4; + } else { // high nibble + nibble = input_byte >> 4; + } + + // Set nibble in output. + if (output_nibble_idx % 2 == 0) { + // First nibble of a byte. We simply set the lower 4 bits. + output_data[output_nibble_idx / 2] = (nibble & 0x0F); + } else { + // Second nibble. OR with existing low nibble. + output_data[output_nibble_idx / 2] |= (nibble << 4); + } + output_nibble_idx++; + } + } + } + } + } +} + template void Minimum(const RuntimeShape& input1_shape, const T* input1_data, const T* input2_data, const RuntimeShape& output_shape, diff --git a/tensorflow/lite/kernels/internal/reference/slice.h b/tensorflow/lite/kernels/internal/reference/slice.h index cb73ea0d0c4..feddd639584 100644 --- a/tensorflow/lite/kernels/internal/reference/slice.h +++ b/tensorflow/lite/kernels/internal/reference/slice.h @@ -15,7 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_ +#include +#include + +#include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/portable_tensor.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -74,6 +81,27 @@ inline void Slice(const tflite::SliceParams& op_params, return Slice(op_params, input_shape, output_shape, &writer); } +inline void SliceInt4(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, + const TfLiteTensor* input, + const RuntimeShape& output_shape, TfLiteTensor* output) { + const int num_input_elements = input_shape.FlatSize(); + std::vector unpacked_input(num_input_elements); + tensor_utils::UnpackPackedIntToInt8(GetTensorData(input), + num_input_elements, 4, + unpacked_input.data()); + + const int num_output_elements = output_shape.FlatSize(); + std::vector unpacked_output(num_output_elements); + + reference_ops::Slice(op_params, input_shape, unpacked_input.data(), + output_shape, unpacked_output.data()); + + tensor_utils::PackInt8IntoDenseInt(unpacked_output.data(), + num_output_elements, 4, + GetTensorData(output)); +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index f486e54da7b..842a9dc99d2 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -415,7 +415,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2()); AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF(), /* min_version = */ 1, - /* max_version = */ 5); + /* max_version = */ 7); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_COS, Register_COS()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF(), diff --git a/tensorflow/lite/kernels/slice.cc b/tensorflow/lite/kernels/slice.cc index d8ff57364fe..62b4ae94406 100644 --- a/tensorflow/lite/kernels/slice.cc +++ b/tensorflow/lite/kernels/slice.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/slice.h" + #include #include @@ -206,6 +208,27 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // The dimensions in the kernel used to be in reverse-order, and TFLite // arranged the begins and sizes vectors accordingly. This macro incorporates // the needed reversing. +#define TF_LITE_SLICE_INT4() \ + { \ + TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \ + TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \ + tflite::SliceParams op_params; \ + op_params.begin_count = kMaxDim; \ + op_params.size_count = kMaxDim; \ + for (int i = 0; i < kMaxDim; ++i) { \ + op_params.begin[i] = begins[i]; \ + op_params.size[i] = sizes[i]; \ + } \ + \ + if (kernel_type == kGenericOptimized) { \ + optimized_ops::SliceInt4(op_params, GetTensorShape(input), input, \ + GetTensorShape(output), output); \ + } else { \ + reference_ops::SliceInt4(op_params, GetTensorShape(input), input, \ + GetTensorShape(output), output); \ + } \ + } + #define TF_LITE_SLICE(data_type) \ { \ TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \ @@ -231,6 +254,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteFloat32: TF_LITE_SLICE(float); break; + case kTfLiteInt4: + TF_LITE_SLICE_INT4(); + break; case kTfLiteInt32: TF_LITE_SLICE(int32_t); break; diff --git a/tensorflow/lite/kernels/slice_test.cc b/tensorflow/lite/kernels/slice_test.cc index 4a016c44a45..feb02c48d2f 100644 --- a/tensorflow/lite/kernels/slice_test.cc +++ b/tensorflow/lite/kernels/slice_test.cc @@ -16,11 +16,16 @@ limitations under the License. #include #include +#include #include #include "Eigen/Core" +#include #include #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_type.h" @@ -67,6 +72,12 @@ class SliceOpModel : public SingleOpModel { } void SetInput(std::initializer_list data) { + if constexpr (std::is_same::value) { + if (interpreter_->tensor(input_)->type == kTfLiteInt4) { + PopulateTensor4bit(input_, 0, data.begin(), data.end()); + return; + } + } PopulateTensor(input_, data); } void SetStringInput(std::vector data) { @@ -253,6 +264,22 @@ TEST_P(SliceOpTest, SliceInt8) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); } +TEST_P(SliceOpTest, SliceInt4) { + SliceOpModel m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4}, + {2, 1, -1, 1}, TensorType_INT32, + TensorType_INT4, GetParam()); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1})); + const TfLiteTensor* output_tensor = m.GetOutputTensor(); + int num_elements = NumElements(output_tensor); + std::vector unpacked_output(num_elements); + tensor_utils::UnpackPackedIntToInt8(GetTensorData(output_tensor), + num_elements, + /*bit_width=*/4, unpacked_output.data()); + EXPECT_THAT(unpacked_output, ElementsAreArray({3, 3, 3, 5, 5, 5})); +} + TEST_P(SliceOpTest, SliceInt16) { SliceOpModel m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4}, {2, 1, -1, 1}, TensorType_INT32,