mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add support for using SignatureRunner with models with no signatures
In the absence of a model signature, with this change a SignatureRunner and a AsyncSignatureRunner would take the names for their input/output tensors from the model. This is necessary in order to ensure correct functionality for the several (Async)SignatureRunner methods that take an input/output name as argument. Note that this change alters the behavior of AsyncSignatureRunner, an experimental API, in that it originally would assume nullptr for input/output names for models with no signatures. PiperOrigin-RevId: 662908295
This commit is contained in:
parent
bb3c8e9a7b
commit
0df7ec86af
|
|
@ -11,7 +11,13 @@
|
|||
|
||||
* `tf.lite`
|
||||
* C API:
|
||||
* An optional, fourth parameter was added `TfLiteOperatorCreate` as a step forward towards a cleaner API for `TfLiteOperator`. Function `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, released on 7/11/2024, and we do not expect there will be much code using this function yet. Any code breakages can be easily resolved by passing nullptr as the new, 4th parameter.
|
||||
* An optional, fourth parameter was added `TfLiteOperatorCreate` as a step
|
||||
forward towards a cleaner API for `TfLiteOperator`. Function
|
||||
`TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0,
|
||||
released on 7/11/2024, and we do not expect there will be much code using this
|
||||
function yet. Any code breakages can be easily resolved by passing nullptr as
|
||||
the new, 4th parameter.
|
||||
* SignatureRunner is now supported for models with no signatures.
|
||||
|
||||
* TensorRT support is disabled in CUDA builds for code health improvement.
|
||||
|
||||
|
|
|
|||
|
|
@ -292,7 +292,10 @@ cc_test(
|
|||
size = "small",
|
||||
srcs = ["c_api_signature_runner_test.cc"],
|
||||
copts = tflite_copts(),
|
||||
data = ["//tensorflow/lite:testdata/multi_signatures.bin"],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/multi_signatures.bin",
|
||||
"//tensorflow/lite:testdata/no_signatures.bin",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
"//tensorflow/lite/core/c:c_api",
|
||||
|
|
|
|||
|
|
@ -24,6 +24,94 @@ limitations under the License.
|
|||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
TEST(SignatureRunnerTest, TestNoSignatures) {
|
||||
TfLiteModel* model = TfLiteModelCreateFromFile(
|
||||
"tensorflow/lite/testdata/no_signatures.bin");
|
||||
ASSERT_NE(model, nullptr);
|
||||
|
||||
TfLiteInterpreter* interpreter =
|
||||
TfLiteInterpreterCreate(model, /*optional_options=*/nullptr);
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
|
||||
int nun_signatures = TfLiteInterpreterGetSignatureCount(interpreter);
|
||||
ASSERT_EQ(nun_signatures, 0);
|
||||
|
||||
ASSERT_EQ(TfLiteInterpreterGetSignatureRunner(interpreter, "foo"), nullptr);
|
||||
|
||||
TfLiteSignatureRunner* runner =
|
||||
TfLiteInterpreterGetSignatureRunner(interpreter, nullptr);
|
||||
ASSERT_NE(runner, nullptr);
|
||||
|
||||
int num_interpreter_inputs =
|
||||
TfLiteInterpreterGetInputTensorCount(interpreter);
|
||||
int num_runner_inputs = TfLiteSignatureRunnerGetInputCount(runner);
|
||||
ASSERT_EQ(num_runner_inputs, num_interpreter_inputs);
|
||||
|
||||
for (int i = 0; i < num_interpreter_inputs; ++i) {
|
||||
auto* interpreter_input_tensor =
|
||||
TfLiteInterpreterGetInputTensor(interpreter, i);
|
||||
ASSERT_NE(interpreter_input_tensor, nullptr);
|
||||
auto* interpreter_input_name = TfLiteTensorName(interpreter_input_tensor);
|
||||
ASSERT_NE(interpreter_input_name, nullptr);
|
||||
auto* runner_input_name = TfLiteSignatureRunnerGetInputName(runner, i);
|
||||
ASSERT_NE(runner_input_name, nullptr);
|
||||
EXPECT_STREQ(runner_input_name, interpreter_input_name);
|
||||
auto* runner_input_tensor =
|
||||
TfLiteSignatureRunnerGetInputTensor(runner, interpreter_input_name);
|
||||
ASSERT_NE(runner_input_tensor, nullptr);
|
||||
ASSERT_EQ(runner_input_tensor, interpreter_input_tensor);
|
||||
}
|
||||
|
||||
int num_interpreter_outputs =
|
||||
TfLiteInterpreterGetOutputTensorCount(interpreter);
|
||||
int num_runner_outputs = TfLiteSignatureRunnerGetOutputCount(runner);
|
||||
ASSERT_EQ(num_runner_outputs, num_interpreter_outputs);
|
||||
|
||||
for (int i = 0; i < num_interpreter_outputs; ++i) {
|
||||
auto* interpreter_output_tensor =
|
||||
TfLiteInterpreterGetOutputTensor(interpreter, i);
|
||||
ASSERT_NE(interpreter_output_tensor, nullptr);
|
||||
auto* interpreter_output_name = TfLiteTensorName(interpreter_output_tensor);
|
||||
ASSERT_NE(interpreter_output_name, nullptr);
|
||||
auto* runner_output_name = TfLiteSignatureRunnerGetOutputName(runner, i);
|
||||
ASSERT_NE(runner_output_name, nullptr);
|
||||
EXPECT_STREQ(runner_output_name, interpreter_output_name);
|
||||
auto* runner_output_tensor =
|
||||
TfLiteSignatureRunnerGetOutputTensor(runner, interpreter_output_name);
|
||||
ASSERT_NE(runner_output_tensor, nullptr);
|
||||
ASSERT_EQ(runner_output_tensor, interpreter_output_tensor);
|
||||
}
|
||||
|
||||
std::array<int, 1> input_dims{2};
|
||||
ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor(
|
||||
runner, "x1", input_dims.data(), input_dims.size()),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor(
|
||||
runner, "x2", input_dims.data(), input_dims.size()),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(TfLiteSignatureRunnerAllocateTensors(runner), kTfLiteOk);
|
||||
TfLiteTensor* input1 = TfLiteSignatureRunnerGetInputTensor(runner, "x1");
|
||||
ASSERT_NE(input1, nullptr);
|
||||
TfLiteTensor* input2 = TfLiteSignatureRunnerGetInputTensor(runner, "x2");
|
||||
ASSERT_NE(input2, nullptr);
|
||||
ASSERT_EQ(TfLiteSignatureRunnerGetInputTensor(runner, "foo"), nullptr);
|
||||
const TfLiteTensor* output =
|
||||
TfLiteSignatureRunnerGetOutputTensor(runner, "Identity");
|
||||
ASSERT_NE(output, nullptr);
|
||||
ASSERT_EQ(TfLiteSignatureRunnerGetOutputTensor(runner, "foo"), nullptr);
|
||||
input1->data.f[0] = -8;
|
||||
input1->data.f[1] = 0.5;
|
||||
input2->data.f[0] = -1;
|
||||
input2->data.f[1] = 1.5;
|
||||
ASSERT_EQ(TfLiteSignatureRunnerInvoke(runner), kTfLiteOk);
|
||||
ASSERT_EQ(output->data.f[0], 0);
|
||||
ASSERT_EQ(output->data.f[1], 2);
|
||||
|
||||
TfLiteSignatureRunnerDelete(runner);
|
||||
TfLiteInterpreterDelete(interpreter);
|
||||
TfLiteModelDelete(model);
|
||||
}
|
||||
|
||||
TEST(SignatureRunnerTest, TestMultiSignatures) {
|
||||
TfLiteModel* model = TfLiteModelCreateFromFile(
|
||||
"tensorflow/lite/testdata/multi_signatures.bin");
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ TEST_F(AsyncSignatureRunnerNoSignatureDefTest, GetAsyncSignatureRunner) {
|
|||
TEST_F(AsyncSignatureRunnerNoSignatureDefTest, InputsTest) {
|
||||
signature_runner_ = interpreter_->GetAsyncSignatureRunner(nullptr);
|
||||
EXPECT_EQ(1, signature_runner_->input_size());
|
||||
EXPECT_EQ(0, signature_runner_->input_names().size());
|
||||
EXPECT_EQ(1, signature_runner_->input_names().size());
|
||||
|
||||
EXPECT_EQ(1, signature_runner_->inputs().size());
|
||||
EXPECT_NE(nullptr, signature_runner_->tensor(signature_runner_->inputs()[0]));
|
||||
|
|
@ -192,7 +192,7 @@ TEST_F(AsyncSignatureRunnerNoSignatureDefTest, InputsTest) {
|
|||
TEST_F(AsyncSignatureRunnerNoSignatureDefTest, OutputsTest) {
|
||||
signature_runner_ = interpreter_->GetAsyncSignatureRunner(nullptr);
|
||||
EXPECT_EQ(1, signature_runner_->output_size());
|
||||
EXPECT_EQ(0, signature_runner_->output_names().size());
|
||||
EXPECT_EQ(1, signature_runner_->output_names().size());
|
||||
|
||||
EXPECT_EQ(1, signature_runner_->outputs().size());
|
||||
EXPECT_NE(nullptr,
|
||||
|
|
|
|||
|
|
@ -118,6 +118,9 @@ cc_test(
|
|||
name = "async_signature_runner_test",
|
||||
srcs = ["async_signature_runner_test.cc"],
|
||||
copts = tflite_copts() + tflite_copts_warnings(),
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/no_signatures.bin",
|
||||
],
|
||||
deps = [
|
||||
":async_signature_runner",
|
||||
":internal",
|
||||
|
|
|
|||
|
|
@ -182,9 +182,10 @@ TEST_P(AsyncSignatureRunnerTest, InputsTest) {
|
|||
"x", TfLiteOpaqueTensorName(
|
||||
TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "input")));
|
||||
} else {
|
||||
EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetInputName(runner_, 0));
|
||||
EXPECT_EQ(nullptr,
|
||||
TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "input"));
|
||||
EXPECT_STREQ("x", TfLiteAsyncSignatureRunnerGetInputName(runner_, 0));
|
||||
EXPECT_STREQ("x",
|
||||
TfLiteOpaqueTensorName(
|
||||
TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "x")));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -198,9 +199,10 @@ TEST_P(AsyncSignatureRunnerTest, OutputsTest) {
|
|||
"a", TfLiteOpaqueTensorName(
|
||||
TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "output")));
|
||||
} else {
|
||||
EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetOutputName(runner_, 0));
|
||||
EXPECT_EQ(nullptr,
|
||||
TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "output"));
|
||||
EXPECT_STREQ("a", TfLiteAsyncSignatureRunnerGetOutputName(runner_, 0));
|
||||
EXPECT_STREQ("a",
|
||||
TfLiteOpaqueTensorName(
|
||||
TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "a")));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -229,5 +231,93 @@ TEST_P(AsyncSignatureRunnerTest, IndexOutOfBound) {
|
|||
EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetTensor(runner_, 42));
|
||||
}
|
||||
|
||||
TEST(AsyncSignatureRunnerTest, TestNoSignatures) {
|
||||
TfLiteModel* model = TfLiteModelCreateFromFile(
|
||||
"third_party/tensorflow/lite/testdata/no_signatures.bin");
|
||||
ASSERT_NE(model, nullptr);
|
||||
|
||||
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
|
||||
ASSERT_NE(options, nullptr);
|
||||
auto kernel =
|
||||
std::make_unique<::testing::StrictMock<testing::MockAsyncKernel>>();
|
||||
auto backend = std::make_unique<testing::TestBackend>(kernel->kernel());
|
||||
TfLiteInterpreterOptionsAddDelegate(options, backend->get_delegate());
|
||||
|
||||
TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
|
||||
TfLiteInterpreterOptionsDelete(options);
|
||||
|
||||
int nun_signatures = TfLiteInterpreterGetSignatureCount(interpreter);
|
||||
ASSERT_EQ(nun_signatures, 0);
|
||||
|
||||
ASSERT_EQ(TfLiteInterpreterGetAsyncSignatureRunner(interpreter, "foo"),
|
||||
nullptr);
|
||||
|
||||
TfLiteAsyncSignatureRunner* runner =
|
||||
TfLiteInterpreterGetAsyncSignatureRunner(interpreter, nullptr);
|
||||
ASSERT_NE(runner, nullptr);
|
||||
|
||||
int num_interpreter_inputs =
|
||||
TfLiteInterpreterGetInputTensorCount(interpreter);
|
||||
int num_runner_inputs = TfLiteAsyncSignatureRunnerGetInputCount(runner);
|
||||
ASSERT_EQ(num_runner_inputs, num_interpreter_inputs);
|
||||
|
||||
for (int i = 0; i < num_interpreter_inputs; ++i) {
|
||||
auto* interpreter_input_tensor =
|
||||
TfLiteInterpreterGetInputTensor(interpreter, i);
|
||||
ASSERT_NE(interpreter_input_tensor, nullptr);
|
||||
auto* interpreter_input_name = TfLiteTensorName(interpreter_input_tensor);
|
||||
ASSERT_NE(interpreter_input_name, nullptr);
|
||||
auto* runner_input_name = TfLiteAsyncSignatureRunnerGetInputName(runner, i);
|
||||
ASSERT_NE(runner_input_name, nullptr);
|
||||
EXPECT_STREQ(runner_input_name, interpreter_input_name);
|
||||
auto* runner_input_tensor = TfLiteAsyncSignatureRunnerGetInputTensor(
|
||||
runner, interpreter_input_name);
|
||||
ASSERT_NE(runner_input_tensor, nullptr);
|
||||
ASSERT_EQ(runner_input_tensor, reinterpret_cast<const TfLiteOpaqueTensor*>(
|
||||
interpreter_input_tensor));
|
||||
}
|
||||
|
||||
int num_interpreter_outputs =
|
||||
TfLiteInterpreterGetOutputTensorCount(interpreter);
|
||||
int num_runner_outputs = TfLiteAsyncSignatureRunnerGetOutputCount(runner);
|
||||
ASSERT_EQ(num_runner_outputs, num_interpreter_outputs);
|
||||
|
||||
for (int i = 0; i < num_interpreter_outputs; ++i) {
|
||||
auto* interpreter_output_tensor =
|
||||
TfLiteInterpreterGetOutputTensor(interpreter, i);
|
||||
ASSERT_NE(interpreter_output_tensor, nullptr);
|
||||
auto* interpreter_output_name = TfLiteTensorName(interpreter_output_tensor);
|
||||
ASSERT_NE(interpreter_output_name, nullptr);
|
||||
auto* runner_output_name =
|
||||
TfLiteAsyncSignatureRunnerGetOutputName(runner, i);
|
||||
ASSERT_NE(runner_output_name, nullptr);
|
||||
EXPECT_STREQ(runner_output_name, interpreter_output_name);
|
||||
auto* runner_output_tensor = TfLiteAsyncSignatureRunnerGetOutputTensor(
|
||||
runner, interpreter_output_name);
|
||||
ASSERT_NE(runner_output_tensor, nullptr);
|
||||
ASSERT_EQ(runner_output_tensor, reinterpret_cast<const TfLiteOpaqueTensor*>(
|
||||
interpreter_output_tensor));
|
||||
}
|
||||
|
||||
EXPECT_CALL(*kernel, Prepare(_, _)).WillOnce(Return(kTfLiteOk));
|
||||
EXPECT_CALL(*kernel, Eval(_, _, _)).WillOnce(Return(kTfLiteOk));
|
||||
EXPECT_CALL(*kernel, Wait(_, _)).WillOnce(Return(kTfLiteOk));
|
||||
EXPECT_CALL(*kernel, Finish(_, _)).WillOnce(Return(kTfLiteOk));
|
||||
|
||||
EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerPrepareBackends(runner));
|
||||
|
||||
auto* task = TfLiteAsyncSignatureRunnerCreateTask(runner);
|
||||
|
||||
EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerInvokeAsync(runner, task));
|
||||
EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerWait(runner, task));
|
||||
EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerFinish(runner, task));
|
||||
|
||||
TfLiteAsyncSignatureRunnerDelete(runner);
|
||||
TfLiteInterpreterDelete(interpreter);
|
||||
TfLiteModelDelete(model);
|
||||
}
|
||||
|
||||
} // namespace async
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -520,7 +520,13 @@ void Interpreter::AddProfiler(std::unique_ptr<Profiler> profiler) {
|
|||
}
|
||||
|
||||
impl::SignatureRunner* Interpreter::GetSignatureRunner(
|
||||
const char* signature_key) {
|
||||
const char* signature_key_) {
|
||||
auto [signature_key, empty_signature_fallback] =
|
||||
ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_);
|
||||
if (!signature_key) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto iter = signature_runner_map_.find(signature_key);
|
||||
if (iter != signature_runner_map_.end()) {
|
||||
return &(iter->second);
|
||||
|
|
@ -533,6 +539,14 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (empty_signature_fallback) {
|
||||
placeholder_signature_def_ = CreatePlaceholderSignatureDef();
|
||||
auto status = signature_runner_map_.insert(
|
||||
{signature_key, SignatureRunner(placeholder_signature_def_.get(),
|
||||
&primary_subgraph())});
|
||||
return &(status.first->second);
|
||||
}
|
||||
|
||||
for (const auto& signature : signature_defs_) {
|
||||
if (signature.signature_key == signature_key) {
|
||||
auto status = signature_runner_map_.insert(
|
||||
|
|
@ -541,7 +555,56 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner(
|
|||
return &(status.first->second);
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<internal::SignatureDef>
|
||||
Interpreter::CreatePlaceholderSignatureDef() {
|
||||
auto placeholder_signature_def = std::make_unique<internal::SignatureDef>();
|
||||
for (auto i = 0; i < inputs().size(); ++i) {
|
||||
auto* name = GetInputName(i);
|
||||
placeholder_signature_def->inputs[name] = inputs()[i];
|
||||
}
|
||||
for (auto i = 0; i < outputs().size(); ++i) {
|
||||
auto* name = GetOutputName(i);
|
||||
placeholder_signature_def->outputs[name] = outputs()[i];
|
||||
}
|
||||
placeholder_signature_def->signature_key = kPlaceholderSignatureDefKey;
|
||||
placeholder_signature_def->subgraph_index = 0;
|
||||
return placeholder_signature_def;
|
||||
}
|
||||
|
||||
std::pair<const char*, bool>
|
||||
Interpreter::ReplaceWithPlaceholderSignatureKeyIfNeeded(
|
||||
const char* signature_key) {
|
||||
// Handles nullptr signature key.
|
||||
// If the model does not have signature def, use default name as placeholder.
|
||||
// Otherwise use the first signature key that points to primary subgraph.
|
||||
bool empty_signature_fallback = false;
|
||||
if (signature_key == nullptr) {
|
||||
if (signature_defs_.empty()) {
|
||||
signature_key = kPlaceholderSignatureDefKey;
|
||||
empty_signature_fallback = true;
|
||||
} else {
|
||||
for (const auto& signature : signature_defs_) {
|
||||
if (signature.subgraph_index == 0) {
|
||||
signature_key = signature.signature_key.c_str();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (signature_key == nullptr) {
|
||||
// The model has signature def but none of those points to primary subgraph.
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"The model has signature def but none of those points "
|
||||
"to primary subgraph.");
|
||||
return {nullptr, empty_signature_fallback};
|
||||
} else {
|
||||
return {signature_key, empty_signature_fallback};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -335,21 +335,25 @@ class Interpreter {
|
|||
}
|
||||
|
||||
/// \brief Returns a pointer to the SignatureRunner instance to run the part
|
||||
/// of the graph identified by a SignatureDef. The nullptr is returned if the
|
||||
/// given signature key is not valid.
|
||||
/// of the graph identified by a SignatureDef. If the model does not have any
|
||||
/// signature defs, pass nullptr as signature_key and a SignatureRunner will
|
||||
/// be created using the primary subgraph (0). A nullptr is returned if the
|
||||
/// given signature_key is not valid. Note, the returned SignatureRunner
|
||||
/// instance is owned by and has the same lifetime as the Interpreter object;
|
||||
/// additionally, class SignatureRunner is *not* thread-safe.
|
||||
/// If you need to specify delegates, you have to do that before calling this
|
||||
/// function. This function will additionally apply default delegates. Thus,
|
||||
/// applying delegates after that might lead to undesirable behaviors.
|
||||
/// Note, the pointed instance has lifetime same as the Interpreter object
|
||||
/// and the SignatureRunner class is *not* thread-safe.
|
||||
SignatureRunner* GetSignatureRunner(const char* signature_key);
|
||||
|
||||
/// \warning Experimental interface, subject to change. \n
|
||||
/// \brief Returns a pointer to the AsyncSignatureRunner instance to run the
|
||||
/// part of the graph identified by a SignatureDef. The nullptr is returned if
|
||||
/// the given signature key is not valid.
|
||||
/// if the model does not have signature def, pass nullptr to signature_key
|
||||
/// and AsyncSignatureRunner will be created using primary subgraph (0).
|
||||
/// \warning Experimental interface, subject to change. \n \brief Returns a
|
||||
/// pointer to the AsyncSignatureRunner instance to run the part of the graph
|
||||
/// identified by a SignatureDef. If the model does not have any signature
|
||||
/// defs, pass nullptr as signature_key and an AsyncSignatureRunner will be
|
||||
/// created using the primary subgraph (0). A nullptr is returned if the
|
||||
/// given signature_key is not valid. Note, the returned AsyncSignatureRunner
|
||||
/// instance is owned by and has the same lifetime as the Interpreter object;
|
||||
/// additionally, class AsyncSignatureRunner is *not* thread-safe.
|
||||
/// The async delegate should be applied before calling this function.
|
||||
async::AsyncSignatureRunner* GetAsyncSignatureRunner(
|
||||
const char* signature_key);
|
||||
|
|
@ -905,6 +909,10 @@ class Interpreter {
|
|||
|
||||
TfLiteStatus ApplyOptionsImpl(InterpreterOptions* options);
|
||||
|
||||
std::unique_ptr<internal::SignatureDef> CreatePlaceholderSignatureDef();
|
||||
std::pair<const char*, bool> ReplaceWithPlaceholderSignatureKeyIfNeeded(
|
||||
const char* signature_key);
|
||||
|
||||
// A pure C data structure used to communicate with the pure C plugin
|
||||
// interface. To avoid copying tensor metadata, this is also the definitive
|
||||
// structure to store tensors.
|
||||
|
|
@ -964,6 +972,13 @@ class Interpreter {
|
|||
// List of SignatureDefs obtained from the model.
|
||||
std::vector<internal::SignatureDef> signature_defs_;
|
||||
|
||||
// Default signature key to use when the model has no signatures.
|
||||
static constexpr char kPlaceholderSignatureDefKey[] =
|
||||
"<placeholder signature>";
|
||||
|
||||
// Placeholder SignatureDef for legacy models with no signatures.
|
||||
std::unique_ptr<internal::SignatureDef> placeholder_signature_def_;
|
||||
|
||||
// Map of signature key to its corresponding SignatureRunner object.
|
||||
// A SignatureRunner is basically a wrapper of the Subgraph corresponding to
|
||||
// its SignatureDef.
|
||||
|
|
|
|||
|
|
@ -34,10 +34,6 @@ limitations under the License.
|
|||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
static constexpr char kDefaultServingSignatureDefKey[] = "serving_default";
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus Interpreter::SetCustomAllocationForTensor(
|
||||
int tensor_index, const TfLiteCustomAllocation& allocation, int64_t flags) {
|
||||
return primary_subgraph().SetCustomAllocationForTensor(tensor_index,
|
||||
|
|
@ -145,27 +141,10 @@ TfLiteStatus Interpreter::ApplyOptions(InterpreterOptions* options) {
|
|||
}
|
||||
|
||||
async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner(
|
||||
const char* signature_key) {
|
||||
// Handles nullptr signature key.
|
||||
// If the model does not have signature def, use default name as placeholder.
|
||||
// Otherwise use the first signature key that points to primary subgraph.
|
||||
bool empty_signature_fallback = false;
|
||||
if (signature_key == nullptr) {
|
||||
if (signature_defs_.empty()) {
|
||||
signature_key = kDefaultServingSignatureDefKey;
|
||||
empty_signature_fallback = true;
|
||||
} else {
|
||||
for (const auto& signature : signature_defs_) {
|
||||
if (signature.subgraph_index == 0) {
|
||||
signature_key = signature.signature_key.c_str();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (signature_key == nullptr) {
|
||||
// The model has signature def but none of those points to primary subgraph.
|
||||
const char* signature_key_) {
|
||||
auto [signature_key, empty_signature_fallback] =
|
||||
ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_);
|
||||
if (!signature_key) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -175,11 +154,14 @@ async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner(
|
|||
}
|
||||
|
||||
if (empty_signature_fallback) {
|
||||
placeholder_signature_def_ = CreatePlaceholderSignatureDef();
|
||||
auto status = async_signature_runner_map_.insert(
|
||||
{signature_key,
|
||||
async::AsyncSignatureRunner(nullptr, &primary_subgraph())});
|
||||
async::AsyncSignatureRunner(placeholder_signature_def_.get(),
|
||||
&primary_subgraph())});
|
||||
return &(status.first->second);
|
||||
}
|
||||
|
||||
for (const auto& signature : signature_defs_) {
|
||||
if (signature.signature_key == signature_key) {
|
||||
auto status = async_signature_runner_map_.insert(
|
||||
|
|
|
|||
BIN
tensorflow/lite/testdata/no_signatures.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/no_signatures.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user