mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Support TFLite BitwiseXor op for 8-bit/16-bit/32-bit integer/unsigned integer
PiperOrigin-RevId: 520507082
This commit is contained in:
parent
e1e39d4e92
commit
fd714e7701
|
|
@ -44,6 +44,7 @@
|
|||
`equal`
|
||||
* Add 8-bit and 16-bit support for `floor_div` and `floor_mod`.
|
||||
* Add 16-bit and 32-bit int support for the built-in op `bitcast`.
|
||||
* Add 8-bit/16-bit/32-bit int/uint support for the built-in op `bitwise_xor`
|
||||
* Add int16 indices support for built-in op `gather` and `gather_nd`.
|
||||
* Add reference implementation for 16-bit int unquantized `add`.
|
||||
* Add reference implementation for 16-bit int and 32-bit unsigned int unquantized `mul`.
|
||||
|
|
|
|||
|
|
@ -4053,6 +4053,26 @@ def TFL_BitcastOp : TFL_Op<"bitcast", [Pure]> {
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TFL_BitwiseXorOp : TFL_Op<"bitwise_xor", [
|
||||
Commutative,
|
||||
SameOperandsAndResultElementType,
|
||||
Pure]> {
|
||||
let summary = "Bitwise Xor operator";
|
||||
|
||||
let description = [{
|
||||
Elementwise computes the bitwise XOR of `x` and `y`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$lhs,
|
||||
TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantization ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -2578,6 +2578,15 @@ func.func @bitcastI16ToFloat(%arg0: tensor<8x2xi16>) -> tensor<8xf32> {
|
|||
// CHECK: return %[[RES0]] : tensor<8xf32>
|
||||
}
|
||||
|
||||
func.func @testBitwiseXor(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> {
|
||||
%0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32>
|
||||
func.return %0 : tensor<8xui32>
|
||||
|
||||
// CHECK-LABEL: testBitwiseXor
|
||||
// CHECK: %[[RES0:.*]] = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32>
|
||||
// CHECK: return %[[RES0]] : tensor<8xui32>
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Training OPs
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -3148,3 +3148,13 @@ func.func @testBitcast(%arg0: tensor<8xui32>) -> tensor<8xi32> {
|
|||
func.return %0 : tensor<8xi32>
|
||||
// CHECK: return %0 : tensor<8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testBitwiseXor
|
||||
func.func @testBitwiseXor(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> {
|
||||
// CHECK: "tfl.bitwise_xor"(%arg0, %arg1)
|
||||
%0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32>
|
||||
func.return %0 : tensor<8xui32>
|
||||
// CHECK: return %0 : tensor<8xui32>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -579,6 +579,9 @@ def LegalizeSign : Pat<(TF_SignOp $x), (TFL_SignOp $x)>;
|
|||
|
||||
def LegalizeBitcast : Pat<(TF_BitcastOp $x), (TFL_BitcastOp $x)>;
|
||||
|
||||
def LegalizeBitwiseXor : Pat<(TF_BitwiseXorOp $l, $r),
|
||||
(TFL_BitwiseXorOp $l, $r)>;
|
||||
|
||||
// =============================================================================
|
||||
// Training OPs
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ typedef enum {
|
|||
kTfLiteBuiltinUnsortedSegmentMin = 157,
|
||||
kTfLiteBuiltinSign = 158,
|
||||
kTfLiteBuiltinBitcast = 159,
|
||||
kTfLiteBuiltinBitwiseXor = 160,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
|||
|
|
@ -546,6 +546,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||
return ParseZerosLike(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_BITWISE_XOR: {
|
||||
return ParseBitwiseXor(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_CAST: {
|
||||
return ParseCast(op, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
|
@ -849,6 +853,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
|||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Below are the ops with no builtin_data structure.
|
||||
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
||||
// ok for now, since there is no call implementation either.
|
||||
|
|
@ -2453,6 +2458,14 @@ TfLiteStatus ParseZerosLike(const Operator*, ErrorReporter*,
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseBitwiseXor(const Operator*, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
|
|
|
|||
|
|
@ -412,6 +412,10 @@ TfLiteStatus ParseZerosLike(const Operator* op, ErrorReporter* error_reporter,
|
|||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseBitwiseXor(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ TfLiteRegistration* Register_WHERE();
|
|||
TfLiteRegistration* Register_WHILE();
|
||||
TfLiteRegistration* Register_ZEROS_LIKE();
|
||||
TfLiteRegistration* Register_BITCAST();
|
||||
TfLiteRegistration* Register_BITWISE_XOR();
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
|
|
|
|||
|
|
@ -360,6 +360,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_BITCAST, Register_BITCAST());
|
||||
AddBuiltin(BuiltinOperator_BITWISE_XOR, Register_BITWISE_XOR());
|
||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
|
|
|||
|
|
@ -587,6 +587,7 @@ BUILTIN_KERNEL_SRCS = [
|
|||
"bidirectional_sequence_lstm.cc",
|
||||
"bidirectional_sequence_rnn.cc",
|
||||
"bitcast.cc",
|
||||
"bitwise_xor.cc",
|
||||
"broadcast_args.cc",
|
||||
"broadcast_to.cc",
|
||||
"bucketize.cc",
|
||||
|
|
@ -2983,6 +2984,17 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "bitwise_xor_test",
|
||||
size = "small",
|
||||
srcs = ["bitwise_xor_test.cc"],
|
||||
deps = [
|
||||
":test_util",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite_combined(
|
||||
combine_conditions = {"deps": [":test_main"]},
|
||||
# TODO(b/229985981) : Remove `nnapi_args` after adding Relu0To1 is completed.
|
||||
|
|
|
|||
168
tensorflow/lite/kernels/bitwise_xor.cc
Normal file
168
tensorflow/lite/kernels/bitwise_xor.cc
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/core/c/c_api_types.h"
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace bitwise_xor {
|
||||
|
||||
// Input/output tensor index.
|
||||
constexpr int kInputTensor1 = 0;
|
||||
constexpr int kInputTensor2 = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
// Op data for bitwise xor op.
|
||||
struct OpData {
|
||||
bool requires_broadcast = false;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* data = new OpData;
|
||||
return data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
|
||||
output->type = input1->type;
|
||||
|
||||
data->requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
|
||||
TfLiteIntArray* output_size = nullptr;
|
||||
if (data->requires_broadcast) {
|
||||
TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
|
||||
context, input1, input2, &output_size));
|
||||
} else {
|
||||
output_size = TfLiteIntArrayCopy(input1->dims);
|
||||
}
|
||||
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T BitwiseXor(T x, T y) {
|
||||
return x ^ y;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
const TfLiteType type = output->type;
|
||||
switch (type) {
|
||||
// The fallthrough is indended. Since bitwise xor function operates on the
|
||||
// underlying binary representation of the integers, both integers and
|
||||
// unsigned integers will have the same behavior
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
if (data->requires_broadcast) {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<int8_t, int8_t, int8_t>(
|
||||
GetTensorShape(input1), GetTensorData<int8_t>(input1),
|
||||
GetTensorShape(input2), GetTensorData<int8_t>(input2),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output), BitwiseXor);
|
||||
} else {
|
||||
reference_ops::BinaryFunction<int8_t, int8_t, int8_t>(
|
||||
GetTensorShape(input1), GetTensorData<int8_t>(input1),
|
||||
GetTensorShape(input2), GetTensorData<int8_t>(input2),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output), BitwiseXor);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt16:
|
||||
case kTfLiteInt16: {
|
||||
if (data->requires_broadcast) {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<int16_t, int16_t, int16_t>(
|
||||
GetTensorShape(input1), GetTensorData<int16_t>(input1),
|
||||
GetTensorShape(input2), GetTensorData<int16_t>(input2),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output), BitwiseXor);
|
||||
} else {
|
||||
reference_ops::BinaryFunction<int16_t, int16_t, int16_t>(
|
||||
GetTensorShape(input1), GetTensorData<int16_t>(input1),
|
||||
GetTensorShape(input2), GetTensorData<int16_t>(input2),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output), BitwiseXor);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt32:
|
||||
case kTfLiteInt32: {
|
||||
if (data->requires_broadcast) {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<int32_t, int32_t, int32_t>(
|
||||
GetTensorShape(input1), GetTensorData<int32_t>(input1),
|
||||
GetTensorShape(input2), GetTensorData<int32_t>(input2),
|
||||
GetTensorShape(output), GetTensorData<int32_t>(output), BitwiseXor);
|
||||
} else {
|
||||
reference_ops::BinaryFunction<int32_t, int32_t, int32_t>(
|
||||
GetTensorShape(input1), GetTensorData<int32_t>(input1),
|
||||
GetTensorShape(input2), GetTensorData<int32_t>(input2),
|
||||
GetTensorShape(output), GetTensorData<int32_t>(output), BitwiseXor);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"BitwiseXor currently only supports "
|
||||
"8-bit/16-bit/32-bit integer/unsigned integer, got %s",
|
||||
TfLiteTypeGetName(type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace bitwise_xor
|
||||
|
||||
TfLiteRegistration* Register_BITWISE_XOR() {
|
||||
static TfLiteRegistration r = {bitwise_xor::Init, bitwise_xor::Free,
|
||||
bitwise_xor::Prepare, bitwise_xor::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
140
tensorflow/lite/kernels/bitwise_xor_test.cc
Normal file
140
tensorflow/lite/kernels/bitwise_xor_test.cc
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class BitwiseXorOpModel : public SingleOpModel {
|
||||
public:
|
||||
BitwiseXorOpModel(std::initializer_list<int> input1_shape,
|
||||
std::initializer_list<int> input2_shape,
|
||||
TensorType tensor_type) {
|
||||
input1_ = AddInput(tensor_type);
|
||||
input2_ = AddInput(tensor_type);
|
||||
output_ = AddOutput(tensor_type);
|
||||
SetBuiltinOp(BuiltinOperator_BITWISE_XOR, BuiltinOptions_BitwiseXorOptions,
|
||||
CreateBitwiseXorOptions(builder_).Union());
|
||||
BuildInterpreter({input1_shape, input2_shape});
|
||||
}
|
||||
|
||||
int input1() const { return input1_; }
|
||||
int input2() const { return input2_; }
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
protected:
|
||||
int input1_;
|
||||
int input2_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(BitwiseXorOpTest, SimpleTestInt8) {
|
||||
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT8);
|
||||
model.PopulateTensor<int8_t>(model.input1(), {0, 5, 3, 14});
|
||||
model.PopulateTensor<int8_t>(model.input2(), {5, 0, 7, 11});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput<int8_t>(), ElementsAreArray({5, 5, 4, 5}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(BitwiseXorOpTest, SimpleTestInt16) {
|
||||
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT16);
|
||||
model.PopulateTensor<int16_t>(model.input1(), {0, 5, 3, 14});
|
||||
model.PopulateTensor<int16_t>(model.input2(), {5, 0, 7, 11});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 5, 4, 5}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(BitwiseXorOpTest, SimpleTestInt32) {
|
||||
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
|
||||
model.PopulateTensor<int32_t>(model.input1(), {0, 5, 3, 14});
|
||||
model.PopulateTensor<int32_t>(model.input2(), {5, 0, 7, 11});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 5, 4, 5}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(BitwiseXorOpTest, SimpleTestUInt8) {
|
||||
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT8);
|
||||
model.PopulateTensor<uint8_t>(model.input1(), {0, 5, 3, 14});
|
||||
model.PopulateTensor<uint8_t>(model.input2(), {5, 0, 7, 11});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput<uint8_t>(), ElementsAreArray({5, 5, 4, 5}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(BitwiseXorOpTest, SimpleTestUInt16) {
|
||||
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT16);
|
||||
model.PopulateTensor<uint16_t>(model.input1(), {0, 5, 3, 14});
|
||||
model.PopulateTensor<uint16_t>(model.input2(), {5, 0, 7, 11});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput<uint16_t>(), ElementsAreArray({5, 5, 4, 5}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(BitwiseXorOpTest, SimpleTestUInt32) {
|
||||
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_UINT32);
|
||||
model.PopulateTensor<uint32_t>(model.input1(), {0, 5, 3, 14});
|
||||
model.PopulateTensor<uint32_t>(model.input2(), {5, 0, 7, 11});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput<uint32_t>(), ElementsAreArray({5, 5, 4, 5}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(BitwiseXorOpTest, BroadcastLhs) {
|
||||
BitwiseXorOpModel model({1, 1, 1, 1}, {1, 1, 1, 4}, TensorType_INT32);
|
||||
model.PopulateTensor<int32_t>(model.input1(), {5});
|
||||
model.PopulateTensor<int32_t>(model.input2(), {0, -5, -3, 14});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, -2, -8, 11}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(BitwiseXorOpTest, BroadcastRhs) {
|
||||
BitwiseXorOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_UINT32);
|
||||
model.PopulateTensor<uint32_t>(model.input1(), {0, 5, 3, 14});
|
||||
model.PopulateTensor<uint32_t>(model.input2(), {5});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput<uint32_t>(), ElementsAreArray({5, 0, 6, 11}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
|
@ -172,3 +172,4 @@ TFLITE_OP(Register_ATAN2)
|
|||
TFLITE_OP(Register_UNSORTED_SEGMENT_MIN)
|
||||
TFLITE_OP(Register_SIGN)
|
||||
TFLITE_OP(Register_BITCAST)
|
||||
TFLITE_OP(Register_BITWISE_XOR)
|
||||
|
|
|
|||
|
|
@ -182,6 +182,7 @@ TfLiteRegistration* Register_VAR_HANDLE();
|
|||
TfLiteRegistration* Register_READ_VARIABLE();
|
||||
TfLiteRegistration* Register_ASSIGN_VARIABLE();
|
||||
TfLiteRegistration* Register_BITCAST();
|
||||
TfLiteRegistration* Register_BITWISE_XOR();
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -530,6 +531,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||
AddBuiltin(BuiltinOperator_READ_VARIABLE, Register_READ_VARIABLE());
|
||||
AddBuiltin(BuiltinOperator_ASSIGN_VARIABLE, Register_ASSIGN_VARIABLE());
|
||||
AddBuiltin(BuiltinOperator_BITCAST, Register_BITCAST());
|
||||
AddBuiltin(BuiltinOperator_BITWISE_XOR, Register_BITWISE_XOR());
|
||||
AddCustom("NumericVerify",
|
||||
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
|
|
|
|||
|
|
@ -414,7 +414,8 @@ enum BuiltinOperator : int32 {
|
|||
ATAN2 = 156,
|
||||
UNSORTED_SEGMENT_MIN = 157,
|
||||
SIGN = 158,
|
||||
BITCAST = 159
|
||||
BITCAST = 159,
|
||||
BITWISE_XOR = 160,
|
||||
}
|
||||
// LINT.ThenChange(nnapi_linter/linter.proto)
|
||||
|
||||
|
|
@ -543,7 +544,8 @@ union BuiltinOptions {
|
|||
UnsortedSegmentSumOptions,
|
||||
ATan2Options,
|
||||
SignOptions,
|
||||
BitcastOptions
|
||||
BitcastOptions,
|
||||
BitwiseXorOptions,
|
||||
}
|
||||
|
||||
// LINT.IfChange
|
||||
|
|
@ -1183,6 +1185,8 @@ table SignOptions {
|
|||
table BitcastOptions {
|
||||
}
|
||||
|
||||
table BitwiseXorOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
|
|
|
|||
|
|
@ -561,6 +561,10 @@ struct BitcastOptions;
|
|||
struct BitcastOptionsBuilder;
|
||||
struct BitcastOptionsT;
|
||||
|
||||
struct BitwiseXorOptions;
|
||||
struct BitwiseXorOptionsBuilder;
|
||||
struct BitwiseXorOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeBuilder;
|
||||
struct OperatorCodeT;
|
||||
|
|
@ -1078,11 +1082,12 @@ enum BuiltinOperator : int32_t {
|
|||
BuiltinOperator_UNSORTED_SEGMENT_MIN = 157,
|
||||
BuiltinOperator_SIGN = 158,
|
||||
BuiltinOperator_BITCAST = 159,
|
||||
BuiltinOperator_BITWISE_XOR = 160,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_BITCAST
|
||||
BuiltinOperator_MAX = BuiltinOperator_BITWISE_XOR
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[160] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[161] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
|
|
@ -1243,13 +1248,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[160] {
|
|||
BuiltinOperator_ATAN2,
|
||||
BuiltinOperator_UNSORTED_SEGMENT_MIN,
|
||||
BuiltinOperator_SIGN,
|
||||
BuiltinOperator_BITCAST
|
||||
BuiltinOperator_BITCAST,
|
||||
BuiltinOperator_BITWISE_XOR
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesBuiltinOperator() {
|
||||
static const char * const names[161] = {
|
||||
static const char * const names[162] = {
|
||||
"ADD",
|
||||
"AVERAGE_POOL_2D",
|
||||
"CONCATENATION",
|
||||
|
|
@ -1410,13 +1416,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
|||
"UNSORTED_SEGMENT_MIN",
|
||||
"SIGN",
|
||||
"BITCAST",
|
||||
"BITWISE_XOR",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
||||
if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BITCAST)) return "";
|
||||
if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BITWISE_XOR)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOperator()[index];
|
||||
}
|
||||
|
|
@ -1547,11 +1554,12 @@ enum BuiltinOptions : uint8_t {
|
|||
BuiltinOptions_ATan2Options = 122,
|
||||
BuiltinOptions_SignOptions = 123,
|
||||
BuiltinOptions_BitcastOptions = 124,
|
||||
BuiltinOptions_BitwiseXorOptions = 125,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_BitcastOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_BitwiseXorOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[125] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[126] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
|
|
@ -1677,13 +1685,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[125] {
|
|||
BuiltinOptions_UnsortedSegmentSumOptions,
|
||||
BuiltinOptions_ATan2Options,
|
||||
BuiltinOptions_SignOptions,
|
||||
BuiltinOptions_BitcastOptions
|
||||
BuiltinOptions_BitcastOptions,
|
||||
BuiltinOptions_BitwiseXorOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesBuiltinOptions() {
|
||||
static const char * const names[126] = {
|
||||
static const char * const names[127] = {
|
||||
"NONE",
|
||||
"Conv2DOptions",
|
||||
"DepthwiseConv2DOptions",
|
||||
|
|
@ -1809,13 +1818,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
|||
"ATan2Options",
|
||||
"SignOptions",
|
||||
"BitcastOptions",
|
||||
"BitwiseXorOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
||||
if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BitcastOptions)) return "";
|
||||
if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BitwiseXorOptions)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOptions()[index];
|
||||
}
|
||||
|
|
@ -2320,6 +2330,10 @@ template<> struct BuiltinOptionsTraits<tflite::BitcastOptions> {
|
|||
static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<tflite::BitwiseXorOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions;
|
||||
};
|
||||
|
||||
template<typename T> struct BuiltinOptionsUnionTraits {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_NONE;
|
||||
};
|
||||
|
|
@ -2820,6 +2834,10 @@ template<> struct BuiltinOptionsUnionTraits<tflite::BitcastOptionsT> {
|
|||
static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsUnionTraits<tflite::BitwiseXorOptionsT> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
|
|
@ -3842,6 +3860,14 @@ struct BuiltinOptionsUnion {
|
|||
return type == BuiltinOptions_BitcastOptions ?
|
||||
reinterpret_cast<const tflite::BitcastOptionsT *>(value) : nullptr;
|
||||
}
|
||||
tflite::BitwiseXorOptionsT *AsBitwiseXorOptions() {
|
||||
return type == BuiltinOptions_BitwiseXorOptions ?
|
||||
reinterpret_cast<tflite::BitwiseXorOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const tflite::BitwiseXorOptionsT *AsBitwiseXorOptions() const {
|
||||
return type == BuiltinOptions_BitwiseXorOptions ?
|
||||
reinterpret_cast<const tflite::BitwiseXorOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
|
|
@ -11635,6 +11661,45 @@ inline ::flatbuffers::Offset<BitcastOptions> CreateBitcastOptions(
|
|||
|
||||
::flatbuffers::Offset<BitcastOptions> CreateBitcastOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitcastOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct BitwiseXorOptionsT : public ::flatbuffers::NativeTable {
|
||||
typedef BitwiseXorOptions TableType;
|
||||
};
|
||||
|
||||
struct BitwiseXorOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
|
||||
typedef BitwiseXorOptionsT NativeTableType;
|
||||
typedef BitwiseXorOptionsBuilder Builder;
|
||||
bool Verify(::flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
BitwiseXorOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(BitwiseXorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static ::flatbuffers::Offset<BitwiseXorOptions> Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct BitwiseXorOptionsBuilder {
|
||||
typedef BitwiseXorOptions Table;
|
||||
::flatbuffers::FlatBufferBuilder &fbb_;
|
||||
::flatbuffers::uoffset_t start_;
|
||||
explicit BitwiseXorOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
::flatbuffers::Offset<BitwiseXorOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = ::flatbuffers::Offset<BitwiseXorOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline ::flatbuffers::Offset<BitwiseXorOptions> CreateBitwiseXorOptions(
|
||||
::flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
BitwiseXorOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
::flatbuffers::Offset<BitwiseXorOptions> CreateBitwiseXorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public ::flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
int8_t deprecated_builtin_code = 0;
|
||||
|
|
@ -12150,6 +12215,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
|
|||
const tflite::BitcastOptions *builtin_options_as_BitcastOptions() const {
|
||||
return builtin_options_type() == tflite::BuiltinOptions_BitcastOptions ? static_cast<const tflite::BitcastOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const tflite::BitwiseXorOptions *builtin_options_as_BitwiseXorOptions() const {
|
||||
return builtin_options_type() == tflite::BuiltinOptions_BitwiseXorOptions ? static_cast<const tflite::BitwiseXorOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const ::flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
|
|
@ -12682,6 +12750,10 @@ template<> inline const tflite::BitcastOptions *Operator::builtin_options_as<tfl
|
|||
return builtin_options_as_BitcastOptions();
|
||||
}
|
||||
|
||||
template<> inline const tflite::BitwiseXorOptions *Operator::builtin_options_as<tflite::BitwiseXorOptions>() const {
|
||||
return builtin_options_as_BitwiseXorOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
typedef Operator Table;
|
||||
::flatbuffers::FlatBufferBuilder &fbb_;
|
||||
|
|
@ -17040,6 +17112,29 @@ inline ::flatbuffers::Offset<BitcastOptions> CreateBitcastOptions(::flatbuffers:
|
|||
_fbb);
|
||||
}
|
||||
|
||||
inline BitwiseXorOptionsT *BitwiseXorOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = std::unique_ptr<BitwiseXorOptionsT>(new BitwiseXorOptionsT());
|
||||
UnPackTo(_o.get(), _resolver);
|
||||
return _o.release();
|
||||
}
|
||||
|
||||
inline void BitwiseXorOptions::UnPackTo(BitwiseXorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline ::flatbuffers::Offset<BitwiseXorOptions> BitwiseXorOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateBitwiseXorOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline ::flatbuffers::Offset<BitwiseXorOptions> CreateBitwiseXorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BitwiseXorOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateBitwiseXorOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = std::unique_ptr<OperatorCodeT>(new OperatorCodeT());
|
||||
UnPackTo(_o.get(), _resolver);
|
||||
|
|
@ -18079,6 +18174,10 @@ inline bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *
|
|||
auto ptr = reinterpret_cast<const tflite::BitcastOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_BitwiseXorOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BitwiseXorOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return true;
|
||||
}
|
||||
}
|
||||
|
|
@ -18594,6 +18693,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
|||
auto ptr = reinterpret_cast<const tflite::BitcastOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_BitwiseXorOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BitwiseXorOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
|
|
@ -19097,6 +19200,10 @@ inline ::flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(::flatbuffers::Flat
|
|||
auto ptr = reinterpret_cast<const tflite::BitcastOptionsT *>(value);
|
||||
return CreateBitcastOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_BitwiseXorOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BitwiseXorOptionsT *>(value);
|
||||
return CreateBitwiseXorOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -19599,6 +19706,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) :
|
|||
value = new tflite::BitcastOptionsT(*reinterpret_cast<tflite::BitcastOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_BitwiseXorOptions: {
|
||||
value = new tflite::BitwiseXorOptionsT(*reinterpret_cast<tflite::BitwiseXorOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -20226,6 +20337,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
|||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_BitwiseXorOptions: {
|
||||
auto ptr = reinterpret_cast<tflite::BitwiseXorOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ def generated_test_models():
|
|||
"batch_to_space_nd",
|
||||
"batchmatmul",
|
||||
"bitcast",
|
||||
"bitwise_xor",
|
||||
"broadcast_args",
|
||||
"broadcast_gradient_args",
|
||||
"broadcast_to",
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from tensorflow.lite.testing.op_tests.batch_to_space_nd import make_batch_to_spa
|
|||
from tensorflow.lite.testing.op_tests.batchmatmul import make_batchmatmul_tests
|
||||
from tensorflow.lite.testing.op_tests.binary_op import make_add_tests, make_div_tests, make_sub_tests, make_mul_tests, make_pow_tests, make_floor_div_tests, make_floor_mod_tests, make_squared_difference_tests
|
||||
from tensorflow.lite.testing.op_tests.bitcast import make_bitcast_tests
|
||||
from tensorflow.lite.testing.op_tests.bitwise_xor import make_bitwise_xor_tests
|
||||
from tensorflow.lite.testing.op_tests.broadcast_args import make_broadcast_args_tests
|
||||
from tensorflow.lite.testing.op_tests.broadcast_gradient_args import make_broadcast_gradient_args_tests
|
||||
from tensorflow.lite.testing.op_tests.broadcast_to import make_broadcast_to_tests
|
||||
|
|
|
|||
78
tensorflow/lite/testing/op_tests/bitwise_xor.py
Normal file
78
tensorflow/lite/testing/op_tests/bitwise_xor.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test configs for bitwise_xor operator."""
|
||||
import tensorflow as tf
|
||||
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
|
||||
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
|
||||
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_bitwise_xor_tests(options):
|
||||
"""Generate examples for bitwise_xor."""
|
||||
test_parameters = [
|
||||
{
|
||||
"input_dtype": [
|
||||
tf.uint8,
|
||||
tf.int8,
|
||||
tf.uint16,
|
||||
tf.int16,
|
||||
tf.uint32,
|
||||
tf.int32,
|
||||
],
|
||||
"input_shape_pair": [
|
||||
([], []),
|
||||
([2, 3, 4], [2, 3, 4]),
|
||||
([1, 1, 1, 3], [1, 1, 1, 3]),
|
||||
([5, 5], [1]),
|
||||
([10], [2, 4, 10]),
|
||||
([2, 3, 3], [2, 3]), # this test case is intended to fail
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the bitwise_xor testing graph."""
|
||||
input_value1 = tf.compat.v1.placeholder(
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input1",
|
||||
shape=parameters["input_shape_pair"][0],
|
||||
)
|
||||
input_value2 = tf.compat.v1.placeholder(
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input2",
|
||||
shape=parameters["input_shape_pair"][1],
|
||||
)
|
||||
out = tf.bitwise.bitwise_xor(input_value1, input_value2)
|
||||
return [input_value1, input_value2], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value1 = create_tensor_data(
|
||||
parameters["input_dtype"], parameters["input_shape_pair"][0]
|
||||
)
|
||||
input_value2 = create_tensor_data(
|
||||
parameters["input_dtype"], parameters["input_shape_pair"][1]
|
||||
)
|
||||
return [input_value1, input_value2], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))
|
||||
)
|
||||
|
||||
make_zip_of_tests(
|
||||
options,
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
expected_tf_failures=6,
|
||||
)
|
||||
|
|
@ -98,6 +98,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
|
|||
"TfLiteVarHandleParams",
|
||||
"TfLiteUnsortedSegmentSumParams",
|
||||
"TfLiteUnsortedSegmentMinParams",
|
||||
"TfLiteBitwiseXorParams",
|
||||
nullptr};
|
||||
} // namespace
|
||||
|
||||
|
|
@ -221,6 +222,7 @@ class OpOptionData {
|
|||
op_to_option_["GELU"] = "";
|
||||
op_to_option_["DYNAMIC_UPDATE_SLICE"] = "";
|
||||
op_to_option_["BITCAST"] = "";
|
||||
op_to_option_["BITWISE_XOR"] = "";
|
||||
|
||||
// TODO(aselle): These are undesirable hacks. Consider changing C structs
|
||||
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
|
||||
|
|
|
|||
|
|
@ -409,7 +409,8 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_ATAN2, 1}, "2.10.0"},
|
||||
{{BuiltinOperator_SIGN, 1}, "2.11.0"},
|
||||
{{BuiltinOperator_SIGN, 2}, "2.12.0"},
|
||||
{{BuiltinOperator_BITCAST, 1}, "2.13.0"}});
|
||||
{{BuiltinOperator_BITCAST, 1}, "2.13.0"},
|
||||
{{BuiltinOperator_BITWISE_XOR, 1}, "2.13.0"}});
|
||||
|
||||
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};
|
||||
auto it = op_version_map->find(version_key);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user