Support TFLite BitwiseXor op for 8-bit/16-bit/32-bit integer/unsigned integer

PiperOrigin-RevId: 520507082
This commit is contained in:
Eric Yang 2023-03-29 19:28:09 -07:00 committed by TensorFlower Gardener
parent e1e39d4e92
commit fd714e7701
22 changed files with 602 additions and 13 deletions

View File

@ -44,6 +44,7 @@
`equal`
* Add 8-bit and 16-bit support for `floor_div` and `floor_mod`.
* Add 16-bit and 32-bit int support for the built-in op `bitcast`.
* Add 8-bit/16-bit/32-bit int/uint support for the built-in op `bitwise_xor`
* Add int16 indices support for built-in op `gather` and `gather_nd`.
* Add reference implementation for 16-bit int unquantized `add`.
* Add reference implementation for 16-bit int and 32-bit unsigned int unquantized `mul`.

View File

@ -4053,6 +4053,26 @@ def TFL_BitcastOp : TFL_Op<"bitcast", [Pure]> {
let hasVerifier = 1;
}
def TFL_BitwiseXorOp : TFL_Op<"bitwise_xor", [
Commutative,
SameOperandsAndResultElementType,
Pure]> {
let summary = "Bitwise Xor operator";
let description = [{
Elementwise computes the bitwise XOR of `x` and `y`.
}];
let arguments = (ins
TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$lhs,
TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$rhs
);
let results = (outs
TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$output
);
}
//===----------------------------------------------------------------------===//
// Quantization ops.
//===----------------------------------------------------------------------===//

View File

@ -2578,6 +2578,15 @@ func.func @bitcastI16ToFloat(%arg0: tensor<8x2xi16>) -> tensor<8xf32> {
// CHECK: return %[[RES0]] : tensor<8xf32>
}
func.func @testBitwiseXor(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> {
%0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32>
func.return %0 : tensor<8xui32>
// CHECK-LABEL: testBitwiseXor
// CHECK: %[[RES0:.*]] = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32>
// CHECK: return %[[RES0]] : tensor<8xui32>
}
// =============================================================================
// Training OPs
// =============================================================================

View File

@ -3148,3 +3148,13 @@ func.func @testBitcast(%arg0: tensor<8xui32>) -> tensor<8xi32> {
func.return %0 : tensor<8xi32>
// CHECK: return %0 : tensor<8xi32>
}
// -----
// CHECK-LABEL: testBitwiseXor
func.func @testBitwiseXor(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> {
// CHECK: "tfl.bitwise_xor"(%arg0, %arg1)
%0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32>
func.return %0 : tensor<8xui32>
// CHECK: return %0 : tensor<8xui32>
}

View File

@ -579,6 +579,9 @@ def LegalizeSign : Pat<(TF_SignOp $x), (TFL_SignOp $x)>;
def LegalizeBitcast : Pat<(TF_BitcastOp $x), (TFL_BitcastOp $x)>;
def LegalizeBitwiseXor : Pat<(TF_BitwiseXorOp $l, $r),
(TFL_BitwiseXorOp $l, $r)>;
// =============================================================================
// Training OPs
// =============================================================================

View File

@ -187,6 +187,7 @@ typedef enum {
kTfLiteBuiltinUnsortedSegmentMin = 157,
kTfLiteBuiltinSign = 158,
kTfLiteBuiltinBitcast = 159,
kTfLiteBuiltinBitwiseXor = 160,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -546,6 +546,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseZerosLike(op, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_BITWISE_XOR: {
return ParseBitwiseXor(op, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_CAST: {
return ParseCast(op, error_reporter, allocator, builtin_data);
}
@ -849,6 +853,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
// Below are the ops with no builtin_data structure.
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
// ok for now, since there is no call implementation either.
@ -2453,6 +2458,14 @@ TfLiteStatus ParseZerosLike(const Operator*, ErrorReporter*,
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseBitwiseXor(const Operator*, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {

View File

@ -412,6 +412,10 @@ TfLiteStatus ParseZerosLike(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseBitwiseXor(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
} // namespace tflite
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_

View File

@ -192,6 +192,7 @@ TfLiteRegistration* Register_WHERE();
TfLiteRegistration* Register_WHILE();
TfLiteRegistration* Register_ZEROS_LIKE();
TfLiteRegistration* Register_BITCAST();
TfLiteRegistration* Register_BITWISE_XOR();
} // namespace builtin
} // namespace ops

View File

@ -360,6 +360,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_BITCAST, Register_BITCAST());
AddBuiltin(BuiltinOperator_BITWISE_XOR, Register_BITWISE_XOR());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -587,6 +587,7 @@ BUILTIN_KERNEL_SRCS = [
"bidirectional_sequence_lstm.cc",
"bidirectional_sequence_rnn.cc",
"bitcast.cc",
"bitwise_xor.cc",
"broadcast_args.cc",
"broadcast_to.cc",
"bucketize.cc",
@ -2983,6 +2984,17 @@ cc_test(
],
)
cc_test(
name = "bitwise_xor_test",
size = "small",
srcs = ["bitwise_xor_test.cc"],
deps = [
":test_util",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_googletest//:gtest_main",
],
)
tflite_portable_test_suite_combined(
combine_conditions = {"deps": [":test_main"]},
# TODO(b/229985981) : Remove `nnapi_args` after adding Relu0To1 is completed.

View File

@ -0,0 +1,168 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace bitwise_xor {
// Input/output tensor index.
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
// Op data for bitwise xor op.
struct OpData {
bool requires_broadcast = false;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData;
return data;
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input1->type;
data->requires_broadcast = !HaveSameShapes(input1, input2);
TfLiteIntArray* output_size = nullptr;
if (data->requires_broadcast) {
TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
context, input1, input2, &output_size));
} else {
output_size = TfLiteIntArrayCopy(input1->dims);
}
return context->ResizeTensor(context, output, output_size);
}
template <typename T>
T BitwiseXor(T x, T y) {
return x ^ y;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteType type = output->type;
switch (type) {
// The fallthrough is indended. Since bitwise xor function operates on the
// underlying binary representation of the integers, both integers and
// unsigned integers will have the same behavior
case kTfLiteUInt8:
case kTfLiteInt8: {
if (data->requires_broadcast) {
reference_ops::BroadcastBinaryFunction4DSlow<int8_t, int8_t, int8_t>(
GetTensorShape(input1), GetTensorData<int8_t>(input1),
GetTensorShape(input2), GetTensorData<int8_t>(input2),
GetTensorShape(output), GetTensorData<int8_t>(output), BitwiseXor);
} else {
reference_ops::BinaryFunction<int8_t, int8_t, int8_t>(
GetTensorShape(input1), GetTensorData<int8_t>(input1),
GetTensorShape(input2), GetTensorData<int8_t>(input2),
GetTensorShape(output), GetTensorData<int8_t>(output), BitwiseXor);
}
break;
}
case kTfLiteUInt16:
case kTfLiteInt16: {
if (data->requires_broadcast) {
reference_ops::BroadcastBinaryFunction4DSlow<int16_t, int16_t, int16_t>(
GetTensorShape(input1), GetTensorData<int16_t>(input1),
GetTensorShape(input2), GetTensorData<int16_t>(input2),
GetTensorShape(output), GetTensorData<int16_t>(output), BitwiseXor);
} else {
reference_ops::BinaryFunction<int16_t, int16_t, int16_t>(
GetTensorShape(input1), GetTensorData<int16_t>(input1),
GetTensorShape(input2), GetTensorData<int16_t>(input2),
GetTensorShape(output), GetTensorData<int16_t>(output), BitwiseXor);
}
break;
}
case kTfLiteUInt32:
case kTfLiteInt32: {
if (data->requires_broadcast) {
reference_ops::BroadcastBinaryFunction4DSlow<int32_t, int32_t, int32_t>(
GetTensorShape(input1), GetTensorData<int32_t>(input1),
GetTensorShape(input2), GetTensorData<int32_t>(input2),
GetTensorShape(output), GetTensorData<int32_t>(output), BitwiseXor);
} else {
reference_ops::BinaryFunction<int32_t, int32_t, int32_t>(
GetTensorShape(input1), GetTensorData<int32_t>(input1),
GetTensorShape(input2), GetTensorData<int32_t>(input2),
GetTensorShape(output), GetTensorData<int32_t>(output), BitwiseXor);
}
break;
}
default:
TF_LITE_KERNEL_LOG(context,
"BitwiseXor currently only supports "
"8-bit/16-bit/32-bit integer/unsigned integer, got %s",
TfLiteTypeGetName(type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace bitwise_xor
TfLiteRegistration* Register_BITWISE_XOR() {
static TfLiteRegistration r = {bitwise_xor::Init, bitwise_xor::Free,
bitwise_xor::Prepare, bitwise_xor::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,140 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <initializer_list>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
class BitwiseXorOpModel : public SingleOpModel {
public:
BitwiseXorOpModel(std::initializer_list<int> input1_shape,
std::initializer_list<int> input2_shape,
TensorType tensor_type) {
input1_ = AddInput(tensor_type);
input2_ = AddInput(tensor_type);
output_ = AddOutput(tensor_type);
SetBuiltinOp(BuiltinOperator_BITWISE_XOR, BuiltinOptions_BitwiseXorOptions,
CreateBitwiseXorOptions(builder_).Union());
BuildInterpreter({input1_shape, input2_shape});
}
int input1() const { return input1_; }
int input2() const { return input2_; }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
int input1_;
int input2_;
int output_;
};
TEST(BitwiseXorOpTest, SimpleTestInt8) {
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT8);
model.PopulateTensor<int8_t>(model.input1(), {0, 5, 3, 14});
model.PopulateTensor<int8_t>(model.input2(), {5, 0, 7, 11});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput<int8_t>(), ElementsAreArray({5, 5, 4, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(BitwiseXorOpTest, SimpleTestInt16) {
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT16);
model.PopulateTensor<int16_t>(model.input1(), {0, 5, 3, 14});
model.PopulateTensor<int16_t>(model.input2(), {5, 0, 7, 11});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 5, 4, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(BitwiseXorOpTest, SimpleTestInt32) {
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
model.PopulateTensor<int32_t>(model.input1(), {0, 5, 3, 14});
model.PopulateTensor<int32_t>(model.input2(), {5, 0, 7, 11});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 5, 4, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(BitwiseXorOpTest, SimpleTestUInt8) {
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT8);
model.PopulateTensor<uint8_t>(model.input1(), {0, 5, 3, 14});
model.PopulateTensor<uint8_t>(model.input2(), {5, 0, 7, 11});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput<uint8_t>(), ElementsAreArray({5, 5, 4, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(BitwiseXorOpTest, SimpleTestUInt16) {
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT16);
model.PopulateTensor<uint16_t>(model.input1(), {0, 5, 3, 14});
model.PopulateTensor<uint16_t>(model.input2(), {5, 0, 7, 11});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput<uint16_t>(), ElementsAreArray({5, 5, 4, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(BitwiseXorOpTest, SimpleTestUInt32) {
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT32);
model.PopulateTensor<uint32_t>(model.input1(), {0, 5, 3, 14});
model.PopulateTensor<uint32_t>(model.input2(), {5, 0, 7, 11});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput<uint32_t>(), ElementsAreArray({5, 5, 4, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(BitwiseXorOpTest, BroadcastLhs) {
BitwiseXorOpModel model({1, 1, 1, 1}, {1, 1, 1, 4}, TensorType_INT32);
model.PopulateTensor<int32_t>(model.input1(), {5});
model.PopulateTensor<int32_t>(model.input2(), {0, -5, -3, 14});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, -2, -8, 11}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(BitwiseXorOpTest, BroadcastRhs) {
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_UINT32);
model.PopulateTensor<uint32_t>(model.input1(), {0, 5, 3, 14});
model.PopulateTensor<uint32_t>(model.input2(), {5});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput<uint32_t>(), ElementsAreArray({5, 0, 6, 11}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
} // namespace
} // namespace tflite

View File

@ -172,3 +172,4 @@ TFLITE_OP(Register_ATAN2)
TFLITE_OP(Register_UNSORTED_SEGMENT_MIN)
TFLITE_OP(Register_SIGN)
TFLITE_OP(Register_BITCAST)
TFLITE_OP(Register_BITWISE_XOR)

View File

@ -182,6 +182,7 @@ TfLiteRegistration* Register_VAR_HANDLE();
TfLiteRegistration* Register_READ_VARIABLE();
TfLiteRegistration* Register_ASSIGN_VARIABLE();
TfLiteRegistration* Register_BITCAST();
TfLiteRegistration* Register_BITWISE_XOR();
namespace {
@ -530,6 +531,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_READ_VARIABLE, Register_READ_VARIABLE());
AddBuiltin(BuiltinOperator_ASSIGN_VARIABLE, Register_ASSIGN_VARIABLE());
AddBuiltin(BuiltinOperator_BITCAST, Register_BITCAST());
AddBuiltin(BuiltinOperator_BITWISE_XOR, Register_BITWISE_XOR());
AddCustom("NumericVerify",
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that

View File

@ -414,7 +414,8 @@ enum BuiltinOperator : int32 {
ATAN2 = 156,
UNSORTED_SEGMENT_MIN = 157,
SIGN = 158,
BITCAST = 159
BITCAST = 159,
BITWISE_XOR = 160,
}
// LINT.ThenChange(nnapi_linter/linter.proto)
@ -543,7 +544,8 @@ union BuiltinOptions {
UnsortedSegmentSumOptions,
ATan2Options,
SignOptions,
BitcastOptions
BitcastOptions,
BitwiseXorOptions,
}
// LINT.IfChange
@ -1183,6 +1185,8 @@ table SignOptions {
table BitcastOptions {
}
table BitwiseXorOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.

View File

@ -561,6 +561,10 @@ struct BitcastOptions;
struct BitcastOptionsBuilder;
struct BitcastOptionsT;
struct BitwiseXorOptions;
struct BitwiseXorOptionsBuilder;
struct BitwiseXorOptionsT;
struct OperatorCode;
struct OperatorCodeBuilder;
struct OperatorCodeT;
@ -1078,11 +1082,12 @@ enum BuiltinOperator : int32_t {
BuiltinOperator_UNSORTED_SEGMENT_MIN = 157,
BuiltinOperator_SIGN = 158,
BuiltinOperator_BITCAST = 159,
BuiltinOperator_BITWISE_XOR = 160,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_BITCAST
BuiltinOperator_MAX = BuiltinOperator_BITWISE_XOR
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[160] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[161] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -1243,13 +1248,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[160] {
BuiltinOperator_ATAN2,
BuiltinOperator_UNSORTED_SEGMENT_MIN,
BuiltinOperator_SIGN,
BuiltinOperator_BITCAST
BuiltinOperator_BITCAST,
BuiltinOperator_BITWISE_XOR
};
return values;
}
inline const char * const *EnumNamesBuiltinOperator() {
static const char * const names[161] = {
static const char * const names[162] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@ -1410,13 +1416,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
"UNSORTED_SEGMENT_MIN",
"SIGN",
"BITCAST",
"BITWISE_XOR",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BITCAST)) return "";
if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BITWISE_XOR)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@ -1547,11 +1554,12 @@ enum BuiltinOptions : uint8_t {
BuiltinOptions_ATan2Options = 122,
BuiltinOptions_SignOptions = 123,
BuiltinOptions_BitcastOptions = 124,
BuiltinOptions_BitwiseXorOptions = 125,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_BitcastOptions
BuiltinOptions_MAX = BuiltinOptions_BitwiseXorOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[125] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[126] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -1677,13 +1685,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[125] {
BuiltinOptions_UnsortedSegmentSumOptions,
BuiltinOptions_ATan2Options,
BuiltinOptions_SignOptions,
BuiltinOptions_BitcastOptions
BuiltinOptions_BitcastOptions,
BuiltinOptions_BitwiseXorOptions
};
return values;
}
inline const char * const *EnumNamesBuiltinOptions() {
static const char * const names[126] = {
static const char * const names[127] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
@ -1809,13 +1818,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
"ATan2Options",
"SignOptions",
"BitcastOptions",
"BitwiseXorOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BitcastOptions)) return "";
if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BitwiseXorOptions)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@ -2320,6 +2330,10 @@ template<> struct BuiltinOptionsTraits<tflite::BitcastOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions;
};
template<> struct BuiltinOptionsTraits<tflite::BitwiseXorOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions;
};
template<typename T> struct BuiltinOptionsUnionTraits {
static const BuiltinOptions enum_value = BuiltinOptions_NONE;
};
@ -2820,6 +2834,10 @@ template<> struct BuiltinOptionsUnionTraits<tflite::BitcastOptionsT> {
static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions;
};
template<> struct BuiltinOptionsUnionTraits<tflite::BitwiseXorOptionsT> {
static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -3842,6 +3860,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_BitcastOptions ?
reinterpret_cast<const tflite::BitcastOptionsT *>(value) : nullptr;
}
tflite::BitwiseXorOptionsT *AsBitwiseXorOptions() {
return type == BuiltinOptions_BitwiseXorOptions ?
reinterpret_cast<tflite::BitwiseXorOptionsT *>(value) : nullptr;
}
const tflite::BitwiseXorOptionsT *AsBitwiseXorOptions() const {
return type == BuiltinOptions_BitwiseXorOptions ?
reinterpret_cast<const tflite::BitwiseXorOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -11635,6 +11661,45 @@ inline ::flatbuffers::Offset<BitcastOptions> CreateBitcastOptions(
::flatbuffers::Offset<BitcastOptions> CreateBitcastOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitcastOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct BitwiseXorOptionsT : public ::flatbuffers::NativeTable {
typedef BitwiseXorOptions TableType;
};
struct BitwiseXorOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef BitwiseXorOptionsT NativeTableType;
typedef BitwiseXorOptionsBuilder Builder;
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
BitwiseXorOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(BitwiseXorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
static ::flatbuffers::Offset<BitwiseXorOptions> Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct BitwiseXorOptionsBuilder {
typedef BitwiseXorOptions Table;
::flatbuffers::FlatBufferBuilder &fbb_;
::flatbuffers::uoffset_t start_;
explicit BitwiseXorOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
::flatbuffers::Offset<BitwiseXorOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = ::flatbuffers::Offset<BitwiseXorOptions>(end);
return o;
}
};
inline ::flatbuffers::Offset<BitwiseXorOptions> CreateBitwiseXorOptions(
::flatbuffers::FlatBufferBuilder &_fbb) {
BitwiseXorOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
::flatbuffers::Offset<BitwiseXorOptions> CreateBitwiseXorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public ::flatbuffers::NativeTable {
typedef OperatorCode TableType;
int8_t deprecated_builtin_code = 0;
@ -12150,6 +12215,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
const tflite::BitcastOptions *builtin_options_as_BitcastOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_BitcastOptions ? static_cast<const tflite::BitcastOptions *>(builtin_options()) : nullptr;
}
const tflite::BitwiseXorOptions *builtin_options_as_BitwiseXorOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_BitwiseXorOptions ? static_cast<const tflite::BitwiseXorOptions *>(builtin_options()) : nullptr;
}
const ::flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -12682,6 +12750,10 @@ template<> inline const tflite::BitcastOptions *Operator::builtin_options_as<tfl
return builtin_options_as_BitcastOptions();
}
template<> inline const tflite::BitwiseXorOptions *Operator::builtin_options_as<tflite::BitwiseXorOptions>() const {
return builtin_options_as_BitwiseXorOptions();
}
struct OperatorBuilder {
typedef Operator Table;
::flatbuffers::FlatBufferBuilder &fbb_;
@ -17040,6 +17112,29 @@ inline ::flatbuffers::Offset<BitcastOptions> CreateBitcastOptions(::flatbuffers:
_fbb);
}
inline BitwiseXorOptionsT *BitwiseXorOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const {
auto _o = std::unique_ptr<BitwiseXorOptionsT>(new BitwiseXorOptionsT());
UnPackTo(_o.get(), _resolver);
return _o.release();
}
inline void BitwiseXorOptions::UnPackTo(BitwiseXorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline ::flatbuffers::Offset<BitwiseXorOptions> BitwiseXorOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) {
return CreateBitwiseXorOptions(_fbb, _o, _rehasher);
}
inline ::flatbuffers::Offset<BitwiseXorOptions> CreateBitwiseXorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BitwiseXorOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateBitwiseXorOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const {
auto _o = std::unique_ptr<OperatorCodeT>(new OperatorCodeT());
UnPackTo(_o.get(), _resolver);
@ -18079,6 +18174,10 @@ inline bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *
auto ptr = reinterpret_cast<const tflite::BitcastOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_BitwiseXorOptions: {
auto ptr = reinterpret_cast<const tflite::BitwiseXorOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return true;
}
}
@ -18594,6 +18693,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const tflite::BitcastOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_BitwiseXorOptions: {
auto ptr = reinterpret_cast<const tflite::BitwiseXorOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -19097,6 +19200,10 @@ inline ::flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(::flatbuffers::Flat
auto ptr = reinterpret_cast<const tflite::BitcastOptionsT *>(value);
return CreateBitcastOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_BitwiseXorOptions: {
auto ptr = reinterpret_cast<const tflite::BitwiseXorOptionsT *>(value);
return CreateBitwiseXorOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -19599,6 +19706,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) :
value = new tflite::BitcastOptionsT(*reinterpret_cast<tflite::BitcastOptionsT *>(u.value));
break;
}
case BuiltinOptions_BitwiseXorOptions: {
value = new tflite::BitwiseXorOptionsT(*reinterpret_cast<tflite::BitwiseXorOptionsT *>(u.value));
break;
}
default:
break;
}
@ -20226,6 +20337,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_BitwiseXorOptions: {
auto ptr = reinterpret_cast<tflite::BitwiseXorOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -22,6 +22,7 @@ def generated_test_models():
"batch_to_space_nd",
"batchmatmul",
"bitcast",
"bitwise_xor",
"broadcast_args",
"broadcast_gradient_args",
"broadcast_to",

View File

@ -46,6 +46,7 @@ from tensorflow.lite.testing.op_tests.batch_to_space_nd import make_batch_to_spa
from tensorflow.lite.testing.op_tests.batchmatmul import make_batchmatmul_tests
from tensorflow.lite.testing.op_tests.binary_op import make_add_tests, make_div_tests, make_sub_tests, make_mul_tests, make_pow_tests, make_floor_div_tests, make_floor_mod_tests, make_squared_difference_tests
from tensorflow.lite.testing.op_tests.bitcast import make_bitcast_tests
from tensorflow.lite.testing.op_tests.bitwise_xor import make_bitwise_xor_tests
from tensorflow.lite.testing.op_tests.broadcast_args import make_broadcast_args_tests
from tensorflow.lite.testing.op_tests.broadcast_gradient_args import make_broadcast_gradient_args_tests
from tensorflow.lite.testing.op_tests.broadcast_to import make_broadcast_to_tests

View File

@ -0,0 +1,78 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test configs for bitwise_xor operator."""
import tensorflow as tf
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
@register_make_test_function()
def make_bitwise_xor_tests(options):
"""Generate examples for bitwise_xor."""
test_parameters = [
{
"input_dtype": [
tf.uint8,
tf.int8,
tf.uint16,
tf.int16,
tf.uint32,
tf.int32,
],
"input_shape_pair": [
([], []),
([2, 3, 4], [2, 3, 4]),
([1, 1, 1, 3], [1, 1, 1, 3]),
([5, 5], [1]),
([10], [2, 4, 10]),
([2, 3, 3], [2, 3]), # this test case is intended to fail
],
},
]
def build_graph(parameters):
"""Build the bitwise_xor testing graph."""
input_value1 = tf.compat.v1.placeholder(
dtype=parameters["input_dtype"],
name="input1",
shape=parameters["input_shape_pair"][0],
)
input_value2 = tf.compat.v1.placeholder(
dtype=parameters["input_dtype"],
name="input2",
shape=parameters["input_shape_pair"][1],
)
out = tf.bitwise.bitwise_xor(input_value1, input_value2)
return [input_value1, input_value2], [out]
def build_inputs(parameters, sess, inputs, outputs):
input_value1 = create_tensor_data(
parameters["input_dtype"], parameters["input_shape_pair"][0]
)
input_value2 = create_tensor_data(
parameters["input_dtype"], parameters["input_shape_pair"][1]
)
return [input_value1, input_value2], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))
)
make_zip_of_tests(
options,
test_parameters,
build_graph,
build_inputs,
expected_tf_failures=6,
)

View File

@ -98,6 +98,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
"TfLiteVarHandleParams",
"TfLiteUnsortedSegmentSumParams",
"TfLiteUnsortedSegmentMinParams",
"TfLiteBitwiseXorParams",
nullptr};
} // namespace
@ -221,6 +222,7 @@ class OpOptionData {
op_to_option_["GELU"] = "";
op_to_option_["DYNAMIC_UPDATE_SLICE"] = "";
op_to_option_["BITCAST"] = "";
op_to_option_["BITWISE_XOR"] = "";
// TODO(aselle): These are undesirable hacks. Consider changing C structs
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";

View File

@ -409,7 +409,8 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_ATAN2, 1}, "2.10.0"},
{{BuiltinOperator_SIGN, 1}, "2.11.0"},
{{BuiltinOperator_SIGN, 2}, "2.12.0"},
{{BuiltinOperator_BITCAST, 1}, "2.13.0"}});
{{BuiltinOperator_BITCAST, 1}, "2.13.0"},
{{BuiltinOperator_BITWISE_XOR, 1}, "2.13.0"}});
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};
auto it = op_version_map->find(version_key);