diff --git a/RELEASE.md b/RELEASE.md index 08357c57932..289fd4bc1f5 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -26,6 +26,10 @@ * * +* `tf.lite` + * `Dequantize` op supports `TensorType_INT4`. + * This change includes per-channel dequantization. + ## Keras diff --git a/tensorflow/lite/core/kernels/register.cc b/tensorflow/lite/core/kernels/register.cc index dcf3338f5b1..df466bb198b 100644 --- a/tensorflow/lite/core/kernels/register.cc +++ b/tensorflow/lite/core/kernels/register.cc @@ -178,7 +178,7 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version = */ 6); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), /* min_version = */ 1, - /* max_version = */ 5); + /* max_version = */ 6); AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), /* min_version = */ 1, diff --git a/tensorflow/lite/kernels/dequantize.cc b/tensorflow/lite/kernels/dequantize.cc index eab7650b6dc..8ec45fd5bd8 100644 --- a/tensorflow/lite/kernels/dequantize.cc +++ b/tensorflow/lite/kernels/dequantize.cc @@ -55,7 +55,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); - TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 || + TF_LITE_ENSURE(context, op_context.input->type == kTfLiteInt4 || + op_context.input->type == kTfLiteUInt8 || op_context.input->type == kTfLiteInt8 || op_context.input->type == kTfLiteInt16 || op_context.input->type == kTfLiteFloat16); diff --git a/tensorflow/lite/kernels/dequantize.h b/tensorflow/lite/kernels/dequantize.h index dea643aac29..f38abe4f8e7 100644 --- a/tensorflow/lite/kernels/dequantize.h +++ b/tensorflow/lite/kernels/dequantize.h @@ -17,9 +17,12 @@ limitations under the License. #include +#include + #include "Eigen/Core" // from @eigen_archive #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" @@ -60,6 +63,19 @@ inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context, quantization_params->quantized_dimension; per_channel_op_params.scale = quantization_params->scale->data; per_channel_op_params.zero_point = quantization_params->zero_point->data; + const int8_t* input_data; + const size_t bytes_unpacked = input->bytes * 2; + auto unpacked_input_data = std::make_unique(bytes_unpacked); + + if (input->type == kTfLiteInt4) { + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + GetTensorData(input), GetTensorShape(input).FlatSize(), + unpacked_input_data.get()); + input_data = unpacked_input_data.get(); + } else { + input_data = GetTensorData(input); + } + switch (input->type) { case kTfLiteUInt8: reference_ops::PerChannelDequantize( @@ -67,11 +83,11 @@ inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context, GetTensorData(input), GetTensorShape(output), GetTensorData(output)); break; + case kTfLiteInt4: case kTfLiteInt8: reference_ops::PerChannelDequantize( - per_channel_op_params, GetTensorShape(input), - GetTensorData(input), GetTensorShape(output), - GetTensorData(output)); + per_channel_op_params, GetTensorShape(input), input_data, + GetTensorShape(output), GetTensorData(output)); break; default: TF_LITE_KERNEL_LOG(context, "Type %d not supported for per-channel.", @@ -90,6 +106,20 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node, DequantizationParams op_params; op_params.zero_point = input->params.zero_point; op_params.scale = input->params.scale; + const int8_t* input_data; + const size_t bytes_unpacked = input->bytes * 2; + auto unpacked_input_data = std::make_unique(bytes_unpacked); + + if (input->type == kTfLiteInt4) { + // Use GetTensorShape(input).FlatSize() for num_elements. + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + GetTensorData(input), GetTensorShape(input).FlatSize(), + unpacked_input_data.get()); + input_data = unpacked_input_data.get(); + } else { + input_data = GetTensorData(input); + } + switch (input->type) { case kTfLiteUInt8: if (kernel_type == kReference) { @@ -102,15 +132,16 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node, GetTensorShape(output), GetTensorData(output)); } break; + case kTfLiteInt4: case kTfLiteInt8: if (kernel_type == kReference) { reference_integer_ops::Dequantize( - op_params, GetTensorShape(input), GetTensorData(input), + op_params, GetTensorShape(input), input_data, GetTensorShape(output), GetTensorData(output)); } else { - optimized_ops::Dequantize( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(output), GetTensorData(output)); + optimized_ops::Dequantize(op_params, GetTensorShape(input), input_data, + GetTensorShape(output), + GetTensorData(output)); } break; case kTfLiteInt16: diff --git a/tensorflow/lite/kernels/dequantize_test.cc b/tensorflow/lite/kernels/dequantize_test.cc index 136f4aa1735..0ce8ce6f5d1 100644 --- a/tensorflow/lite/kernels/dequantize_test.cc +++ b/tensorflow/lite/kernels/dequantize_test.cc @@ -66,6 +66,15 @@ class DequantizeOpModel : public SingleOpModel { PopulateTensor(input_, data); } + template + void SetInputInt4(int input, const std::vector data) { + auto non_const = *const_cast*>(&data); + std::vector data_int8(non_const.size()); + std::copy(non_const.begin(), non_const.end(), data_int8.begin()); + PopulateTensor4bit(input, 0, data_int8.data(), + data_int8.data() + data_int8.size()); + } + std::vector GetOutput() { return ExtractVector(output_); } protected: @@ -73,6 +82,16 @@ class DequantizeOpModel : public SingleOpModel { int output_; }; +TEST(DequantizeOpTest, Int4) { + // [-3.5, 4] -> scale=0.5, zero_point=1 for INT4 + DequantizeOpModel m(TensorType_INT4, {2, 2}, 0.5, -1, 6); + + m.SetInputInt4(0, {7, 6, -7, -8}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({4, 3.5, -3, -3.5}))); +} + TEST(DequantizeOpTest, Uint8) { // [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8 DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 11292aeb2bb..ff959befad1 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -1442,6 +1442,7 @@ cc_test( srcs = ["per_channel_dequantize_test.cc"], deps = [ ":reference_base", + ":tensor_utils_no_eigen", ":types", "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/kernels/internal/per_channel_dequantize_test.cc b/tensorflow/lite/kernels/internal/per_channel_dequantize_test.cc index 89710b99b95..9d5f37944cf 100644 --- a/tensorflow/lite/kernels/internal/per_channel_dequantize_test.cc +++ b/tensorflow/lite/kernels/internal/per_channel_dequantize_test.cc @@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include #include +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/test_util.h" @@ -118,5 +121,34 @@ TEST(PerChannelDequantize, TestInt8ToFloat_4DDim3) { -124, 62, 30.75, 63, 31.25, 127}))); } +TEST(PerChannelDequantize, TestInt4ToFloat_2D) { + const std::vector scales = {0.5, 0.25}; + const std::vector zero_points = {-1, -1}; + const int quantized_dimension = 0; + + const RuntimeShape unpacked_shape({2, 4}); + + const std::vector packed_int4_input = {-1, 0, 65, -127}; + std::vector output(8, -1); + const size_t bytes_unpacked = packed_int4_input.size() * 2; + auto unpacked_input_data = std::make_unique(bytes_unpacked); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + packed_int4_input.data(), bytes_unpacked, unpacked_input_data.get()); + EXPECT_THAT(std::vector(unpacked_input_data.get(), + unpacked_input_data.get() + bytes_unpacked), + ElementsAreArray(ArrayFloatNear({-1, -1, 0, 0, 1, 4, 1, -8}))); + + PerChannelDequantizationParams op_params; + op_params.zero_point = zero_points.data(); + op_params.scale = scales.data(); + op_params.quantized_dimension = quantized_dimension; + reference_ops::PerChannelDequantize(op_params, unpacked_shape, + unpacked_input_data.get(), unpacked_shape, + output.data()); + // This comes from (UNPACKED - zero_point) * scale. + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0, 0, 0.5, 0.5, 0.5, 1.25, 0.5, -1.75}))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/portable_tensor_utils.cc index 024043d75d3..577fc6b235b 100644 --- a/tensorflow/lite/kernels/internal/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/portable_tensor_utils.cc @@ -70,6 +70,12 @@ void ApplySignbitToVector(const float* __restrict__ vector, int v_size, void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements, int8_t* dst_buffer) { + // num_elements means the number of elements regardless of packed or unpacked. + // For example, 3 elements means both + // 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes. + // stored in src_buffer[0] and src_buffer[1] (i = 0..1) + // 2) Unpacked: 3 int8's = 3 bytes. + //. stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2) for (int i = 0; i < num_elements / 2; i++) { int8_t byte = src_buffer[i]; // Shift left first so that sign is properly extended when shifted right diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index fae1e42c5db..2157790c36c 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -316,6 +316,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_DEQUANTIZE, 3}, "1.15.0"}, {{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"}, {{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"}, + {{BuiltinOperator_DEQUANTIZE, 6}, "2.18.0"}, {{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"}, {{BuiltinOperator_EQUAL, 1}, "1.14.0"}, {{BuiltinOperator_EQUAL, 2}, "1.14.0"},