mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add support for int2/int4 in tfl.cast
PiperOrigin-RevId: 820509011
This commit is contained in:
parent
5592d364ec
commit
f67cb87691
|
|
@ -23,6 +23,7 @@
|
|||
* Adds int8 and int16x8 support for SQRT operator.
|
||||
* Adds int16x8 support for EQUAL and NOT_EQUAL operators.
|
||||
* Adds support for int2 type.
|
||||
* Adds support for int2/int4 in tfl.cast.
|
||||
|
||||
### Bug Fixes and Other Changes
|
||||
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
|
|||
Variadic<TensorOf<allowedOpTypes>>,
|
||||
TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
|
||||
|
||||
def TFL_I2 : I<2>;
|
||||
def TFL_I4 : I<4>;
|
||||
def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>;
|
||||
|
||||
|
|
@ -4072,13 +4073,10 @@ def TFL_CastOp : TFL_Op<"cast", [
|
|||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
|
||||
TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I2, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$input
|
||||
);
|
||||
|
||||
// TODO(b/393644251): Temporary support for INT4 TFL_CastOp. Runtime
|
||||
// probably already supports INT4. We should remove the INT4 support here or
|
||||
// make sure the runtime supports is there, as part of closing the bug.
|
||||
let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I2, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex<F<32>>]>:$output);
|
||||
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
|
|
|
|||
|
|
@ -1073,8 +1073,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||
}
|
||||
return 2;
|
||||
case BuiltinOperator_CAST:
|
||||
if (op_sig.inputs.at(0).type == kTfLiteBFloat16 ||
|
||||
op_sig.outputs.at(0).type == kTfLiteBFloat16) {
|
||||
if (op_sig.inputs.at(0).type == kTfLiteInt2 ||
|
||||
op_sig.outputs.at(0).type == kTfLiteInt2) {
|
||||
return 8;
|
||||
} else if (op_sig.inputs.at(0).type == kTfLiteBFloat16 ||
|
||||
op_sig.outputs.at(0).type == kTfLiteBFloat16) {
|
||||
return 7;
|
||||
} else if (op_sig.inputs.at(0).type == kTfLiteInt4 &&
|
||||
op_sig.outputs.at(0).type == kTfLiteFloat32) {
|
||||
|
|
|
|||
|
|
@ -1467,4 +1467,72 @@ TEST(OpVersionTest, VersioningSqrtTest) {
|
|||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningCastTest) {
|
||||
OpSignature fake_op_sig = {};
|
||||
fake_op_sig.op = BuiltinOperator_CAST;
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt2);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt2);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteBFloat16);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteBFloat16);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt4);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat64);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat64);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt16);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt16);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_CAST, 5}, "2.12.0"},
|
||||
{{BuiltinOperator_CAST, 6}, "2.15.0"},
|
||||
{{BuiltinOperator_CAST, 7}, "2.17.0"},
|
||||
{{BuiltinOperator_CAST, 8}, "2.21.0"},
|
||||
{{BuiltinOperator_CONCATENATION, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},
|
||||
|
|
|
|||
|
|
@ -59,10 +59,10 @@ bool IsCompatibleTypeWithTFLCastOp(Type type) {
|
|||
elemType.isF64())
|
||||
return true;
|
||||
|
||||
// I1, I4, I8, I16, I32, I64 types are allowed.
|
||||
if (elemType.isInteger(1) || elemType.isInteger(4) || elemType.isInteger(8) ||
|
||||
elemType.isInteger(16) || elemType.isInteger(32) ||
|
||||
elemType.isInteger(64))
|
||||
// I1, I2, I4, I8, I16, I32, I64 types are allowed.
|
||||
if (elemType.isInteger(1) || elemType.isInteger(2) || elemType.isInteger(4) ||
|
||||
elemType.isInteger(8) || elemType.isInteger(16) ||
|
||||
elemType.isInteger(32) || elemType.isInteger(64))
|
||||
return true;
|
||||
|
||||
// Complex<F<32>> is allowed.
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 7);
|
||||
/* max_version = */ 8);
|
||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 6);
|
||||
|
|
|
|||
|
|
@ -172,6 +172,8 @@ cc_library(
|
|||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/base:no_destructor",
|
||||
"@com_google_absl//absl/log:absl_check",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
|
|
@ -1490,11 +1492,14 @@ cc_test(
|
|||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":cast_test_common",
|
||||
":kernel_util",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/c:c_api_types",
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils_no_eigen",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/random",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@eigen_archive//:eigen3",
|
||||
|
|
|
|||
|
|
@ -18,11 +18,14 @@ limitations under the License.
|
|||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core" // from @eigen_archive
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/interpreter_options.h"
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
|
@ -183,6 +186,19 @@ void copyCastToBFloat16(const Eigen::half* in, Eigen::bfloat16* out,
|
|||
});
|
||||
}
|
||||
|
||||
TfLiteStatus castInt2ToFloat(TfLiteContext* context, const TfLiteTensor* in,
|
||||
TfLiteTensor* out, int num_elements) {
|
||||
const int8_t* in_data = (const int8_t*)in->data.data;
|
||||
float* out_data = (float*)out->data.data;
|
||||
std::vector<int8_t> unpacked_temp(num_elements);
|
||||
tensor_utils::UnpackPackedIntToInt8(in_data, num_elements, /*bit_width=*/2,
|
||||
unpacked_temp.data());
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
out_data[i] = static_cast<float>(unpacked_temp[i]);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus castInt4ToFloat(TfLiteContext* context, const TfLiteTensor* in,
|
||||
TfLiteTensor* out, int num_elements) {
|
||||
const int8_t* in_data = (const int8_t*)in->data.data;
|
||||
|
|
@ -240,6 +256,34 @@ TfLiteStatus castInt4ToFloat(TfLiteContext* context, const TfLiteTensor* in,
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus castFloatToInt4(const float* in, TfLiteTensor* out,
|
||||
int num_elements) {
|
||||
const float min_val = -8.0f;
|
||||
const float max_val = 7.0f;
|
||||
std::vector<int8_t> unpacked_temp(num_elements);
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
unpacked_temp[i] =
|
||||
static_cast<int8_t>(std::max(min_val, std::min(max_val, in[i])));
|
||||
}
|
||||
tensor_utils::PackInt8IntoDenseInt(unpacked_temp.data(), num_elements,
|
||||
/*bit_width=*/4, (int8_t*)out->data.data);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus castFloatToInt2(const float* in, TfLiteTensor* out,
|
||||
int num_elements) {
|
||||
const float min_val = -2.0f;
|
||||
const float max_val = 1.0f;
|
||||
std::vector<int8_t> unpacked_temp(num_elements);
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
unpacked_temp[i] =
|
||||
static_cast<int8_t>(std::max(min_val, std::min(max_val, in[i])));
|
||||
}
|
||||
tensor_utils::PackInt8IntoDenseInt(unpacked_temp.data(), num_elements,
|
||||
/*bit_width=*/2, (int8_t*)out->data.data);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename FromT>
|
||||
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
||||
TfLiteTensor* out, int num_elements) {
|
||||
|
|
@ -286,6 +330,20 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
|||
copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64),
|
||||
num_elements);
|
||||
break;
|
||||
case kTfLiteInt4:
|
||||
if (std::is_same<FromT, float>::value) {
|
||||
return castFloatToInt4(reinterpret_cast<const float*>(in), out,
|
||||
num_elements);
|
||||
} else {
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
|
||||
}
|
||||
case kTfLiteInt2:
|
||||
if (std::is_same<FromT, float>::value) {
|
||||
return castFloatToInt2(reinterpret_cast<const float*>(in), out,
|
||||
num_elements);
|
||||
} else {
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
|
||||
}
|
||||
default:
|
||||
// Unsupported type.
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
|
||||
|
|
@ -334,6 +392,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
|||
TF_LITE_UNSUPPORTED_TYPE(context, output->type, "Cast");
|
||||
}
|
||||
return castInt4ToFloat(context, input, output, num_elements);
|
||||
case kTfLiteInt2:
|
||||
if (output->type != kTfLiteFloat32) {
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, output->type, "Cast");
|
||||
}
|
||||
return castInt2ToFloat(context, input, output, num_elements);
|
||||
default:
|
||||
// Unsupported type.
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Cast");
|
||||
|
|
|
|||
|
|
@ -17,16 +17,18 @@ limitations under the License.
|
|||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/random/random.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "Eigen/Core" // from @eigen_archive
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/c/c_api_types.h"
|
||||
#include "tensorflow/lite/kernels/cast_test_common.h"
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
|
|
@ -45,10 +47,10 @@ TEST(CastOpModel, CastInt4ToFloat) {
|
|||
|
||||
TEST(CastOpModel, CastInt4ToFloatLarge) {
|
||||
int num_elements = 40;
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
std::uniform_int_distribution<int8_t> i8dist(-8, 7);
|
||||
auto i8rng = [&] { return i8dist(rng); };
|
||||
absl::BitGen bitgen;
|
||||
auto i8rng = [&] {
|
||||
return absl::Uniform<int8_t>(absl::IntervalClosed, bitgen, -8, 7);
|
||||
};
|
||||
std::vector<int8_t> input(num_elements);
|
||||
std::generate(input.begin(), input.end(), i8rng);
|
||||
CastOpModel m({TensorType_INT4, {num_elements}},
|
||||
|
|
@ -60,6 +62,85 @@ TEST(CastOpModel, CastInt4ToFloatLarge) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt2ToFloat) {
|
||||
CastOpModel m({TensorType_INT2, {2, 4}}, {TensorType_FLOAT32, {2, 4}});
|
||||
m.Set2BitInput({1, 0, -1, -2, 1, 0, -1, -2});
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||
Pointwise(FloatingPointEq(),
|
||||
{1.f, 0.f, -1.f, -2.f, 1.f, 0.f, -1.f, -2.f}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt2ToFloatLarge) {
|
||||
int num_elements = 40;
|
||||
absl::BitGen bitgen;
|
||||
auto i2rng = [&] {
|
||||
return absl::Uniform<int8_t>(absl::IntervalClosed, bitgen, -2, 1);
|
||||
};
|
||||
std::vector<int8_t> input(num_elements);
|
||||
std::generate(input.begin(), input.end(), i2rng);
|
||||
CastOpModel m({TensorType_INT2, {num_elements}},
|
||||
{TensorType_FLOAT32, {num_elements}});
|
||||
m.Set2BitInput(input);
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
EXPECT_EQ(m.ExtractVector<float>(m.output())[i], input[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToInt4) {
|
||||
CastOpModel m({TensorType_FLOAT32, {2, 4}}, {TensorType_INT4, {2, 4}});
|
||||
m.PopulateTensor<float>(m.input(), {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, -8.f});
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output = m.GetOutputTensor(0);
|
||||
int num_elements = NumElements(output);
|
||||
std::vector<int8_t> unpacked_output(num_elements);
|
||||
tensor_utils::UnpackPackedIntToInt8(
|
||||
reinterpret_cast<int8_t*>(output->data.data), num_elements,
|
||||
/*bit_width=*/4, unpacked_output.data());
|
||||
EXPECT_THAT(unpacked_output, ElementsAreArray({1, 2, 3, 4, 5, 6, 7, -8}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToInt4Clamp) {
|
||||
CastOpModel m({TensorType_FLOAT32, {1, 4}}, {TensorType_INT4, {1, 4}});
|
||||
m.PopulateTensor<float>(m.input(), {100.f, -100.f, 7.9f, -8.9f});
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output = m.GetOutputTensor(0);
|
||||
int num_elements = NumElements(output);
|
||||
std::vector<int8_t> unpacked_output(num_elements);
|
||||
tensor_utils::UnpackPackedIntToInt8(
|
||||
reinterpret_cast<int8_t*>(output->data.data), num_elements,
|
||||
/*bit_width=*/4, unpacked_output.data());
|
||||
EXPECT_THAT(unpacked_output, ElementsAreArray({7, -8, 7, -8}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToInt2) {
|
||||
CastOpModel m({TensorType_FLOAT32, {2, 4}}, {TensorType_INT2, {2, 4}});
|
||||
m.PopulateTensor<float>(m.input(),
|
||||
{1.f, 0.f, -1.f, -2.f, 1.f, 0.f, -1.f, -2.f});
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output = m.GetOutputTensor(0);
|
||||
int num_elements = NumElements(output);
|
||||
std::vector<int8_t> unpacked_output(num_elements);
|
||||
tensor_utils::UnpackPackedIntToInt8(
|
||||
reinterpret_cast<int8_t*>(output->data.data), num_elements,
|
||||
/*bit_width=*/2, unpacked_output.data());
|
||||
EXPECT_THAT(unpacked_output, ElementsAreArray({1, 0, -1, -2, 1, 0, -1, -2}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToInt2Clamp) {
|
||||
CastOpModel m({TensorType_FLOAT32, {1, 4}}, {TensorType_INT2, {1, 4}});
|
||||
m.PopulateTensor<float>(m.input(), {100.f, -100.f, 1.9f, -2.9f});
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output = m.GetOutputTensor(0);
|
||||
int num_elements = NumElements(output);
|
||||
std::vector<int8_t> unpacked_output(num_elements);
|
||||
tensor_utils::UnpackPackedIntToInt8(
|
||||
reinterpret_cast<int8_t*>(output->data.data), num_elements,
|
||||
/*bit_width=*/2, unpacked_output.data());
|
||||
EXPECT_THAT(unpacked_output, ElementsAreArray({1, -2, 1, -2}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToUint8Infinity) {
|
||||
CastOpModel m({TensorType_FLOAT32, {2}}, {TensorType_UINT8, {2}});
|
||||
m.PopulateTensor<float>(m.input(), {std::numeric_limits<float>::infinity(),
|
||||
|
|
|
|||
|
|
@ -59,6 +59,10 @@ class CastOpModel : public SingleOpModel {
|
|||
PopulateTensor4bit(input_, 0, f.data(), f.data() + f.size());
|
||||
}
|
||||
|
||||
void Set2BitInput(absl::Span<const int8_t> data) {
|
||||
PopulateTensor2bit(input_, 0, data.data(), data.data() + data.size());
|
||||
}
|
||||
|
||||
int input() const { return input_; }
|
||||
int output() const { return output_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -377,7 +377,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 7);
|
||||
/* max_version = */ 8);
|
||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE_REF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 6);
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ limitations under the License.
|
|||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -41,14 +40,15 @@ limitations under the License.
|
|||
#include <gtest/gtest.h>
|
||||
#include "fp16/fp16.h" // from @FP16
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/log/absl_check.h"
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "Eigen/Core" // from @eigen_archive
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
#include "tensorflow/lite/core/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/portable_type_to_tflitetype.h"
|
||||
|
|
@ -57,7 +57,6 @@ limitations under the License.
|
|||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/testing/util.h" // IWYU pragma: keep
|
||||
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
||||
#include "tensorflow/lite/type_to_tflitetype.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
|
||||
|
|
@ -489,14 +488,14 @@ class SingleOpModel {
|
|||
reinterpret_cast<const uint8_t*>(q.data()), q.size());
|
||||
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
||||
} else if (is_quantized) {
|
||||
CHECK_EQ(t.type, TensorType_INT8)
|
||||
ABSL_CHECK_EQ(t.type, TensorType_INT8)
|
||||
<< "The INT8 quantization is only supported for sparsified tensor";
|
||||
std::vector<int8_t> quantized_output(sparse_data.size());
|
||||
std::vector<float> scales;
|
||||
std::vector<int64_t> zero_points;
|
||||
if (t.per_channel_quantization) {
|
||||
CHECK_EQ(t.per_channel_quantization_scales.size(), // NOLINT
|
||||
t.per_channel_quantization_offsets.size())
|
||||
ABSL_CHECK_EQ(t.per_channel_quantization_scales.size(), // NOLINT
|
||||
t.per_channel_quantization_offsets.size())
|
||||
<< "Per channel quantization scales and offsets should have the "
|
||||
"same size";
|
||||
std::vector<int8_t> temp_data(dense_data.size());
|
||||
|
|
@ -703,7 +702,7 @@ class SingleOpModel {
|
|||
TfLiteTensor* t = interpreter_->tensor(index);
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(t->quantization.params);
|
||||
CHECK(t->type == kTfLiteInt32 || t->type == kTfLiteInt64);
|
||||
ABSL_CHECK(t->type == kTfLiteInt32 || t->type == kTfLiteInt64);
|
||||
if (t->type == kTfLiteInt32) {
|
||||
PerChannelQuantizeBiasPopulateTensor<int32_t>(index, input_data, params);
|
||||
} else {
|
||||
|
|
@ -783,7 +782,7 @@ class SingleOpModel {
|
|||
std::vector<T> ExtractVector(int index) const {
|
||||
const T* v = interpreter_->typed_tensor<T>(index);
|
||||
const auto* tensor = interpreter_->tensor(index);
|
||||
CHECK(v) << "Could not extract vector at index: " << index;
|
||||
ABSL_CHECK(v) << "Could not extract vector at index: " << index;
|
||||
int tensor_size;
|
||||
if (tensor->sparsity) {
|
||||
// Getting the size of the sparse buffer this way is based on the
|
||||
|
|
@ -815,7 +814,7 @@ class SingleOpModel {
|
|||
// Sets the number of threads available to the interpreter.
|
||||
// Reconstruct the interpreter if reset_interpreter is true.
|
||||
void SetNumThreads(int num_threads, bool reset_interpreter = false) {
|
||||
CHECK(interpreter_ != nullptr);
|
||||
ABSL_CHECK(interpreter_ != nullptr);
|
||||
if (reset_interpreter) {
|
||||
// Reconstruct interpreter as number of threads may affect internal
|
||||
// state, e.g. stratch buffer allocation.
|
||||
|
|
@ -890,7 +889,7 @@ class SingleOpModel {
|
|||
std::tie(t.scale, t.zero_point) =
|
||||
QuantizationParams<int8_t>(t.min, t.max, kTfLiteInt4);
|
||||
} else {
|
||||
LOG(FATAL) << "No support for the requested quantized type";
|
||||
ABSL_LOG(FATAL) << "No support for the requested quantized type";
|
||||
}
|
||||
t.min = 0;
|
||||
t.max = 0;
|
||||
|
|
@ -949,12 +948,12 @@ class SingleOpModel {
|
|||
const float qmax_double = qmax;
|
||||
// 0 should always be a representable value. Let's assume that the initial
|
||||
// min,max range contains 0.
|
||||
CHECK_LE(f_min, 0);
|
||||
CHECK_GE(f_max, 0);
|
||||
ABSL_CHECK_LE(f_min, 0);
|
||||
ABSL_CHECK_GE(f_max, 0);
|
||||
if (f_min == f_max) {
|
||||
// Special case where the min,max range is a point. Should be {0}.
|
||||
CHECK_EQ(f_min, 0);
|
||||
CHECK_EQ(f_max, 0);
|
||||
ABSL_CHECK_EQ(f_min, 0);
|
||||
ABSL_CHECK_EQ(f_max, 0);
|
||||
return {scale, zero_point};
|
||||
}
|
||||
|
||||
|
|
@ -1003,8 +1002,8 @@ class SingleOpModel {
|
|||
|
||||
// The zero point should always be in the range of quantized value,
|
||||
// // [qmin, qmax].
|
||||
CHECK_GE(nudged_zero_point, qmin);
|
||||
CHECK_LE(nudged_zero_point, qmax);
|
||||
ABSL_CHECK_GE(nudged_zero_point, qmin);
|
||||
ABSL_CHECK_LE(nudged_zero_point, qmax);
|
||||
|
||||
zero_point = nudged_zero_point;
|
||||
// finally, return the values
|
||||
|
|
@ -1028,15 +1027,42 @@ class SingleOpModel {
|
|||
|
||||
if (!v) {
|
||||
auto* t = interpreter_->tensor(index);
|
||||
CHECK(t) << "No tensor with index " << index << ".";
|
||||
CHECK(t->data.raw) << "Empty data for tensor with index " << index << ".";
|
||||
LOG(FATAL) << "Unknown tensor error.";
|
||||
ABSL_CHECK(t) << "No tensor with index " << index << ".";
|
||||
ABSL_CHECK(t->data.raw)
|
||||
<< "Empty data for tensor with index " << index << ".";
|
||||
ABSL_LOG(FATAL) << "Unknown tensor error.";
|
||||
}
|
||||
absl::c_copy(data, v + offset);
|
||||
PackInt4ValuesDenselyInPlace(v, ElementCount(*tensor_ptr->dims));
|
||||
tensor_ptr->bytes = ((ElementCount(*tensor_ptr->dims) + 1) / 2);
|
||||
}
|
||||
|
||||
// Partially populates the tensor, starting at the given offset.
|
||||
void PopulateTensor2bit(int index, int offset, const int8_t* begin,
|
||||
const int8_t* end) {
|
||||
auto data = absl::Span<const int8_t>(begin, end - begin);
|
||||
TfLiteTensor* tensor_ptr = interpreter_->tensor(index);
|
||||
uint8_t* v = nullptr;
|
||||
if (tensor_ptr) {
|
||||
v = reinterpret_cast<uint8_t*>(tensor_ptr->data.data);
|
||||
}
|
||||
|
||||
if (!v) {
|
||||
auto* t = interpreter_->tensor(index);
|
||||
ABSL_CHECK(t) << "No tensor with index " << index << ".";
|
||||
ABSL_CHECK(t->data.raw)
|
||||
<< "Empty data for tensor with index " << index << ".";
|
||||
ABSL_LOG(FATAL) << "Unknown tensor error.";
|
||||
}
|
||||
int num_elements = data.size();
|
||||
int num_bytes = (num_elements + 3) / 4;
|
||||
std::vector<int8_t> packed(num_bytes);
|
||||
tensor_utils::PackInt8IntoDenseInt(data.data(), num_elements,
|
||||
/*bit_width=*/2, packed.data());
|
||||
memcpy(v + offset, packed.data(), packed.size());
|
||||
tensor_ptr->bytes = num_bytes;
|
||||
}
|
||||
|
||||
private:
|
||||
// Populates the tensor starting at offset using given data.
|
||||
template <typename T, typename Container>
|
||||
|
|
@ -1044,13 +1070,14 @@ class SingleOpModel {
|
|||
T* v = interpreter_->typed_tensor<T>(index);
|
||||
if (!v) {
|
||||
auto* t = interpreter_->tensor(index);
|
||||
CHECK(t) << "No tensor with index " << index << ".";
|
||||
CHECK(t->data.raw) << "Empty data for tensor with index " << index << ".";
|
||||
CHECK_EQ(t->type, typeToTfLiteType<T>())
|
||||
ABSL_CHECK(t) << "No tensor with index " << index << ".";
|
||||
ABSL_CHECK(t->data.raw)
|
||||
<< "Empty data for tensor with index " << index << ".";
|
||||
ABSL_CHECK_EQ(t->type, typeToTfLiteType<T>())
|
||||
<< "Type mismatch for tensor with index " << index << ". Requested "
|
||||
<< TfLiteTypeGetName(typeToTfLiteType<T>()) << ", got "
|
||||
<< TfLiteTypeGetName(t->type) << ".";
|
||||
LOG(FATAL) << "Unknown tensor error.";
|
||||
ABSL_LOG(FATAL) << "Unknown tensor error.";
|
||||
}
|
||||
absl::c_copy(data, v + offset);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user