mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add support for int4 in dequantize op.
This CL adds int4 support inside the reference/optimized dequantize op including per-channel dequantization. PiperOrigin-RevId: 642055280
This commit is contained in:
parent
05eeb7762b
commit
37ff47b612
|
|
@ -26,6 +26,10 @@
|
|||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
|
||||
* `tf.lite`
|
||||
* `Dequantize` op supports `TensorType_INT4`.
|
||||
* This change includes per-channel dequantization.
|
||||
|
||||
## Keras
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
/* max_version = */ 6);
|
||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 5);
|
||||
/* max_version = */ 6);
|
||||
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
|
||||
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
|
||||
/* min_version = */ 1,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
|
||||
OpContext op_context(context, node);
|
||||
|
||||
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
|
||||
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteInt4 ||
|
||||
op_context.input->type == kTfLiteUInt8 ||
|
||||
op_context.input->type == kTfLiteInt8 ||
|
||||
op_context.input->type == kTfLiteInt16 ||
|
||||
op_context.input->type == kTfLiteFloat16);
|
||||
|
|
|
|||
|
|
@ -17,9 +17,12 @@ limitations under the License.
|
|||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "Eigen/Core" // from @eigen_archive
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
|
|
@ -60,6 +63,19 @@ inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context,
|
|||
quantization_params->quantized_dimension;
|
||||
per_channel_op_params.scale = quantization_params->scale->data;
|
||||
per_channel_op_params.zero_point = quantization_params->zero_point->data;
|
||||
const int8_t* input_data;
|
||||
const size_t bytes_unpacked = input->bytes * 2;
|
||||
auto unpacked_input_data = std::make_unique<int8_t[]>(bytes_unpacked);
|
||||
|
||||
if (input->type == kTfLiteInt4) {
|
||||
tflite::tensor_utils::UnpackDenseInt4IntoInt8(
|
||||
GetTensorData<int8_t>(input), GetTensorShape(input).FlatSize(),
|
||||
unpacked_input_data.get());
|
||||
input_data = unpacked_input_data.get();
|
||||
} else {
|
||||
input_data = GetTensorData<int8_t>(input);
|
||||
}
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteUInt8:
|
||||
reference_ops::PerChannelDequantize<uint8_t>(
|
||||
|
|
@ -67,11 +83,11 @@ inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context,
|
|||
GetTensorData<uint8_t>(input), GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
break;
|
||||
case kTfLiteInt4:
|
||||
case kTfLiteInt8:
|
||||
reference_ops::PerChannelDequantize<int8_t>(
|
||||
per_channel_op_params, GetTensorShape(input),
|
||||
GetTensorData<int8_t>(input), GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
per_channel_op_params, GetTensorShape(input), input_data,
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %d not supported for per-channel.",
|
||||
|
|
@ -90,6 +106,20 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node,
|
|||
DequantizationParams op_params;
|
||||
op_params.zero_point = input->params.zero_point;
|
||||
op_params.scale = input->params.scale;
|
||||
const int8_t* input_data;
|
||||
const size_t bytes_unpacked = input->bytes * 2;
|
||||
auto unpacked_input_data = std::make_unique<int8_t[]>(bytes_unpacked);
|
||||
|
||||
if (input->type == kTfLiteInt4) {
|
||||
// Use GetTensorShape(input).FlatSize() for num_elements.
|
||||
tflite::tensor_utils::UnpackDenseInt4IntoInt8(
|
||||
GetTensorData<int8_t>(input), GetTensorShape(input).FlatSize(),
|
||||
unpacked_input_data.get());
|
||||
input_data = unpacked_input_data.get();
|
||||
} else {
|
||||
input_data = GetTensorData<int8_t>(input);
|
||||
}
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteUInt8:
|
||||
if (kernel_type == kReference) {
|
||||
|
|
@ -102,15 +132,16 @@ TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node,
|
|||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt4:
|
||||
case kTfLiteInt8:
|
||||
if (kernel_type == kReference) {
|
||||
reference_integer_ops::Dequantize<int8_t>(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
op_params, GetTensorShape(input), input_data,
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
} else {
|
||||
optimized_ops::Dequantize(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
optimized_ops::Dequantize(op_params, GetTensorShape(input), input_data,
|
||||
GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
|
|
|
|||
|
|
@ -66,6 +66,15 @@ class DequantizeOpModel : public SingleOpModel {
|
|||
PopulateTensor(input_, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetInputInt4(int input, const std::vector<T> data) {
|
||||
auto non_const = *const_cast<std::vector<T>*>(&data);
|
||||
std::vector<int8_t> data_int8(non_const.size());
|
||||
std::copy(non_const.begin(), non_const.end(), data_int8.begin());
|
||||
PopulateTensor4bit(input, 0, data_int8.data(),
|
||||
data_int8.data() + data_int8.size());
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
|
||||
protected:
|
||||
|
|
@ -73,6 +82,16 @@ class DequantizeOpModel : public SingleOpModel {
|
|||
int output_;
|
||||
};
|
||||
|
||||
TEST(DequantizeOpTest, Int4) {
|
||||
// [-3.5, 4] -> scale=0.5, zero_point=1 for INT4
|
||||
DequantizeOpModel m(TensorType_INT4, {2, 2}, 0.5, -1, 6);
|
||||
|
||||
m.SetInputInt4<int8_t>(0, {7, 6, -7, -8});
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear({4, 3.5, -3, -3.5})));
|
||||
}
|
||||
|
||||
TEST(DequantizeOpTest, Uint8) {
|
||||
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
|
||||
DequantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127, 1);
|
||||
|
|
|
|||
|
|
@ -1442,6 +1442,7 @@ cc_test(
|
|||
srcs = ["per_channel_dequantize_test.cc"],
|
||||
deps = [
|
||||
":reference_base",
|
||||
":tensor_utils_no_eigen",
|
||||
":types",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
|
|
|
|||
|
|
@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
|
|
@ -118,5 +121,34 @@ TEST(PerChannelDequantize, TestInt8ToFloat_4DDim3) {
|
|||
-124, 62, 30.75, 63, 31.25, 127})));
|
||||
}
|
||||
|
||||
TEST(PerChannelDequantize, TestInt4ToFloat_2D) {
|
||||
const std::vector<float> scales = {0.5, 0.25};
|
||||
const std::vector<int> zero_points = {-1, -1};
|
||||
const int quantized_dimension = 0;
|
||||
|
||||
const RuntimeShape unpacked_shape({2, 4});
|
||||
|
||||
const std::vector<int8_t> packed_int4_input = {-1, 0, 65, -127};
|
||||
std::vector<float> output(8, -1);
|
||||
const size_t bytes_unpacked = packed_int4_input.size() * 2;
|
||||
auto unpacked_input_data = std::make_unique<int8_t[]>(bytes_unpacked);
|
||||
tflite::tensor_utils::UnpackDenseInt4IntoInt8(
|
||||
packed_int4_input.data(), bytes_unpacked, unpacked_input_data.get());
|
||||
EXPECT_THAT(std::vector<int8_t>(unpacked_input_data.get(),
|
||||
unpacked_input_data.get() + bytes_unpacked),
|
||||
ElementsAreArray(ArrayFloatNear({-1, -1, 0, 0, 1, 4, 1, -8})));
|
||||
|
||||
PerChannelDequantizationParams op_params;
|
||||
op_params.zero_point = zero_points.data();
|
||||
op_params.scale = scales.data();
|
||||
op_params.quantized_dimension = quantized_dimension;
|
||||
reference_ops::PerChannelDequantize(op_params, unpacked_shape,
|
||||
unpacked_input_data.get(), unpacked_shape,
|
||||
output.data());
|
||||
// This comes from (UNPACKED - zero_point) * scale.
|
||||
EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
|
||||
{0, 0, 0.5, 0.5, 0.5, 1.25, 0.5, -1.75})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -70,6 +70,12 @@ void ApplySignbitToVector(const float* __restrict__ vector, int v_size,
|
|||
|
||||
void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
|
||||
int8_t* dst_buffer) {
|
||||
// num_elements means the number of elements regardless of packed or unpacked.
|
||||
// For example, 3 elements means both
|
||||
// 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
|
||||
// stored in src_buffer[0] and src_buffer[1] (i = 0..1)
|
||||
// 2) Unpacked: 3 int8's = 3 bytes.
|
||||
//. stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
|
||||
for (int i = 0; i < num_elements / 2; i++) {
|
||||
int8_t byte = src_buffer[i];
|
||||
// Shift left first so that sign is properly extended when shifted right
|
||||
|
|
|
|||
|
|
@ -316,6 +316,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_DEQUANTIZE, 3}, "1.15.0"},
|
||||
{{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"},
|
||||
{{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"},
|
||||
{{BuiltinOperator_DEQUANTIZE, 6}, "2.18.0"},
|
||||
{{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_EQUAL, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_EQUAL, 2}, "1.14.0"},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user