mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
-Add 4 bit support to depthwise_conv.cc and fully_connected.cc in TfLite using the reference kernels 4bit functions for those op . And added/changed supporting functions to get test to run on fully_connected_test.cc
-added a 4bit Test(Simple4bit3x3FilterTest) to depthwise_conv_test.cc in Tflite which is ported from the existing Simple3x3FilterTest with adjusted PerChannelQuanization scales for 4bit input. -added a 4bit Test(SimpleTestQuantizedInt4) to fully_connected_test.cc in Tflite which is ported from the existing SimpleTestQuantizedInt8 with adjusted outputs for 4bit. PiperOrigin-RevId: 507003918
This commit is contained in:
parent
dd7d791e02
commit
788ce94edd
|
|
@ -60,7 +60,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
/* max_version = */ 7);
|
||||
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 6);
|
||||
/* max_version = */ 7);
|
||||
AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
|
|
@ -82,7 +82,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
Register_EMBEDDING_LOOKUP_SPARSE());
|
||||
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 9);
|
||||
/* max_version = */ 10);
|
||||
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
|
||||
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
|
||||
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(),
|
||||
|
|
|
|||
|
|
@ -18,9 +18,11 @@ limitations under the License.
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/core/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/core/c/c_api_types.h"
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
|
|
@ -128,8 +130,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
data_type == kTfLiteInt8 || data_type == kTfLiteInt16);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, data_type);
|
||||
if (!is_hybrid) {
|
||||
TF_LITE_ENSURE(context,
|
||||
filter->type == data_type || data_type == kTfLiteInt16);
|
||||
TF_LITE_ENSURE(context, filter->type == data_type ||
|
||||
data_type == kTfLiteInt16 ||
|
||||
filter->type == kTfLiteInt4);
|
||||
}
|
||||
|
||||
if (data_type == kTfLiteInt16) {
|
||||
|
|
@ -400,23 +403,58 @@ TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
|||
TF_LITE_ENSURE_STATUS(ComputeDepthMultiplier(context, input, filter,
|
||||
&op_params.depth_multiplier));
|
||||
|
||||
if (kernel_type == kReference) {
|
||||
reference_integer_ops::DepthwiseConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int8>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8>(output));
|
||||
} else {
|
||||
optimized_integer_ops::DepthwiseConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int8>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8>(output),
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
KernelType effective_kernel_type = kernel_type;
|
||||
|
||||
if (filter->type == kTfLiteInt4) {
|
||||
effective_kernel_type = kReference;
|
||||
}
|
||||
|
||||
switch (effective_kernel_type) {
|
||||
case kReference: {
|
||||
switch (filter->type) {
|
||||
case kTfLiteInt4: {
|
||||
const size_t bytes_unpacked = filter->bytes * 2;
|
||||
auto unpacked_filter_data =
|
||||
std::make_unique<int8_t[]>(bytes_unpacked);
|
||||
reference_integer_ops::DepthwiseConvPerChannelWithPackedInt4Weights(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int8_t>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8_t>(filter), unpacked_filter_data.get(),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
reference_integer_ops::DepthwiseConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int8>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8>(output));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
printf("Weight type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(filter->type), filter->type);
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kGenericOptimized:
|
||||
case kNeonOptimized: {
|
||||
optimized_integer_ops::DepthwiseConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int8>(input), GetTensorShape(filter),
|
||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8>(output),
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1851,6 +1851,109 @@ TEST_P(PerChannelQuantizedDepthwiseConvolutionOpTest, Simple3x3FilterTest) {
|
|||
ElementsAreArray(ArrayFloatNear({9, 18, 0, 0, 47, 54, 0, 0})));
|
||||
}
|
||||
|
||||
// The expected values for this test were obtained by running the test with the
|
||||
// same parameters but by setting filter type to INT8.
|
||||
TEST_P(PerChannelQuantizedDepthwiseConvolutionOpTest, Simple4bit3x3FilterTest) {
|
||||
// TODO(b/265987257) - remove when NNAPI interaction with 4bit depthwise_conv
|
||||
// is fixed.
|
||||
using testing::FloatEq;
|
||||
using testing::Pointwise;
|
||||
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
|
||||
PerChannelQuantizedDepthwiseConvolutionOpModel m(
|
||||
GetRegistration(), {TensorType_INT8, {1, 3, 3, 8}, -63.5, 64, 0.5, -1},
|
||||
{TensorType_INT4,
|
||||
// [1 * 3 * 3 * 8] as [input_channel, y, x, output_channel]
|
||||
{1, 3, 3, 8},
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
/*per_channel_quantization=*/true,
|
||||
/*per_channel_quantization_scales=*/
|
||||
{2.5, 2, 3, 4, 4, 3, 2, 2.5},
|
||||
/*per_channel_quantization_offsets=*/{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
/*channel_index=*/3},
|
||||
{TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID);
|
||||
m.SetInput({// array of 9 x 8 => with tensor dimmensions [1, 3, 3, 8] as
|
||||
// [input_channel, y, x, output_channel]
|
||||
1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
|
||||
0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
|
||||
1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
|
||||
0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||
m.SetFilter(
|
||||
/*filter data*/
|
||||
{// array of 9 x 8 => with tensor dimmensions [1, 3, 3, 8] as
|
||||
// [input_channel, y, x, output_channel]
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8});
|
||||
m.SetBias({0, 0, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
// Invoke and verify output.
|
||||
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
||||
Pointwise(FloatEq(), {0, 18, 0, 0, 36, 54, 0, 0}));
|
||||
}
|
||||
|
||||
/*The expected values for this test were obtained by running the test with the
|
||||
* same parameters but by setting filter type to INT8*/
|
||||
TEST_P(PerChannelQuantizedDepthwiseConvolutionOpTest, Simple4bitPerAxisTest) {
|
||||
// TODO(b/265987257) - remove when NNAPI interaction with 4bit depthiwse_conv
|
||||
// is fixed.
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
|
||||
PerChannelQuantizedDepthwiseConvolutionOpModel m(
|
||||
GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1},
|
||||
{TensorType_INT4,
|
||||
// [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel]
|
||||
{1, 2, 2, 4},
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
/*per_channel_quantization=*/true,
|
||||
/*per_channel_quantization_scales=*/{1, 2, 3, 4},
|
||||
/*per_channel_quantization_offsets=*/{0, 0, 0, 0},
|
||||
/*channel_index=*/3},
|
||||
{TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID);
|
||||
m.SetInput({
|
||||
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
|
||||
3, 2, // batch = 0, y = 0, x = 0
|
||||
1, -1, // batch = 0, y = 0, x = 1
|
||||
-2, -3, // batch = 0, y = 0, x = 2
|
||||
4, 3, // batch = 0, y = 1, x = 0
|
||||
2, -2, // batch = 0, y = 1, x = 1
|
||||
-3, -4, // batch = 0, y = 1, x = 2
|
||||
});
|
||||
m.SetFilter(
|
||||
/*filter data*/
|
||||
{
|
||||
// [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel]
|
||||
// depth multiplier = 2
|
||||
1, 2, 3, 4, // y = 0, x = 0
|
||||
3, 4, 5, 6, // y = 0, x = 1
|
||||
7, 8, 5, 6, // y = 1, x = 0
|
||||
3, 4, 1, 2, // y = 1, x = 1
|
||||
});
|
||||
m.SetBias({3, -2, 4, 6});
|
||||
|
||||
// Invoke and verify output.
|
||||
// output has dimension [1 * 1 * 2 * 4] as [batch, y, x, output_channel]
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(
|
||||
m.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear({43, 48, 21, 22, 3, -4, -30, -54})));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
testing::ElementsAre(85, 95, 41, 43, 5, -9, -61, -109));
|
||||
}
|
||||
|
||||
TEST_P(PerChannelQuantizedDepthwiseConvolutionOpTest,
|
||||
Simple3x3FilterPaddingSameTest) {
|
||||
PerChannelQuantizedDepthwiseConvolutionOpModel m(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/core/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
|
|
@ -135,7 +136,8 @@ inline TfLiteStatus CheckTypes(TfLiteContext* context,
|
|||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
TfLiteFullyConnectedParams* params) {
|
||||
const bool is_quantized =
|
||||
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
|
||||
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8) ||
|
||||
(filter->type == kTfLiteInt4));
|
||||
const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
|
||||
const bool is_shuffled =
|
||||
is_quantized && (params->weights_format ==
|
||||
|
|
@ -287,7 +289,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
|||
// Currently only Int8/Int16 is supported for per channel quantization.
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
|
||||
TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteInt8);
|
||||
TF_LITE_ENSURE(context, (filter->type == kTfLiteInt8));
|
||||
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
|
||||
per_channel_quantization_size);
|
||||
TF_LITE_ENSURE_EQ(
|
||||
|
|
@ -458,7 +460,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
const bool is_quantized =
|
||||
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
|
||||
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8) ||
|
||||
(filter->type == kTfLiteInt4));
|
||||
const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
|
||||
const bool is_pie = kernel_type == kLegacyPie;
|
||||
|
||||
|
|
@ -803,7 +806,17 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
|||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
op_params.lhs_cacheable = IsConstantTensor(filter);
|
||||
op_params.rhs_cacheable = IsConstantTensor(input);
|
||||
if (kernel_type == kReference) {
|
||||
|
||||
if (filter->type == kTfLiteInt4) {
|
||||
const size_t bytes_unpacked = filter->bytes * 2;
|
||||
auto unpacked_filter_data = std::make_unique<int8_t[]>(bytes_unpacked);
|
||||
reference_integer_ops::FullyConnectedWithPackedInt4Weights(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
||||
unpacked_filter_data.get(), GetTensorShape(bias),
|
||||
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8_t>(output));
|
||||
} else if (kernel_type == kReference) {
|
||||
reference_integer_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
||||
|
|
@ -1025,8 +1038,19 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||
"Invalid quantized and sparse fully-connected format.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
|
||||
sparsity.dim_metadata[2].dense_size == 16) {
|
||||
if (filter->type == kTfLiteInt4) {
|
||||
const size_t bytes_unpacked = filter->bytes * 2;
|
||||
auto unpacked_filter_data =
|
||||
std::make_unique<int8_t[]>(bytes_unpacked);
|
||||
reference_integer_ops::FullyConnectedWithPackedInt4Weights(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
||||
unpacked_filter_data.get(), GetTensorShape(bias),
|
||||
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8_t>(output));
|
||||
} else if (sparsity.dim_metadata_size ==
|
||||
kDimMetadataSizeBlockSparse &&
|
||||
sparsity.dim_metadata[2].dense_size == 16) {
|
||||
// Block sparse with block size of 1x16.
|
||||
optimized_ops::FullyConnectedSparseWeight1x16(
|
||||
sparsity, op_params, input_shape, GetTensorData<int8_t>(input),
|
||||
|
|
@ -1304,6 +1328,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
TF_LITE_KERNEL_LOG(context, "Unhandled fully-connected weights format");
|
||||
return kTfLiteError;
|
||||
}
|
||||
case kTfLiteInt4:
|
||||
if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
|
||||
return EvalQuantized<kernel_type>(context, node, params, data, input,
|
||||
filter, bias, output);
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Unhandled fully-connected weights format");
|
||||
return kTfLiteError;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Filter data type %s currently not supported.",
|
||||
|
|
|
|||
|
|
@ -143,7 +143,8 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
|
|||
FullyConnectedOptionsWeightsFormat weights_format =
|
||||
FullyConnectedOptionsWeightsFormat_DEFAULT,
|
||||
int input_size = -1, bool weights_per_channel_quantized = false,
|
||||
std::vector<float> per_channel_quantization_scales = {})
|
||||
std::vector<float> per_channel_quantization_scales = {},
|
||||
TfLiteType filter_type = kTfLiteNoType)
|
||||
: batches_(batches),
|
||||
units_(units),
|
||||
input_size_(input_size),
|
||||
|
|
@ -192,6 +193,11 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
|
|||
{units_, input_size_},
|
||||
/*min=*/-63.5,
|
||||
/*max=*/64});
|
||||
} else if (filter_type == kTfLiteInt4) {
|
||||
weights_ = AddInput({TensorType_INT4,
|
||||
{units_, input_size_},
|
||||
/*min=*/input.min,
|
||||
/*max=*/input.max});
|
||||
} else {
|
||||
weights_ =
|
||||
AddInput({input.type, {units_, input_size_}, input.min, input.max});
|
||||
|
|
@ -287,11 +293,11 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
|
|||
ActivationFunctionType activation_func = ActivationFunctionType_RELU,
|
||||
FullyConnectedOptionsWeightsFormat weights_format =
|
||||
FullyConnectedOptionsWeightsFormat_DEFAULT,
|
||||
int input_size = -1)
|
||||
: BaseFullyConnectedOpModel(registration, units, batches, input, output,
|
||||
bias_type, keep_num_dims,
|
||||
bias_tensor_optional, activation_func,
|
||||
weights_format, input_size) {}
|
||||
int input_size = -1, TfLiteType filter_type = kTfLiteNoType)
|
||||
: BaseFullyConnectedOpModel(
|
||||
registration, units, batches, input, output, bias_type,
|
||||
keep_num_dims, bias_tensor_optional, activation_func,
|
||||
weights_format, input_size, false, {}, filter_type) {}
|
||||
|
||||
void SetBias(const std::vector<float>& data) {
|
||||
if (bias_type_ == TensorType_INT32) {
|
||||
|
|
@ -306,6 +312,10 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
|
|||
QuantizeAndPopulate<T>(weights_, data);
|
||||
}
|
||||
|
||||
void SetWeights4bit(const std::vector<float>& data) {
|
||||
QuantizeAndPopulate4bit(weights_, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ShuffleAndSetWeights(const std::vector<float>& data, int input_depth,
|
||||
int output_depth) {
|
||||
|
|
@ -700,6 +710,35 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8NoBias) {
|
|||
ElementsAre(150, 150, 150, 184, 184, 184));
|
||||
}
|
||||
|
||||
// The expected values for this test were obtained by running the test with the
|
||||
// same parameters but by setting filter type to INT8.
|
||||
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt4) {
|
||||
QuantizedFullyConnectedOpModel m(
|
||||
GetRegistration(), /*units=*/3, /*batches*/ 2,
|
||||
/*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
|
||||
/*output=*/{TensorType_INT8, {}, -127, 128}, TensorType_INT32, false,
|
||||
false, ActivationFunctionType_RELU,
|
||||
FullyConnectedOptionsWeightsFormat_DEFAULT, -1, kTfLiteInt4);
|
||||
|
||||
// input_product_scale < output_scale was not true.
|
||||
m.SetWeights4bit({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
|
||||
});
|
||||
m.SetBias({1, 2, 3});
|
||||
|
||||
m.SetInput<int8_t>({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
||||
});
|
||||
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||
testing::Pointwise(testing::FloatEq(), {64, 64, 68, 82, 82, 87}));
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(63, 63, 67, 81, 81, 86));
|
||||
}
|
||||
|
||||
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) {
|
||||
QuantizedFullyConnectedOpModel m(
|
||||
GetRegistration(), /*units=*/3, /*batches*/ 2,
|
||||
|
|
|
|||
|
|
@ -65,13 +65,21 @@ std::vector<::testing::Matcher<std::complex<float>>> ArrayComplex64Near(
|
|||
|
||||
template <typename T>
|
||||
inline std::vector<T> Quantize(const std::vector<float>& data, float scale,
|
||||
int32_t zero_point) {
|
||||
int32_t zero_point,
|
||||
TfLiteType type = kTfLiteNoType) {
|
||||
std::vector<T> q;
|
||||
|
||||
T min = std::numeric_limits<T>::min();
|
||||
T max = std::numeric_limits<T>::max();
|
||||
|
||||
if (type == kTfLiteInt4) {
|
||||
min = -7;
|
||||
max = 7;
|
||||
}
|
||||
|
||||
for (const auto& f : data) {
|
||||
q.push_back(static_cast<T>(std::max<float>(
|
||||
std::numeric_limits<T>::min(),
|
||||
std::min<float>(std::numeric_limits<T>::max(),
|
||||
std::round(zero_point + (f / scale))))));
|
||||
min, std::min<float>(max, std::round(zero_point + (f / scale))))));
|
||||
}
|
||||
return q;
|
||||
}
|
||||
|
|
@ -457,10 +465,19 @@ class SingleOpModel {
|
|||
template <typename T>
|
||||
void QuantizeAndPopulate(int index, const std::vector<float>& data) {
|
||||
TfLiteTensor* t = interpreter_->tensor(index);
|
||||
auto q = Quantize<T>(data, t->params.scale, t->params.zero_point);
|
||||
auto q = Quantize<T>(data, t->params.scale, t->params.zero_point, t->type);
|
||||
PopulateTensor(index, 0, q.data(), q.data() + q.size());
|
||||
}
|
||||
|
||||
void QuantizeAndPopulate4bit(int index, const std::vector<float>& data) {
|
||||
TfLiteTensor* t = interpreter_->tensor(index);
|
||||
t->type = kTfLiteInt4;
|
||||
std::vector<int8_t> quantized_output =
|
||||
Quantize<int8_t>(data, t->params.scale, t->params.zero_point, t->type);
|
||||
PopulateTensor4bit(index, /*offset=*/0, quantized_output.data(),
|
||||
quantized_output.data() + quantized_output.size());
|
||||
}
|
||||
|
||||
void SymmetricQuantizeAndPopulate(int index, const std::vector<float>& data) {
|
||||
std::vector<int8_t> q = QuantizeTensor(index, data);
|
||||
PopulateTensor(index, /*offset=*/0, reinterpret_cast<uint8_t*>(q.data()),
|
||||
|
|
@ -705,6 +722,9 @@ class SingleOpModel {
|
|||
} else if (t.type == TensorType_INT16) {
|
||||
std::tie(t.scale, t.zero_point) =
|
||||
QuantizationParams<int16_t>(t.min, t.max);
|
||||
} else if (t.type == TensorType_INT4) {
|
||||
std::tie(t.scale, t.zero_point) =
|
||||
QuantizationParams<int8_t>(t.min, t.max, kTfLiteInt4);
|
||||
} else {
|
||||
LOG(FATAL) << "No support for the requested quantized type";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_DEPTHWISE_CONV_2D, 4}, "2.2.0"},
|
||||
{{BuiltinOperator_DEPTHWISE_CONV_2D, 5}, "2.3.0"},
|
||||
{{BuiltinOperator_DEPTHWISE_CONV_2D, 6}, "2.3.0"},
|
||||
{{BuiltinOperator_DEPTHWISE_CONV_2D, 7}, "2.11.0"},
|
||||
{{BuiltinOperator_ADD, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_ADD, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_ADD, 3}, "2.4.0"},
|
||||
|
|
@ -123,6 +124,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_FULLY_CONNECTED, 7}, "2.3.0"},
|
||||
{{BuiltinOperator_FULLY_CONNECTED, 8}, "2.3.0"},
|
||||
{{BuiltinOperator_FULLY_CONNECTED, 9}, "2.3.0"},
|
||||
{{BuiltinOperator_FULLY_CONNECTED, 10}, "2.11.0"},
|
||||
{{BuiltinOperator_GATHER, 1}, "1.6.0"},
|
||||
{{BuiltinOperator_GATHER, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_GATHER, 3}, "1.15.0"},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user