mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Support SQRT operator in XNNPACK delegate
PiperOrigin-RevId: 320429761 Change-Id: I70673fc5ecacb4b1f7ea039267c1560d84a767e9
This commit is contained in:
parent
5491487f9a
commit
a297145b87
|
|
@ -634,6 +634,21 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "sqrt_test",
|
||||||
|
srcs = ["sqrt_test.cc"],
|
||||||
|
linkopts = select({
|
||||||
|
"//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS,
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
deps = [
|
||||||
|
":test_main",
|
||||||
|
":unary_elementwise_tester",
|
||||||
|
":xnnpack_delegate_test_mode",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "square_test",
|
name = "square_test",
|
||||||
srcs = ["square_test.cc"],
|
srcs = ["square_test.cc"],
|
||||||
|
|
|
||||||
|
|
@ -224,6 +224,10 @@ Below is the list of current operators and limitations:
|
||||||
* Inputs and outputs must be in 32-bit floating-point format.
|
* Inputs and outputs must be in 32-bit floating-point format.
|
||||||
* Only `beta = 1.0` is supported.
|
* Only `beta = 1.0` is supported.
|
||||||
|
|
||||||
|
### `SQRT`
|
||||||
|
|
||||||
|
* Inputs and outputs must be in 32-bit floating-point format.
|
||||||
|
|
||||||
### `SQUARE`
|
### `SQUARE`
|
||||||
|
|
||||||
* Inputs and outputs must be in 32-bit floating-point format.
|
* Inputs and outputs must be in 32-bit floating-point format.
|
||||||
|
|
|
||||||
120
tensorflow/lite/delegates/xnnpack/sqrt_test.cc
Normal file
120
tensorflow/lite/delegates/xnnpack/sqrt_test.cc
Normal file
|
|
@ -0,0 +1,120 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h"
|
||||||
|
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace xnnpack {
|
||||||
|
|
||||||
|
TEST(Sqrt, 4D) {
|
||||||
|
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||||
|
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||||
|
TfLiteXNNPackDelegateDelete);
|
||||||
|
|
||||||
|
std::random_device random_device;
|
||||||
|
auto rng = std::mt19937(random_device());
|
||||||
|
auto shape_rng =
|
||||||
|
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||||
|
const auto batch = shape_rng();
|
||||||
|
const auto height = shape_rng();
|
||||||
|
const auto width = shape_rng();
|
||||||
|
const auto channels = shape_rng();
|
||||||
|
|
||||||
|
UnaryElementwiseTester()
|
||||||
|
.Shape({batch, height, width, channels})
|
||||||
|
.Test(BuiltinOperator_SQRT, xnnpack_delegate.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Sqrt, 3D) {
|
||||||
|
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||||
|
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||||
|
TfLiteXNNPackDelegateDelete);
|
||||||
|
|
||||||
|
std::random_device random_device;
|
||||||
|
auto rng = std::mt19937(random_device());
|
||||||
|
auto shape_rng =
|
||||||
|
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||||
|
const auto batch = shape_rng();
|
||||||
|
const auto width = shape_rng();
|
||||||
|
const auto channels = shape_rng();
|
||||||
|
|
||||||
|
UnaryElementwiseTester()
|
||||||
|
.Shape({batch, width, channels})
|
||||||
|
.Test(BuiltinOperator_SQRT, xnnpack_delegate.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Sqrt, 2D) {
|
||||||
|
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||||
|
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||||
|
TfLiteXNNPackDelegateDelete);
|
||||||
|
|
||||||
|
std::random_device random_device;
|
||||||
|
auto rng = std::mt19937(random_device());
|
||||||
|
auto shape_rng =
|
||||||
|
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||||
|
const auto batch = shape_rng();
|
||||||
|
const auto channels = shape_rng();
|
||||||
|
|
||||||
|
UnaryElementwiseTester()
|
||||||
|
.Shape({batch, channels})
|
||||||
|
.Test(BuiltinOperator_SQRT, xnnpack_delegate.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Sqrt, 1D) {
|
||||||
|
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||||
|
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||||
|
TfLiteXNNPackDelegateDelete);
|
||||||
|
|
||||||
|
std::random_device random_device;
|
||||||
|
auto rng = std::mt19937(random_device());
|
||||||
|
auto shape_rng =
|
||||||
|
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||||
|
const auto batch = shape_rng();
|
||||||
|
|
||||||
|
UnaryElementwiseTester().Shape({batch}).Test(BuiltinOperator_SQRT,
|
||||||
|
xnnpack_delegate.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Sqrt, MultiThreading) {
|
||||||
|
TfLiteXNNPackDelegateOptions delegate_options =
|
||||||
|
TfLiteXNNPackDelegateOptionsDefault();
|
||||||
|
delegate_options.num_threads = 2;
|
||||||
|
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||||
|
xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
|
||||||
|
TfLiteXNNPackDelegateDelete);
|
||||||
|
|
||||||
|
std::random_device random_device;
|
||||||
|
auto rng = std::mt19937(random_device());
|
||||||
|
auto shape_rng =
|
||||||
|
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||||
|
const auto batch = shape_rng();
|
||||||
|
const auto height = shape_rng();
|
||||||
|
const auto width = shape_rng();
|
||||||
|
const auto channels = shape_rng();
|
||||||
|
|
||||||
|
UnaryElementwiseTester()
|
||||||
|
.Shape({batch, height, width, channels})
|
||||||
|
.Test(BuiltinOperator_SQRT, xnnpack_delegate.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xnnpack
|
||||||
|
} // namespace tflite
|
||||||
|
|
@ -37,8 +37,15 @@ void UnaryElementwiseTester::Test(tflite::BuiltinOperator unary_op,
|
||||||
TfLiteDelegate* delegate) const {
|
TfLiteDelegate* delegate) const {
|
||||||
std::random_device random_device;
|
std::random_device random_device;
|
||||||
auto rng = std::mt19937(random_device());
|
auto rng = std::mt19937(random_device());
|
||||||
auto input_rng = std::bind(
|
std::uniform_real_distribution<float> input_distribution(-15.0f, 15.0f);
|
||||||
std::uniform_real_distribution<float>(-15.0f, 15.0f), std::ref(rng));
|
switch (unary_op) {
|
||||||
|
case BuiltinOperator_SQRT:
|
||||||
|
input_distribution = std::uniform_real_distribution<float>(0.0f, 10.0f);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto input_rng = std::bind(input_distribution, std::ref(rng));
|
||||||
|
|
||||||
std::vector<char> buffer = CreateTfLiteModel(unary_op);
|
std::vector<char> buffer = CreateTfLiteModel(unary_op);
|
||||||
const Model* model = GetModel(buffer.data());
|
const Model* model = GetModel(buffer.data());
|
||||||
|
|
@ -96,6 +103,7 @@ void UnaryElementwiseTester::Test(tflite::BuiltinOperator unary_op,
|
||||||
case BuiltinOperator_RELU6:
|
case BuiltinOperator_RELU6:
|
||||||
case BuiltinOperator_ROUND:
|
case BuiltinOperator_ROUND:
|
||||||
case BuiltinOperator_SQUARE:
|
case BuiltinOperator_SQUARE:
|
||||||
|
case BuiltinOperator_SQRT:
|
||||||
for (size_t i = 0; i < Size(); i++) {
|
for (size_t i = 0; i < Size(); i++) {
|
||||||
ASSERT_EQ(default_output_data[i], delegate_output_data[i]);
|
ASSERT_EQ(default_output_data[i], delegate_output_data[i]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -913,6 +913,9 @@ class Subgraph {
|
||||||
context->tensors, softmax_params,
|
context->tensors, softmax_params,
|
||||||
xnnpack_tensors);
|
xnnpack_tensors);
|
||||||
}
|
}
|
||||||
|
case kTfLiteBuiltinSqrt:
|
||||||
|
return VisitSqrtNode(subgraph, logging_context, node_index, node,
|
||||||
|
context->tensors, xnnpack_tensors);
|
||||||
case kTfLiteBuiltinSquare:
|
case kTfLiteBuiltinSquare:
|
||||||
return VisitSquareNode(subgraph, logging_context, node_index, node,
|
return VisitSquareNode(subgraph, logging_context, node_index, node,
|
||||||
context->tensors, xnnpack_tensors);
|
context->tensors, xnnpack_tensors);
|
||||||
|
|
@ -2449,6 +2452,39 @@ class Subgraph {
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static TfLiteStatus VisitSqrtNode(
|
||||||
|
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||||
|
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||||
|
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||||
|
TF_LITE_ENSURE_STATUS(
|
||||||
|
CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
|
||||||
|
|
||||||
|
const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
|
||||||
|
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||||
|
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||||
|
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||||
|
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||||
|
|
||||||
|
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||||
|
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||||
|
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||||
|
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||||
|
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||||
|
|
||||||
|
if (subgraph != nullptr) {
|
||||||
|
const xnn_status status = xnn_define_square_root(
|
||||||
|
subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
|
||||||
|
/*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
|
||||||
|
if (status != xnn_status_success) {
|
||||||
|
TF_LITE_KERNEL_LOG(logging_context, "failed to delegate SQRT node #%d",
|
||||||
|
node_index);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
static TfLiteStatus VisitSquaredDifferenceNode(
|
static TfLiteStatus VisitSquaredDifferenceNode(
|
||||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||||
|
|
|
||||||
|
|
@ -164,11 +164,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||||
|
|
||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "XNNPACK",
|
name = "XNNPACK",
|
||||||
sha256 = "2527a30464b43bd03f137b2c455a0381e49eae63d09cfeee128a717dfbe962d5",
|
sha256 = "e37a92154c2ff72c3ebf97247617ce2e159ccc23e648fd62ded44a71c3d68c6a",
|
||||||
strip_prefix = "XNNPACK-8b283aa30a3186c6e640aed520543e9c067132d2",
|
strip_prefix = "XNNPACK-51a01c66c78334c3d5abf4034e9a8a550a8ad4ad",
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/8b283aa30a3186c6e640aed520543e9c067132d2.zip",
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/51a01c66c78334c3d5abf4034e9a8a550a8ad4ad.zip",
|
||||||
"https://github.com/google/XNNPACK/archive/8b283aa30a3186c6e640aed520543e9c067132d2.zip",
|
"https://github.com/google/XNNPACK/archive/51a01c66c78334c3d5abf4034e9a8a550a8ad4ad.zip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user