diff --git a/RELEASE.md b/RELEASE.md index bc81fb56f33..9918f27442f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -44,6 +44,7 @@ `equal` * Add 8-bit and 16-bit support for `floor_div` and `floor_mod`. * Add 16-bit and 32-bit int support for the built-in op `bitcast`. + * Add 8-bit/16-bit/32-bit int/uint support for the built-in op `bitwise_xor` * Add int16 indices support for built-in op `gather` and `gather_nd`. * Add reference implementation for 16-bit int unquantized `add`. * Add reference implementation for 16-bit int and 32-bit unsigned int unquantized `mul`. diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 20e6316b510..4ed963e3932 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -4053,6 +4053,26 @@ def TFL_BitcastOp : TFL_Op<"bitcast", [Pure]> { let hasVerifier = 1; } +def TFL_BitwiseXorOp : TFL_Op<"bitwise_xor", [ + Commutative, + SameOperandsAndResultElementType, + Pure]> { + let summary = "Bitwise Xor operator"; + + let description = [{ + Elementwise computes the bitwise XOR of `x` and `y`. + }]; + + let arguments = (ins + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$lhs, + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$rhs + ); + + let results = (outs + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$output + ); +} + //===----------------------------------------------------------------------===// // Quantization ops. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index f6b96350ffd..82ef14822c4 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -2578,6 +2578,15 @@ func.func @bitcastI16ToFloat(%arg0: tensor<8x2xi16>) -> tensor<8xf32> { // CHECK: return %[[RES0]] : tensor<8xf32> } +func.func @testBitwiseXor(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> { + %0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + func.return %0 : tensor<8xui32> + + // CHECK-LABEL: testBitwiseXor + // CHECK: %[[RES0:.*]] = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + // CHECK: return %[[RES0]] : tensor<8xui32> +} + // ============================================================================= // Training OPs // ============================================================================= diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index e101dbe3fda..afb1d35b0de 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -3148,3 +3148,13 @@ func.func @testBitcast(%arg0: tensor<8xui32>) -> tensor<8xi32> { func.return %0 : tensor<8xi32> // CHECK: return %0 : tensor<8xi32> } + +// ----- + +// CHECK-LABEL: testBitwiseXor +func.func @testBitwiseXor(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> { + // CHECK: "tfl.bitwise_xor"(%arg0, %arg1) + %0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + func.return %0 : tensor<8xui32> + // CHECK: return %0 : tensor<8xui32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index b47d4f9e85f..0db4bb7cf63 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -579,6 +579,9 @@ def LegalizeSign : Pat<(TF_SignOp $x), (TFL_SignOp $x)>; def LegalizeBitcast : Pat<(TF_BitcastOp $x), (TFL_BitcastOp $x)>; +def LegalizeBitwiseXor : Pat<(TF_BitwiseXorOp $l, $r), + (TFL_BitwiseXorOp $l, $r)>; + // ============================================================================= // Training OPs // ============================================================================= diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index b94fac62d94..996aaebbd90 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -187,6 +187,7 @@ typedef enum { kTfLiteBuiltinUnsortedSegmentMin = 157, kTfLiteBuiltinSign = 158, kTfLiteBuiltinBitcast = 159, + kTfLiteBuiltinBitwiseXor = 160, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index d8bc6a2c233..8d1bdf66bda 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -546,6 +546,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, return ParseZerosLike(op, error_reporter, allocator, builtin_data); } + case BuiltinOperator_BITWISE_XOR: { + return ParseBitwiseXor(op, error_reporter, allocator, builtin_data); + } + case BuiltinOperator_CAST: { return ParseCast(op, error_reporter, allocator, builtin_data); } @@ -849,6 +853,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } + // Below are the ops with no builtin_data structure. // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are // ok for now, since there is no call implementation either. @@ -2453,6 +2458,14 @@ TfLiteStatus ParseZerosLike(const Operator*, ErrorReporter*, return kTfLiteOk; } +// We have this parse function instead of directly returning kTfLiteOk from the +// switch-case in ParseOpData because this function is used as part of the +// selective registration for the OpResolver implementation in micro. +TfLiteStatus ParseBitwiseXor(const Operator*, ErrorReporter*, + BuiltinDataAllocator*, void**) { + return kTfLiteOk; +} + TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data) { diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index 2b117c279c6..fc4d97d97d6 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -412,6 +412,10 @@ TfLiteStatus ParseZerosLike(const Operator* op, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data); +TfLiteStatus ParseBitwiseXor(const Operator* op, ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + } // namespace tflite #endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/tensorflow/lite/core/kernels/builtin_op_kernels.h b/tensorflow/lite/core/kernels/builtin_op_kernels.h index c52c0948656..f290a7d7945 100644 --- a/tensorflow/lite/core/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/core/kernels/builtin_op_kernels.h @@ -192,6 +192,7 @@ TfLiteRegistration* Register_WHERE(); TfLiteRegistration* Register_WHILE(); TfLiteRegistration* Register_ZEROS_LIKE(); TfLiteRegistration* Register_BITCAST(); +TfLiteRegistration* Register_BITWISE_XOR(); } // namespace builtin } // namespace ops diff --git a/tensorflow/lite/core/kernels/register.cc b/tensorflow/lite/core/kernels/register.cc index 1685dbaa3e0..2a0716564cc 100644 --- a/tensorflow/lite/core/kernels/register.cc +++ b/tensorflow/lite/core/kernels/register.cc @@ -360,6 +360,7 @@ BuiltinOpResolver::BuiltinOpResolver() { /* min_version = */ 1, /* max_version = */ 2); AddBuiltin(BuiltinOperator_BITCAST, Register_BITCAST()); + AddBuiltin(BuiltinOperator_BITWISE_XOR, Register_BITWISE_XOR()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 4ec186d16d2..48ea76feae3 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -587,6 +587,7 @@ BUILTIN_KERNEL_SRCS = [ "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", "bitcast.cc", + "bitwise_xor.cc", "broadcast_args.cc", "broadcast_to.cc", "bucketize.cc", @@ -2983,6 +2984,17 @@ cc_test( ], ) +cc_test( + name = "bitwise_xor_test", + size = "small", + srcs = ["bitwise_xor_test.cc"], + deps = [ + ":test_util", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_googletest//:gtest_main", + ], +) + tflite_portable_test_suite_combined( combine_conditions = {"deps": [":test_main"]}, # TODO(b/229985981) : Remove `nnapi_args` after adding Relu0To1 is completed. diff --git a/tensorflow/lite/kernels/bitwise_xor.cc b/tensorflow/lite/kernels/bitwise_xor.cc new file mode 100644 index 00000000000..2650b1ee4e1 --- /dev/null +++ b/tensorflow/lite/kernels/bitwise_xor.cc @@ -0,0 +1,168 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/internal/reference/binary_function.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace bitwise_xor { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for bitwise xor op. +struct OpData { + bool requires_broadcast = false; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1; + TF_LITE_ENSURE_OK(context, + GetInputSafe(context, node, kInputTensor1, &input1)); + const TfLiteTensor* input2; + TF_LITE_ENSURE_OK(context, + GetInputSafe(context, node, kInputTensor2, &input2)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, + GetOutputSafe(context, node, kOutputTensor, &output)); + + TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); + + output->type = input1->type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +T BitwiseXor(T x, T y) { + return x ^ y; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1; + TF_LITE_ENSURE_OK(context, + GetInputSafe(context, node, kInputTensor1, &input1)); + const TfLiteTensor* input2; + TF_LITE_ENSURE_OK(context, + GetInputSafe(context, node, kInputTensor2, &input2)); + TfLiteTensor* output; + TF_LITE_ENSURE_OK(context, + GetOutputSafe(context, node, kOutputTensor, &output)); + + const TfLiteType type = output->type; + switch (type) { + // The fallthrough is indended. Since bitwise xor function operates on the + // underlying binary representation of the integers, both integers and + // unsigned integers will have the same behavior + case kTfLiteUInt8: + case kTfLiteInt8: { + if (data->requires_broadcast) { + reference_ops::BroadcastBinaryFunction4DSlow( + GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output), BitwiseXor); + } else { + reference_ops::BinaryFunction( + GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output), BitwiseXor); + } + break; + } + case kTfLiteUInt16: + case kTfLiteInt16: { + if (data->requires_broadcast) { + reference_ops::BroadcastBinaryFunction4DSlow( + GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output), BitwiseXor); + } else { + reference_ops::BinaryFunction( + GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output), BitwiseXor); + } + break; + } + case kTfLiteUInt32: + case kTfLiteInt32: { + if (data->requires_broadcast) { + reference_ops::BroadcastBinaryFunction4DSlow( + GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output), BitwiseXor); + } else { + reference_ops::BinaryFunction( + GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output), BitwiseXor); + } + break; + } + default: + TF_LITE_KERNEL_LOG(context, + "BitwiseXor currently only supports " + "8-bit/16-bit/32-bit integer/unsigned integer, got %s", + TfLiteTypeGetName(type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace bitwise_xor + +TfLiteRegistration* Register_BITWISE_XOR() { + static TfLiteRegistration r = {bitwise_xor::Init, bitwise_xor::Free, + bitwise_xor::Prepare, bitwise_xor::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/bitwise_xor_test.cc b/tensorflow/lite/kernels/bitwise_xor_test.cc new file mode 100644 index 00000000000..f508cb53142 --- /dev/null +++ b/tensorflow/lite/kernels/bitwise_xor_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BitwiseXorOpModel : public SingleOpModel { + public: + BitwiseXorOpModel(std::initializer_list input1_shape, + std::initializer_list input2_shape, + TensorType tensor_type) { + input1_ = AddInput(tensor_type); + input2_ = AddInput(tensor_type); + output_ = AddOutput(tensor_type); + SetBuiltinOp(BuiltinOperator_BITWISE_XOR, BuiltinOptions_BitwiseXorOptions, + CreateBitwiseXorOptions(builder_).Union()); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() const { return input1_; } + int input2() const { return input2_; } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(BitwiseXorOpTest, SimpleTestInt8) { + BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT8); + model.PopulateTensor(model.input1(), {0, 5, 3, 14}); + model.PopulateTensor(model.input2(), {5, 0, 7, 11}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 4, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(BitwiseXorOpTest, SimpleTestInt16) { + BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT16); + model.PopulateTensor(model.input1(), {0, 5, 3, 14}); + model.PopulateTensor(model.input2(), {5, 0, 7, 11}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 4, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(BitwiseXorOpTest, SimpleTestInt32) { + BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor(model.input1(), {0, 5, 3, 14}); + model.PopulateTensor(model.input2(), {5, 0, 7, 11}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 4, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(BitwiseXorOpTest, SimpleTestUInt8) { + BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT8); + model.PopulateTensor(model.input1(), {0, 5, 3, 14}); + model.PopulateTensor(model.input2(), {5, 0, 7, 11}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 4, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(BitwiseXorOpTest, SimpleTestUInt16) { + BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT16); + model.PopulateTensor(model.input1(), {0, 5, 3, 14}); + model.PopulateTensor(model.input2(), {5, 0, 7, 11}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 4, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(BitwiseXorOpTest, SimpleTestUInt32) { + BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT32); + model.PopulateTensor(model.input1(), {0, 5, 3, 14}); + model.PopulateTensor(model.input2(), {5, 0, 7, 11}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 4, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(BitwiseXorOpTest, BroadcastLhs) { + BitwiseXorOpModel model({1, 1, 1, 1}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor(model.input1(), {5}); + model.PopulateTensor(model.input2(), {0, -5, -3, 14}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, -2, -8, 11})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(BitwiseXorOpTest, BroadcastRhs) { + BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_UINT32); + model.PopulateTensor(model.input1(), {0, 5, 3, 14}); + model.PopulateTensor(model.input2(), {5}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 0, 6, 11})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/kernels/builtin_ops_list.inc b/tensorflow/lite/kernels/builtin_ops_list.inc index f21d59b4b54..e33a7cab5e7 100644 --- a/tensorflow/lite/kernels/builtin_ops_list.inc +++ b/tensorflow/lite/kernels/builtin_ops_list.inc @@ -172,3 +172,4 @@ TFLITE_OP(Register_ATAN2) TFLITE_OP(Register_UNSORTED_SEGMENT_MIN) TFLITE_OP(Register_SIGN) TFLITE_OP(Register_BITCAST) +TFLITE_OP(Register_BITWISE_XOR) diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 35af5beb65f..63300926c4b 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -182,6 +182,7 @@ TfLiteRegistration* Register_VAR_HANDLE(); TfLiteRegistration* Register_READ_VARIABLE(); TfLiteRegistration* Register_ASSIGN_VARIABLE(); TfLiteRegistration* Register_BITCAST(); +TfLiteRegistration* Register_BITWISE_XOR(); namespace { @@ -530,6 +531,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_READ_VARIABLE, Register_READ_VARIABLE()); AddBuiltin(BuiltinOperator_ASSIGN_VARIABLE, Register_ASSIGN_VARIABLE()); AddBuiltin(BuiltinOperator_BITCAST, Register_BITCAST()); + AddBuiltin(BuiltinOperator_BITWISE_XOR, Register_BITWISE_XOR()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY_REF()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 4affd55f69d..81e8729d13a 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -414,7 +414,8 @@ enum BuiltinOperator : int32 { ATAN2 = 156, UNSORTED_SEGMENT_MIN = 157, SIGN = 158, - BITCAST = 159 + BITCAST = 159, + BITWISE_XOR = 160, } // LINT.ThenChange(nnapi_linter/linter.proto) @@ -543,7 +544,8 @@ union BuiltinOptions { UnsortedSegmentSumOptions, ATan2Options, SignOptions, - BitcastOptions + BitcastOptions, + BitwiseXorOptions, } // LINT.IfChange @@ -1183,6 +1185,8 @@ table SignOptions { table BitcastOptions { } +table BitwiseXorOptions { +} // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index ab544e2ae0d..beced895a4b 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -561,6 +561,10 @@ struct BitcastOptions; struct BitcastOptionsBuilder; struct BitcastOptionsT; +struct BitwiseXorOptions; +struct BitwiseXorOptionsBuilder; +struct BitwiseXorOptionsT; + struct OperatorCode; struct OperatorCodeBuilder; struct OperatorCodeT; @@ -1078,11 +1082,12 @@ enum BuiltinOperator : int32_t { BuiltinOperator_UNSORTED_SEGMENT_MIN = 157, BuiltinOperator_SIGN = 158, BuiltinOperator_BITCAST = 159, + BuiltinOperator_BITWISE_XOR = 160, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_BITCAST + BuiltinOperator_MAX = BuiltinOperator_BITWISE_XOR }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[160] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[161] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -1243,13 +1248,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[160] { BuiltinOperator_ATAN2, BuiltinOperator_UNSORTED_SEGMENT_MIN, BuiltinOperator_SIGN, - BuiltinOperator_BITCAST + BuiltinOperator_BITCAST, + BuiltinOperator_BITWISE_XOR }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[161] = { + static const char * const names[162] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1410,13 +1416,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "UNSORTED_SEGMENT_MIN", "SIGN", "BITCAST", + "BITWISE_XOR", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BITCAST)) return ""; + if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BITWISE_XOR)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -1547,11 +1554,12 @@ enum BuiltinOptions : uint8_t { BuiltinOptions_ATan2Options = 122, BuiltinOptions_SignOptions = 123, BuiltinOptions_BitcastOptions = 124, + BuiltinOptions_BitwiseXorOptions = 125, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_BitcastOptions + BuiltinOptions_MAX = BuiltinOptions_BitwiseXorOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[125] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[126] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1677,13 +1685,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[125] { BuiltinOptions_UnsortedSegmentSumOptions, BuiltinOptions_ATan2Options, BuiltinOptions_SignOptions, - BuiltinOptions_BitcastOptions + BuiltinOptions_BitcastOptions, + BuiltinOptions_BitwiseXorOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[126] = { + static const char * const names[127] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1809,13 +1818,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "ATan2Options", "SignOptions", "BitcastOptions", + "BitwiseXorOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BitcastOptions)) return ""; + if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BitwiseXorOptions)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -2320,6 +2330,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions; +}; + template struct BuiltinOptionsUnionTraits { static const BuiltinOptions enum_value = BuiltinOptions_NONE; }; @@ -2820,6 +2834,10 @@ template<> struct BuiltinOptionsUnionTraits { static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions; }; +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -3842,6 +3860,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_BitcastOptions ? reinterpret_cast(value) : nullptr; } + tflite::BitwiseXorOptionsT *AsBitwiseXorOptions() { + return type == BuiltinOptions_BitwiseXorOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BitwiseXorOptionsT *AsBitwiseXorOptions() const { + return type == BuiltinOptions_BitwiseXorOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -11635,6 +11661,45 @@ inline ::flatbuffers::Offset CreateBitcastOptions( ::flatbuffers::Offset CreateBitcastOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitcastOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct BitwiseXorOptionsT : public ::flatbuffers::NativeTable { + typedef BitwiseXorOptions TableType; +}; + +struct BitwiseXorOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BitwiseXorOptionsT NativeTableType; + typedef BitwiseXorOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BitwiseXorOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BitwiseXorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BitwiseXorOptionsBuilder { + typedef BitwiseXorOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit BitwiseXorOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBitwiseXorOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + BitwiseXorOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBitwiseXorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public ::flatbuffers::NativeTable { typedef OperatorCode TableType; int8_t deprecated_builtin_code = 0; @@ -12150,6 +12215,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const tflite::BitcastOptions *builtin_options_as_BitcastOptions() const { return builtin_options_type() == tflite::BuiltinOptions_BitcastOptions ? static_cast(builtin_options()) : nullptr; } + const tflite::BitwiseXorOptions *builtin_options_as_BitwiseXorOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BitwiseXorOptions ? static_cast(builtin_options()) : nullptr; + } const ::flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -12682,6 +12750,10 @@ template<> inline const tflite::BitcastOptions *Operator::builtin_options_as inline const tflite::BitwiseXorOptions *Operator::builtin_options_as() const { + return builtin_options_as_BitwiseXorOptions(); +} + struct OperatorBuilder { typedef Operator Table; ::flatbuffers::FlatBufferBuilder &fbb_; @@ -17040,6 +17112,29 @@ inline ::flatbuffers::Offset CreateBitcastOptions(::flatbuffers: _fbb); } +inline BitwiseXorOptionsT *BitwiseXorOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BitwiseXorOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BitwiseXorOptions::UnPackTo(BitwiseXorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset BitwiseXorOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBitwiseXorOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBitwiseXorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BitwiseXorOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBitwiseXorOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { auto _o = std::unique_ptr(new OperatorCodeT()); UnPackTo(_o.get(), _resolver); @@ -18079,6 +18174,10 @@ inline bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void * auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BitwiseXorOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -18594,6 +18693,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_BitwiseXorOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -19097,6 +19200,10 @@ inline ::flatbuffers::Offset BuiltinOptionsUnion::Pack(::flatbuffers::Flat auto ptr = reinterpret_cast(value); return CreateBitcastOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_BitwiseXorOptions: { + auto ptr = reinterpret_cast(value); + return CreateBitwiseXorOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -19599,6 +19706,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) : value = new tflite::BitcastOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_BitwiseXorOptions: { + value = new tflite::BitwiseXorOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -20226,6 +20337,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_BitwiseXorOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/testing/build_def.bzl b/tensorflow/lite/testing/build_def.bzl index 05fe5f25431..e37ff1d8c72 100644 --- a/tensorflow/lite/testing/build_def.bzl +++ b/tensorflow/lite/testing/build_def.bzl @@ -22,6 +22,7 @@ def generated_test_models(): "batch_to_space_nd", "batchmatmul", "bitcast", + "bitwise_xor", "broadcast_args", "broadcast_gradient_args", "broadcast_to", diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 4f4147aa70d..d65e81e3bdb 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -46,6 +46,7 @@ from tensorflow.lite.testing.op_tests.batch_to_space_nd import make_batch_to_spa from tensorflow.lite.testing.op_tests.batchmatmul import make_batchmatmul_tests from tensorflow.lite.testing.op_tests.binary_op import make_add_tests, make_div_tests, make_sub_tests, make_mul_tests, make_pow_tests, make_floor_div_tests, make_floor_mod_tests, make_squared_difference_tests from tensorflow.lite.testing.op_tests.bitcast import make_bitcast_tests +from tensorflow.lite.testing.op_tests.bitwise_xor import make_bitwise_xor_tests from tensorflow.lite.testing.op_tests.broadcast_args import make_broadcast_args_tests from tensorflow.lite.testing.op_tests.broadcast_gradient_args import make_broadcast_gradient_args_tests from tensorflow.lite.testing.op_tests.broadcast_to import make_broadcast_to_tests diff --git a/tensorflow/lite/testing/op_tests/bitwise_xor.py b/tensorflow/lite/testing/op_tests/bitwise_xor.py new file mode 100644 index 00000000000..3c6f6a2fa4c --- /dev/null +++ b/tensorflow/lite/testing/op_tests/bitwise_xor.py @@ -0,0 +1,78 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test configs for bitwise_xor operator.""" +import tensorflow as tf +from tensorflow.lite.testing.zip_test_utils import create_tensor_data +from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests +from tensorflow.lite.testing.zip_test_utils import register_make_test_function + + +@register_make_test_function() +def make_bitwise_xor_tests(options): + """Generate examples for bitwise_xor.""" + test_parameters = [ + { + "input_dtype": [ + tf.uint8, + tf.int8, + tf.uint16, + tf.int16, + tf.uint32, + tf.int32, + ], + "input_shape_pair": [ + ([], []), + ([2, 3, 4], [2, 3, 4]), + ([1, 1, 1, 3], [1, 1, 1, 3]), + ([5, 5], [1]), + ([10], [2, 4, 10]), + ([2, 3, 3], [2, 3]), # this test case is intended to fail + ], + }, + ] + + def build_graph(parameters): + """Build the bitwise_xor testing graph.""" + input_value1 = tf.compat.v1.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0], + ) + input_value2 = tf.compat.v1.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1], + ) + out = tf.bitwise.bitwise_xor(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data( + parameters["input_dtype"], parameters["input_shape_pair"][0] + ) + input_value2 = create_tensor_data( + parameters["input_dtype"], parameters["input_shape_pair"][1] + ) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])) + ) + + make_zip_of_tests( + options, + test_parameters, + build_graph, + build_inputs, + expected_tf_failures=6, + ) diff --git a/tensorflow/lite/tools/serialization/option_writer_generator.cc b/tensorflow/lite/tools/serialization/option_writer_generator.cc index 8c374e3eade..520c708e28a 100644 --- a/tensorflow/lite/tools/serialization/option_writer_generator.cc +++ b/tensorflow/lite/tools/serialization/option_writer_generator.cc @@ -98,6 +98,7 @@ static const char* param_structs[] = {"TfLiteAddParams", "TfLiteVarHandleParams", "TfLiteUnsortedSegmentSumParams", "TfLiteUnsortedSegmentMinParams", + "TfLiteBitwiseXorParams", nullptr}; } // namespace @@ -221,6 +222,7 @@ class OpOptionData { op_to_option_["GELU"] = ""; op_to_option_["DYNAMIC_UPDATE_SLICE"] = ""; op_to_option_["BITCAST"] = ""; + op_to_option_["BITWISE_XOR"] = ""; // TODO(aselle): These are undesirable hacks. Consider changing C structs option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 54ed812c8e8..a7e3cfc559c 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -409,7 +409,8 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_ATAN2, 1}, "2.10.0"}, {{BuiltinOperator_SIGN, 1}, "2.11.0"}, {{BuiltinOperator_SIGN, 2}, "2.12.0"}, - {{BuiltinOperator_BITCAST, 1}, "2.13.0"}}); + {{BuiltinOperator_BITCAST, 1}, "2.13.0"}, + {{BuiltinOperator_BITWISE_XOR, 1}, "2.13.0"}}); std::pair version_key = {op_code, op_version}; auto it = op_version_map->find(version_key);