mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[tfl] Add (unquantized) Int16 and Uint32 support for Mul
PiperOrigin-RevId: 518603054
This commit is contained in:
parent
e51d876d4d
commit
d0227956e9
|
|
@ -45,6 +45,7 @@
|
|||
* Add 8-bit and 16-bit support for `floor_div` and `floor_mod`.
|
||||
* Add int16 indices support for built-in op `gather` and `gather_nd`.
|
||||
* Add reference implementation for 16-bit int unquantized `add`.
|
||||
* Add reference implementation for 16-bit int and 32-bit unsigned int unquantized `mul`.
|
||||
|
||||
* `tf.keras`
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -282,6 +282,11 @@ bool IsI32Type(Type element_type) {
|
|||
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is UI32.
|
||||
bool IsUI32Type(Type element_type) {
|
||||
return element_type.isInteger(32) && element_type.isUnsignedInteger();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is I64.
|
||||
bool IsI64Type(Type element_type) {
|
||||
return element_type.isInteger(64) && !element_type.isUnsignedInteger();
|
||||
|
|
@ -389,8 +394,9 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
|
|||
|
||||
// Allows I32, I64, QI16 and F32 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsI64Type(element_type) ||
|
||||
IsQI16Type(element_type) || element_type.isa<ComplexType>() ||
|
||||
if (IsI32Type(element_type) || IsUI32Type(element_type) ||
|
||||
IsI64Type(element_type) || IsQI16Type(element_type) ||
|
||||
IsI16Type(element_type) || element_type.isa<ComplexType>() ||
|
||||
element_type.isF32()) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -2528,11 +2528,11 @@ def TFL_MulOp : TFL_Op<"mul", [
|
|||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16, Complex<F<32>>]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16, Complex<F<32>>]>:$rhs,
|
||||
ins TFL_TensorOf<[F32, I32, UI32, I64, QI8, QUI8, QI16, I16, Complex<F<32>>]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, UI32, I64, QI8, QUI8, QI16, I16, Complex<F<32>>]>:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16, Complex<F<32>>]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, UI32, I64, QI8, QUI8, QI16, I16, Complex<F<32>>]>:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
|
|
|
|||
|
|
@ -2338,6 +2338,24 @@ func.func @mul_i64(%arg0: tensor<14xi64>, %arg1: tensor<14xi64>) -> tensor<14xi6
|
|||
// CHECK: return
|
||||
}
|
||||
|
||||
func.func @mul_i16(%arg0: tensor<14xi16>, %arg1: tensor<14xi16>) -> tensor<14xi16> {
|
||||
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<14xi16>, tensor<14xi16>) -> tensor<14xi16>
|
||||
func.return %0: tensor<14xi16>
|
||||
|
||||
// CHECK-LABEL: mul_i16
|
||||
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<14xi16>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func.func @mul_ui32(%arg0: tensor<14xui32>, %arg1: tensor<14xui32>) -> tensor<14xui32> {
|
||||
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<14xui32>, tensor<14xui32>) -> tensor<14xui32>
|
||||
func.return %0: tensor<14xui32>
|
||||
|
||||
// CHECK-LABEL: mul_ui32
|
||||
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<14xui32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func.func @mul_complex32(%arg0: tensor<14xcomplex<f32>>, %arg1: tensor<14xcomplex<f32>>) -> tensor<14xcomplex<f32>> {
|
||||
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<14xcomplex<f32>>, tensor<14xcomplex<f32>>) -> tensor<14xcomplex<f32>>
|
||||
func.return %0: tensor<14xcomplex<f32>>
|
||||
|
|
|
|||
|
|
@ -352,6 +352,22 @@ func.func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
|||
func.return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMul32BitUInt
|
||||
func.func @testMul32BitUInt(tensor<? x ui32>, tensor<? x ui32>) -> tensor<? x ui32> {
|
||||
^bb0(%arg0: tensor<? x ui32>, %arg1: tensor<? x ui32>):
|
||||
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"}
|
||||
%0 = tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x ui32>
|
||||
func.return %0#0 : tensor<? x ui32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMul16BitInt
|
||||
func.func @testMul16BitInt(tensor<? x i16>, tensor<? x i16>) -> tensor<? x i16> {
|
||||
^bb0(%arg0: tensor<? x i16>, %arg1: tensor<? x i16>):
|
||||
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"}
|
||||
%0 = tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i16>
|
||||
func.return %0#0 : tensor<? x i16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMulComplex
|
||||
func.func @testMulComplex(tensor<? x complex<f32>>, tensor<? x complex<f32>>) -> tensor<? x complex<f32>> {
|
||||
^bb0(%arg0: tensor<? x complex<f32>>, %arg1: tensor<? x complex<f32>>):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -101,7 +101,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_MUL, Register_MUL(), /* min_version = */ 1,
|
||||
/* max_version = */ 6);
|
||||
/* max_version = */ 7);
|
||||
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -56,7 +56,7 @@ inline void Mul(const ArithmeticParams& params,
|
|||
const int flat_size =
|
||||
MatchingExtendedShapeFlatSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
output_data[i] = ActivationFunctionWithMinMax<T>(
|
||||
input1_data[i] * input2_data[i], output_activation_min,
|
||||
output_activation_max);
|
||||
}
|
||||
|
|
@ -128,14 +128,18 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BroadcastMul4DSlow(const ArithmeticParams& params,
|
||||
const RuntimeShape& unextended_input1_shape,
|
||||
const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape,
|
||||
const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
template <typename T,
|
||||
// For unquantized mul on small integers, explictly set to true.
|
||||
bool enable_for_short_integers = false>
|
||||
inline typename std::enable_if<
|
||||
!is_small_integer<T>::value || enable_for_short_integers, void>::type
|
||||
BroadcastMul4DSlow(const ArithmeticParams& params,
|
||||
const RuntimeShape& unextended_input1_shape,
|
||||
const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape,
|
||||
const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
T output_activation_min;
|
||||
T output_activation_max;
|
||||
GetActivationParams(params, &output_activation_min, &output_activation_max);
|
||||
|
|
@ -167,7 +171,7 @@ void BroadcastMul4DSlow(const ArithmeticParams& params,
|
|||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
output_data[Offset(output_shape, b, y, x, c)] =
|
||||
ActivationFunctionWithMinMax(
|
||||
ActivationFunctionWithMinMax<T>(
|
||||
input1_data[SubscriptToIndex(desc1, b, y, x, c)] *
|
||||
input2_data[SubscriptToIndex(desc2, b, y, x, c)],
|
||||
output_activation_min, output_activation_max);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -1025,6 +1025,12 @@ inline void SetActivationParams(int32_t min, int32_t max, P* params) {
|
|||
params->quantized_activation_max = max;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
inline void SetActivationParams(uint32_t min, uint32_t max, P* params) {
|
||||
params->quantized_activation_min = min;
|
||||
params->quantized_activation_max = max;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
inline void SetActivationParams(int16_t min, int16_t max, P* params) {
|
||||
params->int16_activation_min = min;
|
||||
|
|
@ -1043,6 +1049,12 @@ inline void GetActivationParams(const P& params, int32_t* min, int32_t* max) {
|
|||
*max = params.quantized_activation_max;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
inline void GetActivationParams(const P& params, uint32_t* min, uint32_t* max) {
|
||||
*min = params.quantized_activation_min;
|
||||
*max = params.quantized_activation_max;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
inline void GetActivationParams(const P& params, int16_t* min, int16_t* max) {
|
||||
*min = params.int16_activation_min;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <complex>
|
||||
|
||||
#include "tensorflow/lite/core/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/core/c/c_api_types.h"
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||
|
|
@ -124,7 +125,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
}
|
||||
|
||||
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
|
||||
output->type == kTfLiteInt16) {
|
||||
(output->quantization.type != kTfLiteNoQuantization &&
|
||||
output->type == kTfLiteInt16)) {
|
||||
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
|
||||
context, params->activation, output, &data->output_activation_min,
|
||||
&data->output_activation_max));
|
||||
|
|
@ -177,6 +179,12 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
|
|||
TF_LITE_MUL(optimized_ops, Mul, int32_t);
|
||||
}
|
||||
}
|
||||
} else if (output->type == kTfLiteUInt32) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, uint32_t);
|
||||
} else {
|
||||
TF_LITE_MUL(reference_ops, Mul, uint32_t);
|
||||
}
|
||||
} else if (output->type == kTfLiteFloat32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (need_broadcast) {
|
||||
|
|
@ -235,6 +243,23 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
|
|||
TF_LITE_MUL(optimized_ops, Mul, float);
|
||||
}
|
||||
}
|
||||
} else if (output->type == kTfLiteInt16) {
|
||||
int16_t output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
SetActivationParams(output_activation_min, output_activation_max,
|
||||
&op_params);
|
||||
if (need_broadcast) {
|
||||
reference_ops::BroadcastMul4DSlow<int16_t, true>(
|
||||
op_params, GetTensorShape(input1), GetTensorData<int16_t>(input1),
|
||||
GetTensorShape(input2), GetTensorData<int16_t>(input2),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
reference_ops::Mul<int16_t>(
|
||||
op_params, GetTensorShape(input1), GetTensorData<int16_t>(input1),
|
||||
GetTensorShape(input2), GetTensorData<int16_t>(input2),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
}
|
||||
} else if (output->type == kTfLiteInt64) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int64_t);
|
||||
|
|
@ -362,8 +387,11 @@ template <KernelType kernel_type>
|
|||
TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, OpData* data,
|
||||
TfLiteMulParams* params, const TfLiteTensor* input1,
|
||||
const TfLiteTensor* input2, TfLiteTensor* output) {
|
||||
bool output_quantized = output->quantization.type != kTfLiteNoQuantization;
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
|
||||
output->type == kTfLiteInt64 || output->type == kTfLiteComplex64) {
|
||||
output->type == kTfLiteInt64 || output->type == kTfLiteComplex64 ||
|
||||
(!output_quantized && output->type == kTfLiteInt16) ||
|
||||
output->type == kTfLiteUInt32) {
|
||||
EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
|
||||
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
|
||||
output->type == kTfLiteInt16) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -88,6 +88,20 @@ class IntegerMulOpModel : public BaseMulOpModel<int32_t> {
|
|||
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
|
||||
};
|
||||
|
||||
class UnsignedInteger32BitMulOpModel : public BaseMulOpModel<uint32_t> {
|
||||
public:
|
||||
using BaseMulOpModel::BaseMulOpModel;
|
||||
|
||||
std::vector<uint32_t> GetOutput() { return ExtractVector<uint32_t>(output_); }
|
||||
};
|
||||
|
||||
class Integer16BitMulOpModel : public BaseMulOpModel<int16_t> {
|
||||
public:
|
||||
using BaseMulOpModel::BaseMulOpModel;
|
||||
|
||||
std::vector<int16_t> GetOutput() { return ExtractVector<int16_t>(output_); }
|
||||
};
|
||||
|
||||
// For quantized Mul, the error shouldn't exceed (2*step + step^2).
|
||||
// The param min=-1.0 & max=1.0 is used in the following tests.
|
||||
// The tolerance value is ~0.0157.
|
||||
|
|
@ -346,6 +360,86 @@ TEST_P(MulOpTest, IntegerNoActivation) {
|
|||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40}));
|
||||
}
|
||||
|
||||
TEST_P(MulOpTest, Int16ActivationRELU_N1_TO_1) {
|
||||
bool constant_tensors = GetParam();
|
||||
if (SingleOpModel::GetForceUseNnapi() && constant_tensors) {
|
||||
// NNAPI does not support graphs with all constant inputs.
|
||||
return;
|
||||
}
|
||||
Integer16BitMulOpModel m(
|
||||
{TensorType_INT16, {1, 2, 2, 1}}, {TensorType_INT16, {1, 2, 2, 1}},
|
||||
{TensorType_INT16, {}}, ActivationFunctionType_RELU_N1_TO_1,
|
||||
{-20, 2, 7, 8}, {1, 2, 3, 5}, constant_tensors);
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST_P(MulOpTest, Int16VariousInputShapes) {
|
||||
bool constant_tensors = GetParam();
|
||||
if (SingleOpModel::GetForceUseNnapi() && constant_tensors) {
|
||||
// NNAPI does not support graphs with all constant inputs.
|
||||
return;
|
||||
}
|
||||
const std::vector<std::vector<int>> test_shapes = {
|
||||
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
|
||||
for (int i = 0; i < test_shapes.size(); ++i) {
|
||||
Integer16BitMulOpModel m(
|
||||
{TensorType_INT16, test_shapes[i]}, {TensorType_INT16, test_shapes[i]},
|
||||
{TensorType_INT16, {}}, ActivationFunctionType_NONE,
|
||||
{-20, 2, 7, 8, 11, 20}, {1, 2, 3, 5, 11, 1}, constant_tensors);
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40, 121, 20}))
|
||||
<< "With shape number " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MulOpTest, Int16WithBroadcast) {
|
||||
bool constant_tensors = GetParam();
|
||||
if (SingleOpModel::GetForceUseNnapi() && constant_tensors) {
|
||||
// NNAPI does not support graphs with all constant inputs.
|
||||
return;
|
||||
}
|
||||
const std::vector<std::vector<int>> test_shapes = {
|
||||
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
|
||||
for (int i = 0; i < test_shapes.size(); ++i) {
|
||||
Integer16BitMulOpModel m({TensorType_INT16, test_shapes[i]},
|
||||
{TensorType_INT16, {}}, // always a scalar
|
||||
{TensorType_INT16, {}},
|
||||
ActivationFunctionType_NONE,
|
||||
{-20, 2, 7, 8, 11, 20}, {1}, constant_tensors);
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 2, 7, 8, 11, 20}))
|
||||
<< "With shape number " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MulOpTest, 16BitIntegerNoActivation) {
|
||||
bool constant_tensors = GetParam();
|
||||
if (SingleOpModel::GetForceUseNnapi() && constant_tensors) {
|
||||
// NNAPI does not support graphs with all constant inputs.
|
||||
return;
|
||||
}
|
||||
Integer16BitMulOpModel m({TensorType_INT16, {4}}, {TensorType_INT16, {4}},
|
||||
{TensorType_INT16, {}}, ActivationFunctionType_NONE,
|
||||
{-20, 2, 7, 8}, {1, 2, 3, 5}, constant_tensors);
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40}));
|
||||
}
|
||||
|
||||
TEST_P(MulOpTest, 32BitUnsignedIntegerNoActivation) {
|
||||
bool constant_tensors = GetParam();
|
||||
if (SingleOpModel::GetForceUseNnapi() && constant_tensors) {
|
||||
// NNAPI does not support graphs with all constant inputs.
|
||||
return;
|
||||
}
|
||||
UnsignedInteger32BitMulOpModel m(
|
||||
{TensorType_UINT32, {1, 2, 2, 1}}, {TensorType_UINT32, {1, 2, 2, 1}},
|
||||
{TensorType_UINT32, {}}, ActivationFunctionType_NONE, {20, 2, 7, 8},
|
||||
{1, 2, 3, 5}, constant_tensors);
|
||||
ASSERT_EQ(m.Invoke(), kTfLiteOk);
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({20, 4, 21, 40}));
|
||||
}
|
||||
|
||||
TEST_P(MulOpTest, ComplexBaseTest) {
|
||||
bool constant_tensors = GetParam();
|
||||
if (SingleOpModel::GetForceUseNnapi() && constant_tensors) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -284,7 +284,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_MUL, Register_MUL_REF(), /* min_version = */ 1,
|
||||
/* max_version = */ 6);
|
||||
/* max_version = */ 7);
|
||||
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
|
|
|
|||
|
|
@ -322,6 +322,22 @@ def make_mul_tests(options):
|
|||
"fully_quantize": [False],
|
||||
"dynamic_range_quantize": [False],
|
||||
},
|
||||
{
|
||||
"dtype": [tf.int16],
|
||||
"input_shape_1": [[1, 3, 3, 3]],
|
||||
"input_shape_2": [[3], [1, 3, 3, 3]],
|
||||
"activation": [False],
|
||||
"fully_quantize": [False],
|
||||
"dynamic_range_quantize": [False],
|
||||
},
|
||||
{
|
||||
"dtype": [tf.uint32],
|
||||
"input_shape_1": [[1, 3, 3, 3]],
|
||||
"input_shape_2": [[3], [1, 3, 3, 3]],
|
||||
"activation": [False],
|
||||
"fully_quantize": [False],
|
||||
"dynamic_range_quantize": [False],
|
||||
},
|
||||
]
|
||||
make_binary_op_tests(
|
||||
options,
|
||||
|
|
|
|||
|
|
@ -662,11 +662,13 @@ class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
|
|||
const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
|
||||
const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
|
||||
const float output_scale = output_quant ? output_quant->scale : 0.0f;
|
||||
const bool input_quantized = input1_quant || input2_quant;
|
||||
::tflite::OpSignature op_sig =
|
||||
GetVersioningOpSig(builtin_op(), op_signature);
|
||||
op_sig.ext_options.mul.input1_scale = input1_scale;
|
||||
op_sig.ext_options.mul.input2_scale = input2_scale;
|
||||
op_sig.ext_options.mul.output_scale = output_scale;
|
||||
op_sig.ext_options.mul.input_quantized = input_quantized;
|
||||
return ::tflite::GetBuiltinOperatorVersion(op_sig);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ tf_cc_test(
|
|||
deps = [
|
||||
":versioning",
|
||||
"//tensorflow/lite:builtin_op_data",
|
||||
"//tensorflow/lite/core/c:c_api_types",
|
||||
"//tensorflow/lite/core/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/schema:schema_fbs_with_mutable",
|
||||
|
|
|
|||
|
|
@ -189,6 +189,9 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
|||
op_sig.ext_options.mul.input2_scale = input2_qunt->scale()->Get(0);
|
||||
op_sig.ext_options.mul.output_scale = output_quant->scale()->Get(0);
|
||||
}
|
||||
if (input1_quant || input2_qunt) {
|
||||
op_sig.ext_options.mul.input_quantized = true;
|
||||
}
|
||||
} break;
|
||||
|
||||
case BuiltinOperator_CONV_2D: {
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ typedef struct {
|
|||
float input1_scale;
|
||||
float input2_scale;
|
||||
float output_scale;
|
||||
bool input_quantized;
|
||||
} mul;
|
||||
struct {
|
||||
int32_t num_dims;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -267,6 +267,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||
return 1;
|
||||
|
||||
case BuiltinOperator_MUL:
|
||||
// Version 7 supports int16 and uint32 inputs
|
||||
if ((op_sig.inputs.at(0).type == kTfLiteInt16 &&
|
||||
!op_sig.ext_options.mul.input_quantized) ||
|
||||
op_sig.inputs.at(0).type == kTfLiteUInt32) {
|
||||
return 7;
|
||||
}
|
||||
// Version 6 supports complex32 inputs
|
||||
if (op_sig.inputs.at(0).type == kTfLiteComplex64) {
|
||||
return 6;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/builtin_op_data.h"
|
||||
#include "tensorflow/lite/core/c/c_api_types.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
|
@ -469,6 +470,21 @@ TEST(OpVersionTest, VersioningSubTest) {
|
|||
SimpleVersioningTest(BuiltinOperator_SUB);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningMUL7TestInt16) {
|
||||
OpSignature fake_op_sig;
|
||||
fake_op_sig.op = BuiltinOperator_MUL;
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16);
|
||||
fake_op_sig.ext_options.mul.input_quantized = false;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningMUL7TestUInt32) {
|
||||
OpSignature fake_op_sig;
|
||||
fake_op_sig.op = BuiltinOperator_MUL;
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningMUL6Test) {
|
||||
OpSignature fake_op_sig;
|
||||
fake_op_sig.op = BuiltinOperator_MUL;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
@ -162,6 +162,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_MUL, 4}, "2.3.0"},
|
||||
{{BuiltinOperator_MUL, 5}, "2.6.0"},
|
||||
{{BuiltinOperator_MUL, 6}, "2.11.0"},
|
||||
{{BuiltinOperator_MUL, 7}, "2.13.0"},
|
||||
{{BuiltinOperator_NON_MAX_SUPPRESSION_V4, 1}, "2.1.0"},
|
||||
{{BuiltinOperator_NON_MAX_SUPPRESSION_V5, 1}, "2.1.0"},
|
||||
{{BuiltinOperator_PAD, 1}, "1.5.0"},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user