Add support for int2/int4 in tfl.cast

PiperOrigin-RevId: 820509011
This commit is contained in:
Majid Dadashi 2025-10-16 20:36:50 -07:00 committed by TensorFlower Gardener
parent 5592d364ec
commit f67cb87691
13 changed files with 293 additions and 42 deletions

View File

@ -23,6 +23,7 @@
* Adds int8 and int16x8 support for SQRT operator.
* Adds int16x8 support for EQUAL and NOT_EQUAL operators.
* Adds support for int2 type.
* Adds support for int2/int4 in tfl.cast.
### Bug Fixes and Other Changes

View File

@ -112,6 +112,7 @@ class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
Variadic<TensorOf<allowedOpTypes>>,
TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
def TFL_I2 : I<2>;
def TFL_I4 : I<4>;
def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
@ -4072,13 +4073,10 @@ def TFL_CastOp : TFL_Op<"cast", [
}];
let arguments = (ins
TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I2, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
);
// TODO(b/393644251): Temporary support for INT4 TFL_CastOp. Runtime
// probably already supports INT4. We should remove the INT4 support here or
// make sure the runtime supports is there, as part of closing the bug.
let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I2, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.

View File

@ -1073,8 +1073,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 2;
case BuiltinOperator_CAST:
if (op_sig.inputs.at(0).type == kTfLiteBFloat16 ||
op_sig.outputs.at(0).type == kTfLiteBFloat16) {
if (op_sig.inputs.at(0).type == kTfLiteInt2 ||
op_sig.outputs.at(0).type == kTfLiteInt2) {
return 8;
} else if (op_sig.inputs.at(0).type == kTfLiteBFloat16 ||
op_sig.outputs.at(0).type == kTfLiteBFloat16) {
return 7;
} else if (op_sig.inputs.at(0).type == kTfLiteInt4 &&
op_sig.outputs.at(0).type == kTfLiteFloat32) {

View File

@ -1467,4 +1467,72 @@ TEST(OpVersionTest, VersioningSqrtTest) {
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
}
TEST(OpVersionTest, VersioningCastTest) {
OpSignature fake_op_sig = {};
fake_op_sig.op = BuiltinOperator_CAST;
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt2);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt2);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteBFloat16);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteBFloat16);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt4);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat64);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat64);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt16);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt16);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
} // namespace tflite

View File

@ -112,6 +112,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_CAST, 5}, "2.12.0"},
{{BuiltinOperator_CAST, 6}, "2.15.0"},
{{BuiltinOperator_CAST, 7}, "2.17.0"},
{{BuiltinOperator_CAST, 8}, "2.21.0"},
{{BuiltinOperator_CONCATENATION, 1}, "1.5.0"},
{{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
{{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},

View File

@ -59,10 +59,10 @@ bool IsCompatibleTypeWithTFLCastOp(Type type) {
elemType.isF64())
return true;
// I1, I4, I8, I16, I32, I64 types are allowed.
if (elemType.isInteger(1) || elemType.isInteger(4) || elemType.isInteger(8) ||
elemType.isInteger(16) || elemType.isInteger(32) ||
elemType.isInteger(64))
// I1, I2, I4, I8, I16, I32, I64 types are allowed.
if (elemType.isInteger(1) || elemType.isInteger(2) || elemType.isInteger(4) ||
elemType.isInteger(8) || elemType.isInteger(16) ||
elemType.isInteger(32) || elemType.isInteger(64))
return true;
// Complex<F<32>> is allowed.

View File

@ -176,7 +176,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
/* min_version = */ 1,
/* max_version = */ 7);
/* max_version = */ 8);
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
/* min_version = */ 1,
/* max_version = */ 6);

View File

@ -172,6 +172,8 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
@ -1490,11 +1492,14 @@ cc_test(
tags = ["tflite_nnapi"],
deps = [
":cast_test_common",
":kernel_util",
":test_main",
":test_util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/c:c_api_types",
"//tensorflow/lite/kernels/internal:tensor_utils_no_eigen",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/random",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@eigen_archive//:eigen3",

View File

@ -18,11 +18,14 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>
#include <vector>
#include "Eigen/Core" // from @eigen_archive
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/interpreter_options.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
@ -183,6 +186,19 @@ void copyCastToBFloat16(const Eigen::half* in, Eigen::bfloat16* out,
});
}
TfLiteStatus castInt2ToFloat(TfLiteContext* context, const TfLiteTensor* in,
TfLiteTensor* out, int num_elements) {
const int8_t* in_data = (const int8_t*)in->data.data;
float* out_data = (float*)out->data.data;
std::vector<int8_t> unpacked_temp(num_elements);
tensor_utils::UnpackPackedIntToInt8(in_data, num_elements, /*bit_width=*/2,
unpacked_temp.data());
for (int i = 0; i < num_elements; ++i) {
out_data[i] = static_cast<float>(unpacked_temp[i]);
}
return kTfLiteOk;
}
TfLiteStatus castInt4ToFloat(TfLiteContext* context, const TfLiteTensor* in,
TfLiteTensor* out, int num_elements) {
const int8_t* in_data = (const int8_t*)in->data.data;
@ -240,6 +256,34 @@ TfLiteStatus castInt4ToFloat(TfLiteContext* context, const TfLiteTensor* in,
return kTfLiteOk;
}
TfLiteStatus castFloatToInt4(const float* in, TfLiteTensor* out,
int num_elements) {
const float min_val = -8.0f;
const float max_val = 7.0f;
std::vector<int8_t> unpacked_temp(num_elements);
for (int i = 0; i < num_elements; ++i) {
unpacked_temp[i] =
static_cast<int8_t>(std::max(min_val, std::min(max_val, in[i])));
}
tensor_utils::PackInt8IntoDenseInt(unpacked_temp.data(), num_elements,
/*bit_width=*/4, (int8_t*)out->data.data);
return kTfLiteOk;
}
TfLiteStatus castFloatToInt2(const float* in, TfLiteTensor* out,
int num_elements) {
const float min_val = -2.0f;
const float max_val = 1.0f;
std::vector<int8_t> unpacked_temp(num_elements);
for (int i = 0; i < num_elements; ++i) {
unpacked_temp[i] =
static_cast<int8_t>(std::max(min_val, std::min(max_val, in[i])));
}
tensor_utils::PackInt8IntoDenseInt(unpacked_temp.data(), num_elements,
/*bit_width=*/2, (int8_t*)out->data.data);
return kTfLiteOk;
}
template <typename FromT>
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
TfLiteTensor* out, int num_elements) {
@ -286,6 +330,20 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64),
num_elements);
break;
case kTfLiteInt4:
if (std::is_same<FromT, float>::value) {
return castFloatToInt4(reinterpret_cast<const float*>(in), out,
num_elements);
} else {
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
}
case kTfLiteInt2:
if (std::is_same<FromT, float>::value) {
return castFloatToInt2(reinterpret_cast<const float*>(in), out,
num_elements);
} else {
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
}
default:
// Unsupported type.
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
@ -334,6 +392,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
TF_LITE_UNSUPPORTED_TYPE(context, output->type, "Cast");
}
return castInt4ToFloat(context, input, output, num_elements);
case kTfLiteInt2:
if (output->type != kTfLiteFloat32) {
TF_LITE_UNSUPPORTED_TYPE(context, output->type, "Cast");
}
return castInt2ToFloat(context, input, output, num_elements);
default:
// Unsupported type.
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Cast");

View File

@ -17,16 +17,18 @@ limitations under the License.
#include <algorithm>
#include <complex>
#include <limits>
#include <random>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/random/random.h"
#include "absl/types/span.h"
#include "Eigen/Core" // from @eigen_archive
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/kernels/cast_test_common.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -45,10 +47,10 @@ TEST(CastOpModel, CastInt4ToFloat) {
TEST(CastOpModel, CastInt4ToFloatLarge) {
int num_elements = 40;
std::random_device random_device;
auto rng = std::mt19937(random_device());
std::uniform_int_distribution<int8_t> i8dist(-8, 7);
auto i8rng = [&] { return i8dist(rng); };
absl::BitGen bitgen;
auto i8rng = [&] {
return absl::Uniform<int8_t>(absl::IntervalClosed, bitgen, -8, 7);
};
std::vector<int8_t> input(num_elements);
std::generate(input.begin(), input.end(), i8rng);
CastOpModel m({TensorType_INT4, {num_elements}},
@ -60,6 +62,85 @@ TEST(CastOpModel, CastInt4ToFloatLarge) {
}
}
TEST(CastOpModel, CastInt2ToFloat) {
CastOpModel m({TensorType_INT2, {2, 4}}, {TensorType_FLOAT32, {2, 4}});
m.Set2BitInput({1, 0, -1, -2, 1, 0, -1, -2});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.ExtractVector<float>(m.output()),
Pointwise(FloatingPointEq(),
{1.f, 0.f, -1.f, -2.f, 1.f, 0.f, -1.f, -2.f}));
}
TEST(CastOpModel, CastInt2ToFloatLarge) {
int num_elements = 40;
absl::BitGen bitgen;
auto i2rng = [&] {
return absl::Uniform<int8_t>(absl::IntervalClosed, bitgen, -2, 1);
};
std::vector<int8_t> input(num_elements);
std::generate(input.begin(), input.end(), i2rng);
CastOpModel m({TensorType_INT2, {num_elements}},
{TensorType_FLOAT32, {num_elements}});
m.Set2BitInput(input);
ASSERT_EQ(m.Invoke(), kTfLiteOk);
for (int i = 0; i < input.size(); ++i) {
EXPECT_EQ(m.ExtractVector<float>(m.output())[i], input[i]);
}
}
TEST(CastOpModel, CastFloatToInt4) {
CastOpModel m({TensorType_FLOAT32, {2, 4}}, {TensorType_INT4, {2, 4}});
m.PopulateTensor<float>(m.input(), {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, -8.f});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
TfLiteTensor* output = m.GetOutputTensor(0);
int num_elements = NumElements(output);
std::vector<int8_t> unpacked_output(num_elements);
tensor_utils::UnpackPackedIntToInt8(
reinterpret_cast<int8_t*>(output->data.data), num_elements,
/*bit_width=*/4, unpacked_output.data());
EXPECT_THAT(unpacked_output, ElementsAreArray({1, 2, 3, 4, 5, 6, 7, -8}));
}
TEST(CastOpModel, CastFloatToInt4Clamp) {
CastOpModel m({TensorType_FLOAT32, {1, 4}}, {TensorType_INT4, {1, 4}});
m.PopulateTensor<float>(m.input(), {100.f, -100.f, 7.9f, -8.9f});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
TfLiteTensor* output = m.GetOutputTensor(0);
int num_elements = NumElements(output);
std::vector<int8_t> unpacked_output(num_elements);
tensor_utils::UnpackPackedIntToInt8(
reinterpret_cast<int8_t*>(output->data.data), num_elements,
/*bit_width=*/4, unpacked_output.data());
EXPECT_THAT(unpacked_output, ElementsAreArray({7, -8, 7, -8}));
}
TEST(CastOpModel, CastFloatToInt2) {
CastOpModel m({TensorType_FLOAT32, {2, 4}}, {TensorType_INT2, {2, 4}});
m.PopulateTensor<float>(m.input(),
{1.f, 0.f, -1.f, -2.f, 1.f, 0.f, -1.f, -2.f});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
TfLiteTensor* output = m.GetOutputTensor(0);
int num_elements = NumElements(output);
std::vector<int8_t> unpacked_output(num_elements);
tensor_utils::UnpackPackedIntToInt8(
reinterpret_cast<int8_t*>(output->data.data), num_elements,
/*bit_width=*/2, unpacked_output.data());
EXPECT_THAT(unpacked_output, ElementsAreArray({1, 0, -1, -2, 1, 0, -1, -2}));
}
TEST(CastOpModel, CastFloatToInt2Clamp) {
CastOpModel m({TensorType_FLOAT32, {1, 4}}, {TensorType_INT2, {1, 4}});
m.PopulateTensor<float>(m.input(), {100.f, -100.f, 1.9f, -2.9f});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
TfLiteTensor* output = m.GetOutputTensor(0);
int num_elements = NumElements(output);
std::vector<int8_t> unpacked_output(num_elements);
tensor_utils::UnpackPackedIntToInt8(
reinterpret_cast<int8_t*>(output->data.data), num_elements,
/*bit_width=*/2, unpacked_output.data());
EXPECT_THAT(unpacked_output, ElementsAreArray({1, -2, 1, -2}));
}
TEST(CastOpModel, CastFloatToUint8Infinity) {
CastOpModel m({TensorType_FLOAT32, {2}}, {TensorType_UINT8, {2}});
m.PopulateTensor<float>(m.input(), {std::numeric_limits<float>::infinity(),

View File

@ -59,6 +59,10 @@ class CastOpModel : public SingleOpModel {
PopulateTensor4bit(input_, 0, f.data(), f.data() + f.size());
}
void Set2BitInput(absl::Span<const int8_t> data) {
PopulateTensor2bit(input_, 0, data.data(), data.data() + data.size());
}
int input() const { return input_; }
int output() const { return output_; }

View File

@ -377,7 +377,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
/* min_version = */ 1,
/* max_version = */ 7);
/* max_version = */ 8);
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE_REF(),
/* min_version = */ 1,
/* max_version = */ 6);

View File

@ -33,7 +33,6 @@ limitations under the License.
#include <ostream>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
@ -41,14 +40,15 @@ limitations under the License.
#include <gtest/gtest.h>
#include "fp16/fp16.h" // from @FP16
#include "absl/algorithm/container.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/types/span.h"
#include "Eigen/Core" // from @eigen_archive
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/core/interpreter.h"
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
@ -57,7 +57,6 @@ limitations under the License.
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
#include "tensorflow/lite/type_to_tflitetype.h"
#include "tensorflow/lite/util.h"
#include "tsl/platform/logging.h"
@ -489,14 +488,14 @@ class SingleOpModel {
reinterpret_cast<const uint8_t*>(q.data()), q.size());
buffers_.push_back(CreateBuffer(builder_, data_buffer));
} else if (is_quantized) {
CHECK_EQ(t.type, TensorType_INT8)
ABSL_CHECK_EQ(t.type, TensorType_INT8)
<< "The INT8 quantization is only supported for sparsified tensor";
std::vector<int8_t> quantized_output(sparse_data.size());
std::vector<float> scales;
std::vector<int64_t> zero_points;
if (t.per_channel_quantization) {
CHECK_EQ(t.per_channel_quantization_scales.size(), // NOLINT
t.per_channel_quantization_offsets.size())
ABSL_CHECK_EQ(t.per_channel_quantization_scales.size(), // NOLINT
t.per_channel_quantization_offsets.size())
<< "Per channel quantization scales and offsets should have the "
"same size";
std::vector<int8_t> temp_data(dense_data.size());
@ -703,7 +702,7 @@ class SingleOpModel {
TfLiteTensor* t = interpreter_->tensor(index);
auto* params =
reinterpret_cast<TfLiteAffineQuantization*>(t->quantization.params);
CHECK(t->type == kTfLiteInt32 || t->type == kTfLiteInt64);
ABSL_CHECK(t->type == kTfLiteInt32 || t->type == kTfLiteInt64);
if (t->type == kTfLiteInt32) {
PerChannelQuantizeBiasPopulateTensor<int32_t>(index, input_data, params);
} else {
@ -783,7 +782,7 @@ class SingleOpModel {
std::vector<T> ExtractVector(int index) const {
const T* v = interpreter_->typed_tensor<T>(index);
const auto* tensor = interpreter_->tensor(index);
CHECK(v) << "Could not extract vector at index: " << index;
ABSL_CHECK(v) << "Could not extract vector at index: " << index;
int tensor_size;
if (tensor->sparsity) {
// Getting the size of the sparse buffer this way is based on the
@ -815,7 +814,7 @@ class SingleOpModel {
// Sets the number of threads available to the interpreter.
// Reconstruct the interpreter if reset_interpreter is true.
void SetNumThreads(int num_threads, bool reset_interpreter = false) {
CHECK(interpreter_ != nullptr);
ABSL_CHECK(interpreter_ != nullptr);
if (reset_interpreter) {
// Reconstruct interpreter as number of threads may affect internal
// state, e.g. stratch buffer allocation.
@ -890,7 +889,7 @@ class SingleOpModel {
std::tie(t.scale, t.zero_point) =
QuantizationParams<int8_t>(t.min, t.max, kTfLiteInt4);
} else {
LOG(FATAL) << "No support for the requested quantized type";
ABSL_LOG(FATAL) << "No support for the requested quantized type";
}
t.min = 0;
t.max = 0;
@ -949,12 +948,12 @@ class SingleOpModel {
const float qmax_double = qmax;
// 0 should always be a representable value. Let's assume that the initial
// min,max range contains 0.
CHECK_LE(f_min, 0);
CHECK_GE(f_max, 0);
ABSL_CHECK_LE(f_min, 0);
ABSL_CHECK_GE(f_max, 0);
if (f_min == f_max) {
// Special case where the min,max range is a point. Should be {0}.
CHECK_EQ(f_min, 0);
CHECK_EQ(f_max, 0);
ABSL_CHECK_EQ(f_min, 0);
ABSL_CHECK_EQ(f_max, 0);
return {scale, zero_point};
}
@ -1003,8 +1002,8 @@ class SingleOpModel {
// The zero point should always be in the range of quantized value,
// // [qmin, qmax].
CHECK_GE(nudged_zero_point, qmin);
CHECK_LE(nudged_zero_point, qmax);
ABSL_CHECK_GE(nudged_zero_point, qmin);
ABSL_CHECK_LE(nudged_zero_point, qmax);
zero_point = nudged_zero_point;
// finally, return the values
@ -1028,15 +1027,42 @@ class SingleOpModel {
if (!v) {
auto* t = interpreter_->tensor(index);
CHECK(t) << "No tensor with index " << index << ".";
CHECK(t->data.raw) << "Empty data for tensor with index " << index << ".";
LOG(FATAL) << "Unknown tensor error.";
ABSL_CHECK(t) << "No tensor with index " << index << ".";
ABSL_CHECK(t->data.raw)
<< "Empty data for tensor with index " << index << ".";
ABSL_LOG(FATAL) << "Unknown tensor error.";
}
absl::c_copy(data, v + offset);
PackInt4ValuesDenselyInPlace(v, ElementCount(*tensor_ptr->dims));
tensor_ptr->bytes = ((ElementCount(*tensor_ptr->dims) + 1) / 2);
}
// Partially populates the tensor, starting at the given offset.
void PopulateTensor2bit(int index, int offset, const int8_t* begin,
const int8_t* end) {
auto data = absl::Span<const int8_t>(begin, end - begin);
TfLiteTensor* tensor_ptr = interpreter_->tensor(index);
uint8_t* v = nullptr;
if (tensor_ptr) {
v = reinterpret_cast<uint8_t*>(tensor_ptr->data.data);
}
if (!v) {
auto* t = interpreter_->tensor(index);
ABSL_CHECK(t) << "No tensor with index " << index << ".";
ABSL_CHECK(t->data.raw)
<< "Empty data for tensor with index " << index << ".";
ABSL_LOG(FATAL) << "Unknown tensor error.";
}
int num_elements = data.size();
int num_bytes = (num_elements + 3) / 4;
std::vector<int8_t> packed(num_bytes);
tensor_utils::PackInt8IntoDenseInt(data.data(), num_elements,
/*bit_width=*/2, packed.data());
memcpy(v + offset, packed.data(), packed.size());
tensor_ptr->bytes = num_bytes;
}
private:
// Populates the tensor starting at offset using given data.
template <typename T, typename Container>
@ -1044,13 +1070,14 @@ class SingleOpModel {
T* v = interpreter_->typed_tensor<T>(index);
if (!v) {
auto* t = interpreter_->tensor(index);
CHECK(t) << "No tensor with index " << index << ".";
CHECK(t->data.raw) << "Empty data for tensor with index " << index << ".";
CHECK_EQ(t->type, typeToTfLiteType<T>())
ABSL_CHECK(t) << "No tensor with index " << index << ".";
ABSL_CHECK(t->data.raw)
<< "Empty data for tensor with index " << index << ".";
ABSL_CHECK_EQ(t->type, typeToTfLiteType<T>())
<< "Type mismatch for tensor with index " << index << ". Requested "
<< TfLiteTypeGetName(typeToTfLiteType<T>()) << ", got "
<< TfLiteTypeGetName(t->type) << ".";
LOG(FATAL) << "Unknown tensor error.";
ABSL_LOG(FATAL) << "Unknown tensor error.";
}
absl::c_copy(data, v + offset);
}