mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add i4 support in tfl.slice
PiperOrigin-RevId: 825217744
This commit is contained in:
parent
202bd1ac59
commit
2feb74eeff
|
|
@ -23,7 +23,9 @@
|
||||||
* Adds int8 and int16x8 support for SQRT operator.
|
* Adds int8 and int16x8 support for SQRT operator.
|
||||||
* Adds int16x8 support for EQUAL and NOT_EQUAL operators.
|
* Adds int16x8 support for EQUAL and NOT_EQUAL operators.
|
||||||
* Adds support for int2 type.
|
* Adds support for int2 type.
|
||||||
* Adds support for int2/int4 in tfl.cast.
|
* Adds support for int2/int4 in tfl.cast .
|
||||||
|
* Adds support for SRQ int2 in tfl.fully_connected.
|
||||||
|
* Adds support for int4 in tfl.slice.
|
||||||
|
|
||||||
### Bug Fixes and Other Changes
|
### Bug Fixes and Other Changes
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2477,13 +2477,13 @@ equivalent to setting:
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||||
TFL_I32OrI64Tensor:$begin,
|
TFL_I32OrI64Tensor:$begin,
|
||||||
TFL_I32OrI64Tensor:$size
|
TFL_I32OrI64Tensor:$size
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
|
TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
|
|
||||||
|
|
@ -468,6 +468,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
case BuiltinOperator_SLICE:
|
case BuiltinOperator_SLICE:
|
||||||
|
if (op_sig.inputs.at(0).type == kTfLiteInt4) {
|
||||||
|
return 7;
|
||||||
|
}
|
||||||
if (op_sig.inputs.at(0).type == kTfLiteUInt32) {
|
if (op_sig.inputs.at(0).type == kTfLiteUInt32) {
|
||||||
return 6;
|
return 6;
|
||||||
}
|
}
|
||||||
|
|
@ -477,7 +480,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||||
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
|
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
// Version 3 supports string input types.
|
|
||||||
if (op_sig.inputs.at(0).type == kTfLiteString) {
|
if (op_sig.inputs.at(0).type == kTfLiteString) {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -295,6 +295,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||||
{{BuiltinOperator_SLICE, 4}, "2.4.0"},
|
{{BuiltinOperator_SLICE, 4}, "2.4.0"},
|
||||||
{{BuiltinOperator_SLICE, 5}, "2.5.0"},
|
{{BuiltinOperator_SLICE, 5}, "2.5.0"},
|
||||||
{{BuiltinOperator_SLICE, 6}, "2.14.0"},
|
{{BuiltinOperator_SLICE, 6}, "2.14.0"},
|
||||||
|
{{BuiltinOperator_SLICE, 7}, "2.21.0"},
|
||||||
{{BuiltinOperator_TANH, 1}, "1.14.0"},
|
{{BuiltinOperator_TANH, 1}, "1.14.0"},
|
||||||
{{BuiltinOperator_TANH, 2}, "1.14.0"},
|
{{BuiltinOperator_TANH, 2}, "1.14.0"},
|
||||||
{{BuiltinOperator_TANH, 3}, "2.3.0"},
|
{{BuiltinOperator_TANH, 3}, "2.3.0"},
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||||
/* max_version = */ 2);
|
/* max_version = */ 2);
|
||||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
|
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
/* max_version = */ 6);
|
/* max_version = */ 7);
|
||||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),
|
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),
|
||||||
|
|
|
||||||
|
|
@ -2573,10 +2573,13 @@ cc_test(
|
||||||
],
|
],
|
||||||
tags = ["tflite_nnapi"],
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":kernel_util",
|
||||||
":test_main",
|
":test_main",
|
||||||
":test_util",
|
":test_util",
|
||||||
"//tensorflow/lite:string",
|
"//tensorflow/lite:string",
|
||||||
"//tensorflow/lite/core/c:common",
|
"//tensorflow/lite/core/c:common",
|
||||||
|
"//tensorflow/lite/kernels/internal:tensor_ctypes",
|
||||||
|
"//tensorflow/lite/kernels/internal:tensor_utils_no_eigen",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
"@eigen_archive//:eigen3",
|
"@eigen_archive//:eigen3",
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
@ -34,6 +35,8 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/kernels/internal/reference/add.h"
|
#include "tensorflow/lite/kernels/internal/reference/add.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/mul.h"
|
#include "tensorflow/lite/kernels/internal/reference/mul.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
|
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
|
||||||
#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
|
|
@ -4798,6 +4801,78 @@ inline void Slice(const tflite::SliceParams& op_params,
|
||||||
return Slice(op_params, input_shape, output_shape, &writer);
|
return Slice(op_params, input_shape, output_shape, &writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Iterates through the desired slice region and copies nibbles directly from
|
||||||
|
// the input to the output tensor.
|
||||||
|
inline void SliceInt4(const tflite::SliceParams& op_params,
|
||||||
|
const RuntimeShape& input_shape,
|
||||||
|
const TfLiteTensor* input,
|
||||||
|
const RuntimeShape& output_shape, TfLiteTensor* output) {
|
||||||
|
ruy::profiler::ScopeLabel label("SliceInt4");
|
||||||
|
|
||||||
|
const int8_t* input_data = GetTensorData<int8_t>(input);
|
||||||
|
int8_t* output_data = GetTensorData<int8_t>(output);
|
||||||
|
|
||||||
|
// Clear output buffer, as we will be writing nibbles.
|
||||||
|
const int output_byte_size = (output_shape.FlatSize() + 1) / 2;
|
||||||
|
memset(output_data, 0, output_byte_size);
|
||||||
|
|
||||||
|
// Calculate the start and stop indices for each dimension of the slice.
|
||||||
|
const RuntimeShape ext_input_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, input_shape);
|
||||||
|
TFLITE_DCHECK_LE(op_params.begin_count, 5);
|
||||||
|
TFLITE_DCHECK_LE(op_params.size_count, 5);
|
||||||
|
const int begin_count = op_params.begin_count;
|
||||||
|
const int size_count = op_params.size_count;
|
||||||
|
// We front-pad the begin and size vectors.
|
||||||
|
int start[5];
|
||||||
|
int stop[5];
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
int padded_i = 5 - i;
|
||||||
|
start[i] =
|
||||||
|
begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
|
||||||
|
stop[i] =
|
||||||
|
(size_count < padded_i || op_params.size[size_count - padded_i] == -1)
|
||||||
|
? ext_input_shape.Dims(i)
|
||||||
|
: start[i] + op_params.size[size_count - padded_i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop over the slice region and copy nibbles.
|
||||||
|
int output_nibble_idx = 0;
|
||||||
|
for (int i0 = start[0]; i0 < stop[0]; ++i0) {
|
||||||
|
for (int i1 = start[1]; i1 < stop[1]; ++i1) {
|
||||||
|
for (int i2 = start[2]; i2 < stop[2]; ++i2) {
|
||||||
|
for (int i3 = start[3]; i3 < stop[3]; ++i3) {
|
||||||
|
for (int i4 = start[4]; i4 < stop[4]; ++i4) {
|
||||||
|
const int input_nibble_idx =
|
||||||
|
Offset(ext_input_shape, i0, i1, i2, i3, i4);
|
||||||
|
|
||||||
|
// Get nibble from input. Since int4 data is packed, two nibbles
|
||||||
|
// share a byte.
|
||||||
|
const int8_t input_byte = input_data[input_nibble_idx / 2];
|
||||||
|
int8_t nibble;
|
||||||
|
if (input_nibble_idx % 2 == 0) { // low nibble
|
||||||
|
// The `(val << 4) >> 4` trick is to sign-extend the 4-bit value.
|
||||||
|
nibble = static_cast<int8_t>(input_byte << 4) >> 4;
|
||||||
|
} else { // high nibble
|
||||||
|
nibble = input_byte >> 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set nibble in output.
|
||||||
|
if (output_nibble_idx % 2 == 0) {
|
||||||
|
// First nibble of a byte. We simply set the lower 4 bits.
|
||||||
|
output_data[output_nibble_idx / 2] = (nibble & 0x0F);
|
||||||
|
} else {
|
||||||
|
// Second nibble. OR with existing low nibble.
|
||||||
|
output_data[output_nibble_idx / 2] |= (nibble << 4);
|
||||||
|
}
|
||||||
|
output_nibble_idx++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
|
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
|
||||||
const T* input2_data, const RuntimeShape& output_shape,
|
const T* input2_data, const RuntimeShape& output_shape,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,14 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/core/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
|
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
@ -74,6 +81,27 @@ inline void Slice(const tflite::SliceParams& op_params,
|
||||||
return Slice(op_params, input_shape, output_shape, &writer);
|
return Slice(op_params, input_shape, output_shape, &writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void SliceInt4(const tflite::SliceParams& op_params,
|
||||||
|
const RuntimeShape& input_shape,
|
||||||
|
const TfLiteTensor* input,
|
||||||
|
const RuntimeShape& output_shape, TfLiteTensor* output) {
|
||||||
|
const int num_input_elements = input_shape.FlatSize();
|
||||||
|
std::vector<int8_t> unpacked_input(num_input_elements);
|
||||||
|
tensor_utils::UnpackPackedIntToInt8(GetTensorData<int8_t>(input),
|
||||||
|
num_input_elements, 4,
|
||||||
|
unpacked_input.data());
|
||||||
|
|
||||||
|
const int num_output_elements = output_shape.FlatSize();
|
||||||
|
std::vector<int8_t> unpacked_output(num_output_elements);
|
||||||
|
|
||||||
|
reference_ops::Slice<int8_t>(op_params, input_shape, unpacked_input.data(),
|
||||||
|
output_shape, unpacked_output.data());
|
||||||
|
|
||||||
|
tensor_utils::PackInt8IntoDenseInt(unpacked_output.data(),
|
||||||
|
num_output_elements, 4,
|
||||||
|
GetTensorData<int8_t>(output));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace reference_ops
|
} // namespace reference_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -415,7 +415,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||||
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
|
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
|
||||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF(),
|
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
/* max_version = */ 5);
|
/* max_version = */ 7);
|
||||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF(),
|
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF(),
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/slice.h"
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
@ -206,6 +208,27 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
||||||
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
||||||
// the needed reversing.
|
// the needed reversing.
|
||||||
|
#define TF_LITE_SLICE_INT4() \
|
||||||
|
{ \
|
||||||
|
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
|
||||||
|
TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
|
||||||
|
tflite::SliceParams op_params; \
|
||||||
|
op_params.begin_count = kMaxDim; \
|
||||||
|
op_params.size_count = kMaxDim; \
|
||||||
|
for (int i = 0; i < kMaxDim; ++i) { \
|
||||||
|
op_params.begin[i] = begins[i]; \
|
||||||
|
op_params.size[i] = sizes[i]; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
if (kernel_type == kGenericOptimized) { \
|
||||||
|
optimized_ops::SliceInt4(op_params, GetTensorShape(input), input, \
|
||||||
|
GetTensorShape(output), output); \
|
||||||
|
} else { \
|
||||||
|
reference_ops::SliceInt4(op_params, GetTensorShape(input), input, \
|
||||||
|
GetTensorShape(output), output); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
#define TF_LITE_SLICE(data_type) \
|
#define TF_LITE_SLICE(data_type) \
|
||||||
{ \
|
{ \
|
||||||
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
|
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
|
||||||
|
|
@ -231,6 +254,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
TF_LITE_SLICE(float);
|
TF_LITE_SLICE(float);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt4:
|
||||||
|
TF_LITE_SLICE_INT4();
|
||||||
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
TF_LITE_SLICE(int32_t);
|
TF_LITE_SLICE(int32_t);
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,16 @@ limitations under the License.
|
||||||
|
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "Eigen/Core"
|
#include "Eigen/Core"
|
||||||
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/lite/core/c/common.h"
|
#include "tensorflow/lite/core/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/test_util.h"
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/string_type.h"
|
#include "tensorflow/lite/string_type.h"
|
||||||
|
|
@ -67,6 +72,12 @@ class SliceOpModel : public SingleOpModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInput(std::initializer_list<input_type> data) {
|
void SetInput(std::initializer_list<input_type> data) {
|
||||||
|
if constexpr (std::is_same<input_type, int8_t>::value) {
|
||||||
|
if (interpreter_->tensor(input_)->type == kTfLiteInt4) {
|
||||||
|
PopulateTensor4bit(input_, 0, data.begin(), data.end());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
PopulateTensor<input_type>(input_, data);
|
PopulateTensor<input_type>(input_, data);
|
||||||
}
|
}
|
||||||
void SetStringInput(std::vector<string> data) {
|
void SetStringInput(std::vector<string> data) {
|
||||||
|
|
@ -253,6 +264,22 @@ TEST_P(SliceOpTest, SliceInt8) {
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(SliceOpTest, SliceInt4) {
|
||||||
|
SliceOpModel<int8_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
|
{2, 1, -1, 1}, TensorType_INT32,
|
||||||
|
TensorType_INT4, GetParam());
|
||||||
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
|
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||||
|
const TfLiteTensor* output_tensor = m.GetOutputTensor();
|
||||||
|
int num_elements = NumElements(output_tensor);
|
||||||
|
std::vector<int8_t> unpacked_output(num_elements);
|
||||||
|
tensor_utils::UnpackPackedIntToInt8(GetTensorData<int8_t>(output_tensor),
|
||||||
|
num_elements,
|
||||||
|
/*bit_width=*/4, unpacked_output.data());
|
||||||
|
EXPECT_THAT(unpacked_output, ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(SliceOpTest, SliceInt16) {
|
TEST_P(SliceOpTest, SliceInt16) {
|
||||||
SliceOpModel<int16_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
SliceOpModel<int16_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
{2, 1, -1, 1}, TensorType_INT32,
|
{2, 1, -1, 1}, TensorType_INT32,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user