Add support for int4 in dequantize op.

This CL adds int4 support inside the reference/optimized dequantize op including per-channel dequantization.

PiperOrigin-RevId: 642055280
This commit is contained in:
Jae H. Yoo 2024-06-10 16:13:22 -07:00 committed by TensorFlower Gardener
parent 05eeb7762b
commit 37ff47b612
9 changed files with 104 additions and 9 deletions

View File

@ -26,6 +26,10 @@
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE> * <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA> * <NOTES SHOULD BE GROUPED PER AREA>
* `tf.lite`
* `Dequantize` op supports `TensorType_INT4`.
* This change includes per-channel dequantization.
## Keras ## Keras
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES> <INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>

View File

@ -178,7 +178,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 6); /* max_version = */ 6);
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
/* min_version = */ 1, /* min_version = */ 1,
/* max_version = */ 5); /* max_version = */ 6);
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
/* min_version = */ 1, /* min_version = */ 1,

View File

@ -55,7 +55,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node); OpContext op_context(context, node);
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 || TF_LITE_ENSURE(context, op_context.input->type == kTfLiteInt4 ||
op_context.input->type == kTfLiteUInt8 ||
op_context.input->type == kTfLiteInt8 || op_context.input->type == kTfLiteInt8 ||
op_context.input->type == kTfLiteInt16 || op_context.input->type == kTfLiteInt16 ||
op_context.input->type == kTfLiteFloat16); op_context.input->type == kTfLiteFloat16);

View File

@ -17,9 +17,12 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include <memory>
#include "Eigen/Core" // from @eigen_archive #include "Eigen/Core" // from @eigen_archive
#include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
@ -60,6 +63,19 @@ inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context,
quantization_params->quantized_dimension; quantization_params->quantized_dimension;
per_channel_op_params.scale = quantization_params->scale->data; per_channel_op_params.scale = quantization_params->scale->data;
per_channel_op_params.zero_point = quantization_params->zero_point->data; per_channel_op_params.zero_point = quantization_params->zero_point->data;
const int8_t* input_data;
const size_t bytes_unpacked = input->bytes * 2;
auto unpacked_input_data = std::make_unique<int8_t[]>(bytes_unpacked);
if (input->type == kTfLiteInt4) {
tflite::tensor_utils::UnpackDenseInt4IntoInt8(
GetTensorData<int8_t>(input), GetTensorShape(input).FlatSize(),
unpacked_input_data.get());
input_data = unpacked_input_data.get();
} else {
input_data = GetTensorData<int8_t>(input);
}
switch (input->type) { switch (input->type) {
case kTfLiteUInt8: case kTfLiteUInt8:
reference_ops::PerChannelDequantize<uint8_t>( reference_ops::PerChannelDequantize<uint8_t>(
@ -67,11 +83,11 @@ inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context,
GetTensorData<uint8_t>(input), GetTensorShape(output), GetTensorData<uint8_t>(input), GetTensorShape(output),
GetTensorData<float>(output)); GetTensorData<float>(output));
break; break;
case kTfLiteInt4:
case kTfLiteInt8: case kTfLiteInt8:
reference_ops::PerChannelDequantize<int8_t>( reference_ops::PerChannelDequantize<int8_t>(
per_channel_op_params, GetTensorShape(input), per_channel_op_params, GetTensorShape(input), input_data,
GetTensorData<int8_t>(input), GetTensorShape(output), GetTensorShape(output), GetTensorData<float>(output));
GetTensorData<float>(output));
break; break;
default: default:
TF_LITE_KERNEL_LOG(context, "Type %d not supported for per-channel.", TF_LITE_KERNEL_LOG(context, "Type %d not supported for per-channel.",
@ -90,6 +106,20 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node,
DequantizationParams op_params; DequantizationParams op_params;
op_params.zero_point = input->params.zero_point; op_params.zero_point = input->params.zero_point;
op_params.scale = input->params.scale; op_params.scale = input->params.scale;
const int8_t* input_data;
const size_t bytes_unpacked = input->bytes * 2;
auto unpacked_input_data = std::make_unique<int8_t[]>(bytes_unpacked);
if (input->type == kTfLiteInt4) {
// Use GetTensorShape(input).FlatSize() for num_elements.
tflite::tensor_utils::UnpackDenseInt4IntoInt8(
GetTensorData<int8_t>(input), GetTensorShape(input).FlatSize(),
unpacked_input_data.get());
input_data = unpacked_input_data.get();
} else {
input_data = GetTensorData<int8_t>(input);
}
switch (input->type) { switch (input->type) {
case kTfLiteUInt8: case kTfLiteUInt8:
if (kernel_type == kReference) { if (kernel_type == kReference) {
@ -102,15 +132,16 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node,
GetTensorShape(output), GetTensorData<float>(output)); GetTensorShape(output), GetTensorData<float>(output));
} }
break; break;
case kTfLiteInt4:
case kTfLiteInt8: case kTfLiteInt8:
if (kernel_type == kReference) { if (kernel_type == kReference) {
reference_integer_ops::Dequantize<int8_t>( reference_integer_ops::Dequantize<int8_t>(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input), op_params, GetTensorShape(input), input_data,
GetTensorShape(output), GetTensorData<float>(output)); GetTensorShape(output), GetTensorData<float>(output));
} else { } else {
optimized_ops::Dequantize( optimized_ops::Dequantize(op_params, GetTensorShape(input), input_data,
op_params, GetTensorShape(input), GetTensorData<int8_t>(input), GetTensorShape(output),
GetTensorShape(output), GetTensorData<float>(output)); GetTensorData<float>(output));
} }
break; break;
case kTfLiteInt16: case kTfLiteInt16:

View File

@ -66,6 +66,15 @@ class DequantizeOpModel : public SingleOpModel {
PopulateTensor(input_, data); PopulateTensor(input_, data);
} }
template <typename T>
void SetInputInt4(int input, const std::vector<T> data) {
auto non_const = *const_cast<std::vector<T>*>(&data);
std::vector<int8_t> data_int8(non_const.size());
std::copy(non_const.begin(), non_const.end(), data_int8.begin());
PopulateTensor4bit(input, 0, data_int8.data(),
data_int8.data() + data_int8.size());
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); } std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
protected: protected:
@ -73,6 +82,16 @@ class DequantizeOpModel : public SingleOpModel {
int output_; int output_;
}; };
TEST(DequantizeOpTest, Int4) {
// [-3.5, 4] -> scale=0.5, zero_point=1 for INT4
DequantizeOpModel m(TensorType_INT4, {2, 2}, 0.5, -1, 6);
m.SetInputInt4<int8_t>(0, {7, 6, -7, -8});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear({4, 3.5, -3, -3.5})));
}
TEST(DequantizeOpTest, Uint8) { TEST(DequantizeOpTest, Uint8) {
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8 // [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1); DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1);

View File

@ -1442,6 +1442,7 @@ cc_test(
srcs = ["per_channel_dequantize_test.cc"], srcs = ["per_channel_dequantize_test.cc"],
deps = [ deps = [
":reference_base", ":reference_base",
":tensor_utils_no_eigen",
":types", ":types",
"//tensorflow/lite/kernels:test_util", "//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",

View File

@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <memory>
#include <vector> #include <vector>
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/kernels/test_util.h"
@ -118,5 +121,34 @@ TEST(PerChannelDequantize, TestInt8ToFloat_4DDim3) {
-124, 62, 30.75, 63, 31.25, 127}))); -124, 62, 30.75, 63, 31.25, 127})));
} }
TEST(PerChannelDequantize, TestInt4ToFloat_2D) {
const std::vector<float> scales = {0.5, 0.25};
const std::vector<int> zero_points = {-1, -1};
const int quantized_dimension = 0;
const RuntimeShape unpacked_shape({2, 4});
const std::vector<int8_t> packed_int4_input = {-1, 0, 65, -127};
std::vector<float> output(8, -1);
const size_t bytes_unpacked = packed_int4_input.size() * 2;
auto unpacked_input_data = std::make_unique<int8_t[]>(bytes_unpacked);
tflite::tensor_utils::UnpackDenseInt4IntoInt8(
packed_int4_input.data(), bytes_unpacked, unpacked_input_data.get());
EXPECT_THAT(std::vector<int8_t>(unpacked_input_data.get(),
unpacked_input_data.get() + bytes_unpacked),
ElementsAreArray(ArrayFloatNear({-1, -1, 0, 0, 1, 4, 1, -8})));
PerChannelDequantizationParams op_params;
op_params.zero_point = zero_points.data();
op_params.scale = scales.data();
op_params.quantized_dimension = quantized_dimension;
reference_ops::PerChannelDequantize(op_params, unpacked_shape,
unpacked_input_data.get(), unpacked_shape,
output.data());
// This comes from (UNPACKED - zero_point) * scale.
EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
{0, 0, 0.5, 0.5, 0.5, 1.25, 0.5, -1.75})));
}
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -70,6 +70,12 @@ void ApplySignbitToVector(const float* __restrict__ vector, int v_size,
void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements, void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
int8_t* dst_buffer) { int8_t* dst_buffer) {
// num_elements means the number of elements regardless of packed or unpacked.
// For example, 3 elements means both
// 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
// stored in src_buffer[0] and src_buffer[1] (i = 0..1)
// 2) Unpacked: 3 int8's = 3 bytes.
//. stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
for (int i = 0; i < num_elements / 2; i++) { for (int i = 0; i < num_elements / 2; i++) {
int8_t byte = src_buffer[i]; int8_t byte = src_buffer[i];
// Shift left first so that sign is properly extended when shifted right // Shift left first so that sign is properly extended when shifted right

View File

@ -316,6 +316,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_DEQUANTIZE, 3}, "1.15.0"}, {{BuiltinOperator_DEQUANTIZE, 3}, "1.15.0"},
{{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"}, {{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"},
{{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"}, {{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"},
{{BuiltinOperator_DEQUANTIZE, 6}, "2.18.0"},
{{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"}, {{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"},
{{BuiltinOperator_EQUAL, 1}, "1.14.0"}, {{BuiltinOperator_EQUAL, 1}, "1.14.0"},
{{BuiltinOperator_EQUAL, 2}, "1.14.0"}, {{BuiltinOperator_EQUAL, 2}, "1.14.0"},