mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Support SQRT operator in XNNPACK delegate
PiperOrigin-RevId: 320429761 Change-Id: I70673fc5ecacb4b1f7ea039267c1560d84a767e9
This commit is contained in:
parent
5491487f9a
commit
a297145b87
|
|
@ -634,6 +634,21 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "sqrt_test",
|
||||
srcs = ["sqrt_test.cc"],
|
||||
linkopts = select({
|
||||
"//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS,
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":test_main",
|
||||
":unary_elementwise_tester",
|
||||
":xnnpack_delegate_test_mode",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "square_test",
|
||||
srcs = ["square_test.cc"],
|
||||
|
|
|
|||
|
|
@ -224,6 +224,10 @@ Below is the list of current operators and limitations:
|
|||
* Inputs and outputs must be in 32-bit floating-point format.
|
||||
* Only `beta = 1.0` is supported.
|
||||
|
||||
### `SQRT`
|
||||
|
||||
* Inputs and outputs must be in 32-bit floating-point format.
|
||||
|
||||
### `SQUARE`
|
||||
|
||||
* Inputs and outputs must be in 32-bit floating-point format.
|
||||
|
|
|
|||
120
tensorflow/lite/delegates/xnnpack/sqrt_test.cc
Normal file
120
tensorflow/lite/delegates/xnnpack/sqrt_test.cc
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
TEST(Sqrt, 4D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
UnaryElementwiseTester()
|
||||
.Shape({batch, height, width, channels})
|
||||
.Test(BuiltinOperator_SQRT, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Sqrt, 3D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
UnaryElementwiseTester()
|
||||
.Shape({batch, width, channels})
|
||||
.Test(BuiltinOperator_SQRT, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Sqrt, 2D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
UnaryElementwiseTester()
|
||||
.Shape({batch, channels})
|
||||
.Test(BuiltinOperator_SQRT, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Sqrt, 1D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
|
||||
UnaryElementwiseTester().Shape({batch}).Test(BuiltinOperator_SQRT,
|
||||
xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Sqrt, MultiThreading) {
|
||||
TfLiteXNNPackDelegateOptions delegate_options =
|
||||
TfLiteXNNPackDelegateOptionsDefault();
|
||||
delegate_options.num_threads = 2;
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
UnaryElementwiseTester()
|
||||
.Shape({batch, height, width, channels})
|
||||
.Test(BuiltinOperator_SQRT, xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
||||
|
|
@ -37,8 +37,15 @@ void UnaryElementwiseTester::Test(tflite::BuiltinOperator unary_op,
|
|||
TfLiteDelegate* delegate) const {
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng = std::bind(
|
||||
std::uniform_real_distribution<float>(-15.0f, 15.0f), std::ref(rng));
|
||||
std::uniform_real_distribution<float> input_distribution(-15.0f, 15.0f);
|
||||
switch (unary_op) {
|
||||
case BuiltinOperator_SQRT:
|
||||
input_distribution = std::uniform_real_distribution<float>(0.0f, 10.0f);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
auto input_rng = std::bind(input_distribution, std::ref(rng));
|
||||
|
||||
std::vector<char> buffer = CreateTfLiteModel(unary_op);
|
||||
const Model* model = GetModel(buffer.data());
|
||||
|
|
@ -96,6 +103,7 @@ void UnaryElementwiseTester::Test(tflite::BuiltinOperator unary_op,
|
|||
case BuiltinOperator_RELU6:
|
||||
case BuiltinOperator_ROUND:
|
||||
case BuiltinOperator_SQUARE:
|
||||
case BuiltinOperator_SQRT:
|
||||
for (size_t i = 0; i < Size(); i++) {
|
||||
ASSERT_EQ(default_output_data[i], delegate_output_data[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -913,6 +913,9 @@ class Subgraph {
|
|||
context->tensors, softmax_params,
|
||||
xnnpack_tensors);
|
||||
}
|
||||
case kTfLiteBuiltinSqrt:
|
||||
return VisitSqrtNode(subgraph, logging_context, node_index, node,
|
||||
context->tensors, xnnpack_tensors);
|
||||
case kTfLiteBuiltinSquare:
|
||||
return VisitSquareNode(subgraph, logging_context, node_index, node,
|
||||
context->tensors, xnnpack_tensors);
|
||||
|
|
@ -2449,6 +2452,39 @@ class Subgraph {
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitSqrtNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
|
||||
|
||||
const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
|
||||
if (subgraph != nullptr) {
|
||||
const xnn_status status = xnn_define_square_root(
|
||||
subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
|
||||
/*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
|
||||
if (status != xnn_status_success) {
|
||||
TF_LITE_KERNEL_LOG(logging_context, "failed to delegate SQRT node #%d",
|
||||
node_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitSquaredDifferenceNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
|
|
|
|||
|
|
@ -164,11 +164,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||
|
||||
tf_http_archive(
|
||||
name = "XNNPACK",
|
||||
sha256 = "2527a30464b43bd03f137b2c455a0381e49eae63d09cfeee128a717dfbe962d5",
|
||||
strip_prefix = "XNNPACK-8b283aa30a3186c6e640aed520543e9c067132d2",
|
||||
sha256 = "e37a92154c2ff72c3ebf97247617ce2e159ccc23e648fd62ded44a71c3d68c6a",
|
||||
strip_prefix = "XNNPACK-51a01c66c78334c3d5abf4034e9a8a550a8ad4ad",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/8b283aa30a3186c6e640aed520543e9c067132d2.zip",
|
||||
"https://github.com/google/XNNPACK/archive/8b283aa30a3186c6e640aed520543e9c067132d2.zip",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/51a01c66c78334c3d5abf4034e9a8a550a8ad4ad.zip",
|
||||
"https://github.com/google/XNNPACK/archive/51a01c66c78334c3d5abf4034e9a8a550a8ad4ad.zip",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user