Support Int16 Activation Int4 Weight quantization in fully connected

The 4bit weights are unpacked on the fly to 8bits for computation.

PiperOrigin-RevId: 656240754
This commit is contained in:
Joe Zou 2024-07-25 21:55:32 -07:00 committed by TensorFlower Gardener
parent a9a9ecb331
commit 47c91c8159
6 changed files with 167 additions and 37 deletions

View File

@ -40,6 +40,8 @@
* This change includes per-channel dequantization.
* Add support for `stablehlo.composite`.
* `EmbeddingLookup` op supports `TensorType_INT4` values.
* `FullyConnected` op supports `TensorType_INT16` activation and
`TensorType_Int4` weight per-channel quantization.
* `tf.tensor_scatter_update`, `tf.tensor_scatter_add` and of other reduce types.
* Support `bad_indices_policy`.

View File

@ -82,7 +82,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
Register_EMBEDDING_LOOKUP_SPARSE());
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
/* min_version = */ 1,
/* max_version = */ 12);
/* max_version = */ 13);
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(),

View File

@ -443,7 +443,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node,
// Currently only Int8/Int16 is supported for per channel quantization.
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
TF_LITE_ENSURE(context, (filter->type == kTfLiteInt8));
TF_LITE_ENSURE(context, (filter->type == kTfLiteInt8 ||
(filter->type == kTfLiteInt4 &&
input->type == kTfLiteInt16)));
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
per_channel_quantization_size);
TF_LITE_ENSURE_EQ(
@ -481,11 +483,14 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node,
data->output_shift = exponent;
}
if (input->type == kTfLiteUInt8 && output->type == kTfLiteInt16) {
TF_LITE_ENSURE(context, filter->type == kTfLiteUInt8);
}
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
}
if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@ -1166,18 +1171,32 @@ void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input,
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
const int8_t* filter_data;
std::unique_ptr<int8_t[]> unpacked_filter_data = nullptr;
if (filter->type == kTfLiteInt4) {
const size_t bytes_unpacked = filter->bytes * 2;
unpacked_filter_data = std::make_unique<int8_t[]>(bytes_unpacked);
tflite::tensor_utils::UnpackDenseInt4IntoInt8(
GetTensorData<int8_t>(filter), GetTensorShape(filter).FlatSize(),
unpacked_filter_data.get());
filter_data = unpacked_filter_data.get();
} else {
filter_data = GetTensorData<int8>(filter);
}
if (data->quantized_bias_type == kTfLiteInt32) {
reference_integer_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output));
GetTensorShape(filter), filter_data, GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output));
} else {
reference_integer_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int64_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output));
GetTensorShape(filter), filter_data, GetTensorShape(bias),
GetTensorData<int64_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output));
}
}
@ -1231,22 +1250,33 @@ void FullyConnectedPerChannelInt16(const OpData* data,
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
const int8_t* filter_data;
std::unique_ptr<int8_t[]> unpacked_filter_data = nullptr;
if (filter->type == kTfLiteInt4) {
const size_t bytes_unpacked = filter->bytes * 2;
unpacked_filter_data = std::make_unique<int8_t[]>(bytes_unpacked);
tflite::tensor_utils::UnpackDenseInt4IntoInt8(
GetTensorData<int8_t>(filter), GetTensorShape(filter).FlatSize(),
unpacked_filter_data.get());
filter_data = unpacked_filter_data.get();
} else {
filter_data = GetTensorData<int8>(filter);
}
if (data->quantized_bias_type == kTfLiteInt32) {
reference_integer_ops::FullyConnectedPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16_t>(input), GetTensorShape(filter),
GetTensorData<int8_t>(filter), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output));
GetTensorData<int16_t>(input), GetTensorShape(filter), filter_data,
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output));
} else {
reference_integer_ops::FullyConnectedPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16_t>(input), GetTensorShape(filter),
GetTensorData<int8_t>(filter), GetTensorShape(bias),
GetTensorData<int64_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output));
GetTensorData<int16_t>(input), GetTensorShape(filter), filter_data,
GetTensorShape(bias), GetTensorData<int64_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output));
}
}
@ -1331,6 +1361,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
switch (output->type) {
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_ENSURE(context, filter->type != kTfLiteInt4);
reference_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
@ -1401,7 +1432,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
filter->params.zero_point ||
output->params.zero_point;
if (kernel_type == kReference || has_non_zero_point ||
(bias && bias->type == kTfLiteInt64)) {
(bias && bias->type == kTfLiteInt64) ||
(filter->type == kTfLiteInt4)) {
is_per_channel ? FullyConnectedPerChannelInt16<kernel_type>(
data, input, filter, bias, output)
: FullyConnectedInt16<kernel_type>(

View File

@ -163,6 +163,18 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
std::vector<int64_t> per_channel_quantization_offsets(
per_channel_quantization_scales.size(), 0);
if (input.type == TensorType_INT16) {
if (filter_type == kTfLiteInt4) {
weights_ = AddInput({TensorType_INT4,
{units_, input_size_},
0,
0,
0,
0,
true,
per_channel_quantization_scales,
per_channel_quantization_offsets,
0});
} else {
weights_ = AddInput({TensorType_INT8,
{units_, input_size_},
0,
@ -173,6 +185,7 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
per_channel_quantization_scales,
per_channel_quantization_offsets,
0});
}
} else {
weights_ = AddInput({input.type,
{units_, input_size_},
@ -187,12 +200,19 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
}
} else {
if (input.type == TensorType_INT16) {
if (filter_type == kTfLiteInt4) {
weights_ = AddInput({TensorType_INT4,
{units_, input_size_},
/*min=*/-7,
/*max=*/7});
} else {
// Set min and max values that are used to calculate per-tensor scale
// and zero points.
weights_ = AddInput({TensorType_INT8,
{units_, input_size_},
/*min=*/-63.5,
/*max=*/64});
}
} else if (filter_type == kTfLiteInt4) {
weights_ = AddInput({TensorType_INT4,
{units_, input_size_},
@ -375,12 +395,12 @@ class PerChannelQuantizedFullyConnectedOpModel
ActivationFunctionType activation_func = ActivationFunctionType_RELU,
FullyConnectedOptionsWeightsFormat weights_format =
FullyConnectedOptionsWeightsFormat_DEFAULT,
int input_size = -1)
int input_size = -1, TfLiteType filter_type = kTfLiteNoType)
: BaseFullyConnectedOpModel(
registration, units, batches, input, output, bias_type,
keep_num_dims, bias_tensor_optional, activation_func,
weights_format, input_size, true, per_channel_quantization_scales) {
}
weights_format, input_size, true, per_channel_quantization_scales,
filter_type) {}
void SetBias(const std::vector<float>& data) {
PerChannelQuantizeBias(bias_, data);
@ -391,6 +411,12 @@ class PerChannelQuantizedFullyConnectedOpModel
PerChannelSymmetricQuantizeAndPopulate(weights_, data);
}
void SetWeights4bit(const std::vector<float>& data) {
// 4 bit logic handled in PerChannelSymmetricQuantizeAndPopulate.
CHECK_EQ(interpreter_->tensor(weights_)->type, kTfLiteInt4);
PerChannelSymmetricQuantizeAndPopulate(weights_, data);
}
template <typename T>
void SetInput(const std::vector<float>& data) {
QuantizeAndPopulate<T>(input_, data);
@ -862,6 +888,38 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias32) {
ElementsAre(12288, 12800, 13312, 29696, 30208, 30720));
}
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias32Weight4) {
const float scale = 128.0 / 65536;
QuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches*/ 2,
/*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
/*output=*/{TensorType_INT16, {}, 0, 0, scale, 0},
/*bias_type=*/TensorType_INT32, /*keep_num_dims=*/false,
/*bias_tensor_optional=*/false,
/*activation_func*/ ActivationFunctionType_RELU,
/*weights_format=*/FullyConnectedOptionsWeightsFormat_DEFAULT,
/*input_size=*/-1, /*filter_type=*/kTfLiteInt4);
m.SetWeights4bit({
1, 2, 3, 4, 5, 6, -7, 1, 2, 3, // u = 0
1, 2, 3, 4, 5, 6, -7, 1, 2, 3, // u = 1
1, 2, 3, 4, 5, 6, -7, 1, 2, 3, // u = 2
});
m.SetBias({1, 2, 3});
m.SetInput<int16_t>({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear({3, 4, 5, 23, 24, 25})));
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAre(1536, 2048, 2560, 11776, 12288, 12800));
}
TEST_P(QuantizedFullyConnectedOpTest,
SimpleTestPerChannelQuantizedInt16Bias32) {
const float scale = 128.0 / 65536;
@ -893,6 +951,37 @@ TEST_P(QuantizedFullyConnectedOpTest,
ElementsAre(12288, 12800, 13312, 29696, 30208, 30720));
}
TEST_P(QuantizedFullyConnectedOpTest,
SimpleTestPerChannelQuantizedInt16Bias32Weight4) {
const float scale = 128.0 / 65536;
PerChannelQuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches*/ 2,
/*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
/*per_channel_quantization_scales=*/{1.0, 1.0, 1.0},
/*output=*/{TensorType_INT16, {}, 0, 0, scale, 0},
/*bias_type=*/TensorType_INT32, false, false, ActivationFunctionType_RELU,
FullyConnectedOptionsWeightsFormat_DEFAULT, -1, kTfLiteInt4);
m.SetWeights4bit({
1, 2, 3, 4, 5, 6, -7, 1, 2, 3, // u = 0
1, 2, 3, 4, 5, 6, -7, 1, 2, 3, // u = 1
1, 2, 3, 4, 5, 6, -7, 1, 2, 3, // u = 2
});
m.SetBias({1, 2, 3});
m.SetInput<int16_t>({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear({3, 4, 5, 23, 24, 25})));
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAre(1536, 2048, 2560, 11776, 12288, 12800));
}
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias64) {
const float scale = 128.0 / 65536;
QuantizedFullyConnectedOpModel m(

View File

@ -180,6 +180,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
reinterpret_cast<TfLiteFullyConnectedParams*>(op_sig.builtin_data);
TFLITE_DCHECK(fully_connected_params != nullptr);
if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
op_sig.inputs.at(1).type == kTfLiteInt4 &&
op_sig.outputs.at(0).type == kTfLiteInt16) {
return 13;
}
if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
op_sig.inputs.at(1).type == kTfLiteInt8 &&
op_sig.outputs.at(0).type == kTfLiteFloat32 &&

View File

@ -134,6 +134,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_FULLY_CONNECTED, 10}, "2.11.0"},
{{BuiltinOperator_FULLY_CONNECTED, 11}, "2.15.0"},
{{BuiltinOperator_FULLY_CONNECTED, 12}, "2.17.0"},
{{BuiltinOperator_FULLY_CONNECTED, 13}, "2.18.0"},
{{BuiltinOperator_GATHER, 1}, "1.6.0"},
{{BuiltinOperator_GATHER, 2}, "1.14.0"},
{{BuiltinOperator_GATHER, 3}, "1.15.0"},