Add support for using SignatureRunner with models with no signatures

In the absence of a model signature, with this change a SignatureRunner and a
AsyncSignatureRunner would take the names for their input/output tensors from
the model. This is necessary in order to ensure correct functionality for the
several (Async)SignatureRunner methods that take an input/output name as
argument.

Note that this change alters the behavior of AsyncSignatureRunner, an experimental API, in that it originally would assume nullptr for input/output names for models with no signatures.

PiperOrigin-RevId: 662908295
This commit is contained in:
A. Unique TensorFlower 2024-08-14 07:16:12 -07:00 committed by TensorFlower Gardener
parent bb3c8e9a7b
commit 0df7ec86af
10 changed files with 297 additions and 47 deletions

View File

@ -11,7 +11,13 @@
* `tf.lite`
* C API:
* An optional, fourth parameter was added `TfLiteOperatorCreate` as a step forward towards a cleaner API for `TfLiteOperator`. Function `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, released on 7/11/2024, and we do not expect there will be much code using this function yet. Any code breakages can be easily resolved by passing nullptr as the new, 4th parameter.
* An optional, fourth parameter was added `TfLiteOperatorCreate` as a step
forward towards a cleaner API for `TfLiteOperator`. Function
`TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0,
released on 7/11/2024, and we do not expect there will be much code using this
function yet. Any code breakages can be easily resolved by passing nullptr as
the new, 4th parameter.
* SignatureRunner is now supported for models with no signatures.
* TensorRT support is disabled in CUDA builds for code health improvement.

View File

@ -292,7 +292,10 @@ cc_test(
size = "small",
srcs = ["c_api_signature_runner_test.cc"],
copts = tflite_copts(),
data = ["//tensorflow/lite:testdata/multi_signatures.bin"],
data = [
"//tensorflow/lite:testdata/multi_signatures.bin",
"//tensorflow/lite:testdata/no_signatures.bin",
],
deps = [
":c_api",
"//tensorflow/lite/core/c:c_api",

View File

@ -24,6 +24,94 @@ limitations under the License.
namespace tflite {
namespace {
TEST(SignatureRunnerTest, TestNoSignatures) {
TfLiteModel* model = TfLiteModelCreateFromFile(
"tensorflow/lite/testdata/no_signatures.bin");
ASSERT_NE(model, nullptr);
TfLiteInterpreter* interpreter =
TfLiteInterpreterCreate(model, /*optional_options=*/nullptr);
ASSERT_NE(interpreter, nullptr);
int nun_signatures = TfLiteInterpreterGetSignatureCount(interpreter);
ASSERT_EQ(nun_signatures, 0);
ASSERT_EQ(TfLiteInterpreterGetSignatureRunner(interpreter, "foo"), nullptr);
TfLiteSignatureRunner* runner =
TfLiteInterpreterGetSignatureRunner(interpreter, nullptr);
ASSERT_NE(runner, nullptr);
int num_interpreter_inputs =
TfLiteInterpreterGetInputTensorCount(interpreter);
int num_runner_inputs = TfLiteSignatureRunnerGetInputCount(runner);
ASSERT_EQ(num_runner_inputs, num_interpreter_inputs);
for (int i = 0; i < num_interpreter_inputs; ++i) {
auto* interpreter_input_tensor =
TfLiteInterpreterGetInputTensor(interpreter, i);
ASSERT_NE(interpreter_input_tensor, nullptr);
auto* interpreter_input_name = TfLiteTensorName(interpreter_input_tensor);
ASSERT_NE(interpreter_input_name, nullptr);
auto* runner_input_name = TfLiteSignatureRunnerGetInputName(runner, i);
ASSERT_NE(runner_input_name, nullptr);
EXPECT_STREQ(runner_input_name, interpreter_input_name);
auto* runner_input_tensor =
TfLiteSignatureRunnerGetInputTensor(runner, interpreter_input_name);
ASSERT_NE(runner_input_tensor, nullptr);
ASSERT_EQ(runner_input_tensor, interpreter_input_tensor);
}
int num_interpreter_outputs =
TfLiteInterpreterGetOutputTensorCount(interpreter);
int num_runner_outputs = TfLiteSignatureRunnerGetOutputCount(runner);
ASSERT_EQ(num_runner_outputs, num_interpreter_outputs);
for (int i = 0; i < num_interpreter_outputs; ++i) {
auto* interpreter_output_tensor =
TfLiteInterpreterGetOutputTensor(interpreter, i);
ASSERT_NE(interpreter_output_tensor, nullptr);
auto* interpreter_output_name = TfLiteTensorName(interpreter_output_tensor);
ASSERT_NE(interpreter_output_name, nullptr);
auto* runner_output_name = TfLiteSignatureRunnerGetOutputName(runner, i);
ASSERT_NE(runner_output_name, nullptr);
EXPECT_STREQ(runner_output_name, interpreter_output_name);
auto* runner_output_tensor =
TfLiteSignatureRunnerGetOutputTensor(runner, interpreter_output_name);
ASSERT_NE(runner_output_tensor, nullptr);
ASSERT_EQ(runner_output_tensor, interpreter_output_tensor);
}
std::array<int, 1> input_dims{2};
ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor(
runner, "x1", input_dims.data(), input_dims.size()),
kTfLiteOk);
ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor(
runner, "x2", input_dims.data(), input_dims.size()),
kTfLiteOk);
ASSERT_EQ(TfLiteSignatureRunnerAllocateTensors(runner), kTfLiteOk);
TfLiteTensor* input1 = TfLiteSignatureRunnerGetInputTensor(runner, "x1");
ASSERT_NE(input1, nullptr);
TfLiteTensor* input2 = TfLiteSignatureRunnerGetInputTensor(runner, "x2");
ASSERT_NE(input2, nullptr);
ASSERT_EQ(TfLiteSignatureRunnerGetInputTensor(runner, "foo"), nullptr);
const TfLiteTensor* output =
TfLiteSignatureRunnerGetOutputTensor(runner, "Identity");
ASSERT_NE(output, nullptr);
ASSERT_EQ(TfLiteSignatureRunnerGetOutputTensor(runner, "foo"), nullptr);
input1->data.f[0] = -8;
input1->data.f[1] = 0.5;
input2->data.f[0] = -1;
input2->data.f[1] = 1.5;
ASSERT_EQ(TfLiteSignatureRunnerInvoke(runner), kTfLiteOk);
ASSERT_EQ(output->data.f[0], 0);
ASSERT_EQ(output->data.f[1], 2);
TfLiteSignatureRunnerDelete(runner);
TfLiteInterpreterDelete(interpreter);
TfLiteModelDelete(model);
}
TEST(SignatureRunnerTest, TestMultiSignatures) {
TfLiteModel* model = TfLiteModelCreateFromFile(
"tensorflow/lite/testdata/multi_signatures.bin");

View File

@ -183,7 +183,7 @@ TEST_F(AsyncSignatureRunnerNoSignatureDefTest, GetAsyncSignatureRunner) {
TEST_F(AsyncSignatureRunnerNoSignatureDefTest, InputsTest) {
signature_runner_ = interpreter_->GetAsyncSignatureRunner(nullptr);
EXPECT_EQ(1, signature_runner_->input_size());
EXPECT_EQ(0, signature_runner_->input_names().size());
EXPECT_EQ(1, signature_runner_->input_names().size());
EXPECT_EQ(1, signature_runner_->inputs().size());
EXPECT_NE(nullptr, signature_runner_->tensor(signature_runner_->inputs()[0]));
@ -192,7 +192,7 @@ TEST_F(AsyncSignatureRunnerNoSignatureDefTest, InputsTest) {
TEST_F(AsyncSignatureRunnerNoSignatureDefTest, OutputsTest) {
signature_runner_ = interpreter_->GetAsyncSignatureRunner(nullptr);
EXPECT_EQ(1, signature_runner_->output_size());
EXPECT_EQ(0, signature_runner_->output_names().size());
EXPECT_EQ(1, signature_runner_->output_names().size());
EXPECT_EQ(1, signature_runner_->outputs().size());
EXPECT_NE(nullptr,

View File

@ -118,6 +118,9 @@ cc_test(
name = "async_signature_runner_test",
srcs = ["async_signature_runner_test.cc"],
copts = tflite_copts() + tflite_copts_warnings(),
data = [
"//tensorflow/lite:testdata/no_signatures.bin",
],
deps = [
":async_signature_runner",
":internal",

View File

@ -182,9 +182,10 @@ TEST_P(AsyncSignatureRunnerTest, InputsTest) {
"x", TfLiteOpaqueTensorName(
TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "input")));
} else {
EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetInputName(runner_, 0));
EXPECT_EQ(nullptr,
TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "input"));
EXPECT_STREQ("x", TfLiteAsyncSignatureRunnerGetInputName(runner_, 0));
EXPECT_STREQ("x",
TfLiteOpaqueTensorName(
TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "x")));
}
}
@ -198,9 +199,10 @@ TEST_P(AsyncSignatureRunnerTest, OutputsTest) {
"a", TfLiteOpaqueTensorName(
TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "output")));
} else {
EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetOutputName(runner_, 0));
EXPECT_EQ(nullptr,
TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "output"));
EXPECT_STREQ("a", TfLiteAsyncSignatureRunnerGetOutputName(runner_, 0));
EXPECT_STREQ("a",
TfLiteOpaqueTensorName(
TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "a")));
}
}
@ -229,5 +231,93 @@ TEST_P(AsyncSignatureRunnerTest, IndexOutOfBound) {
EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetTensor(runner_, 42));
}
TEST(AsyncSignatureRunnerTest, TestNoSignatures) {
TfLiteModel* model = TfLiteModelCreateFromFile(
"third_party/tensorflow/lite/testdata/no_signatures.bin");
ASSERT_NE(model, nullptr);
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
ASSERT_NE(options, nullptr);
auto kernel =
std::make_unique<::testing::StrictMock<testing::MockAsyncKernel>>();
auto backend = std::make_unique<testing::TestBackend>(kernel->kernel());
TfLiteInterpreterOptionsAddDelegate(options, backend->get_delegate());
TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
ASSERT_NE(interpreter, nullptr);
TfLiteInterpreterOptionsDelete(options);
int nun_signatures = TfLiteInterpreterGetSignatureCount(interpreter);
ASSERT_EQ(nun_signatures, 0);
ASSERT_EQ(TfLiteInterpreterGetAsyncSignatureRunner(interpreter, "foo"),
nullptr);
TfLiteAsyncSignatureRunner* runner =
TfLiteInterpreterGetAsyncSignatureRunner(interpreter, nullptr);
ASSERT_NE(runner, nullptr);
int num_interpreter_inputs =
TfLiteInterpreterGetInputTensorCount(interpreter);
int num_runner_inputs = TfLiteAsyncSignatureRunnerGetInputCount(runner);
ASSERT_EQ(num_runner_inputs, num_interpreter_inputs);
for (int i = 0; i < num_interpreter_inputs; ++i) {
auto* interpreter_input_tensor =
TfLiteInterpreterGetInputTensor(interpreter, i);
ASSERT_NE(interpreter_input_tensor, nullptr);
auto* interpreter_input_name = TfLiteTensorName(interpreter_input_tensor);
ASSERT_NE(interpreter_input_name, nullptr);
auto* runner_input_name = TfLiteAsyncSignatureRunnerGetInputName(runner, i);
ASSERT_NE(runner_input_name, nullptr);
EXPECT_STREQ(runner_input_name, interpreter_input_name);
auto* runner_input_tensor = TfLiteAsyncSignatureRunnerGetInputTensor(
runner, interpreter_input_name);
ASSERT_NE(runner_input_tensor, nullptr);
ASSERT_EQ(runner_input_tensor, reinterpret_cast<const TfLiteOpaqueTensor*>(
interpreter_input_tensor));
}
int num_interpreter_outputs =
TfLiteInterpreterGetOutputTensorCount(interpreter);
int num_runner_outputs = TfLiteAsyncSignatureRunnerGetOutputCount(runner);
ASSERT_EQ(num_runner_outputs, num_interpreter_outputs);
for (int i = 0; i < num_interpreter_outputs; ++i) {
auto* interpreter_output_tensor =
TfLiteInterpreterGetOutputTensor(interpreter, i);
ASSERT_NE(interpreter_output_tensor, nullptr);
auto* interpreter_output_name = TfLiteTensorName(interpreter_output_tensor);
ASSERT_NE(interpreter_output_name, nullptr);
auto* runner_output_name =
TfLiteAsyncSignatureRunnerGetOutputName(runner, i);
ASSERT_NE(runner_output_name, nullptr);
EXPECT_STREQ(runner_output_name, interpreter_output_name);
auto* runner_output_tensor = TfLiteAsyncSignatureRunnerGetOutputTensor(
runner, interpreter_output_name);
ASSERT_NE(runner_output_tensor, nullptr);
ASSERT_EQ(runner_output_tensor, reinterpret_cast<const TfLiteOpaqueTensor*>(
interpreter_output_tensor));
}
EXPECT_CALL(*kernel, Prepare(_, _)).WillOnce(Return(kTfLiteOk));
EXPECT_CALL(*kernel, Eval(_, _, _)).WillOnce(Return(kTfLiteOk));
EXPECT_CALL(*kernel, Wait(_, _)).WillOnce(Return(kTfLiteOk));
EXPECT_CALL(*kernel, Finish(_, _)).WillOnce(Return(kTfLiteOk));
EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerPrepareBackends(runner));
auto* task = TfLiteAsyncSignatureRunnerCreateTask(runner);
EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerInvokeAsync(runner, task));
EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerWait(runner, task));
EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerFinish(runner, task));
TfLiteAsyncSignatureRunnerDelete(runner);
TfLiteInterpreterDelete(interpreter);
TfLiteModelDelete(model);
}
} // namespace async
} // namespace tflite

View File

@ -520,7 +520,13 @@ void Interpreter::AddProfiler(std::unique_ptr<Profiler> profiler) {
}
impl::SignatureRunner* Interpreter::GetSignatureRunner(
const char* signature_key) {
const char* signature_key_) {
auto [signature_key, empty_signature_fallback] =
ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_);
if (!signature_key) {
return nullptr;
}
auto iter = signature_runner_map_.find(signature_key);
if (iter != signature_runner_map_.end()) {
return &(iter->second);
@ -533,6 +539,14 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner(
return nullptr;
}
if (empty_signature_fallback) {
placeholder_signature_def_ = CreatePlaceholderSignatureDef();
auto status = signature_runner_map_.insert(
{signature_key, SignatureRunner(placeholder_signature_def_.get(),
&primary_subgraph())});
return &(status.first->second);
}
for (const auto& signature : signature_defs_) {
if (signature.signature_key == signature_key) {
auto status = signature_runner_map_.insert(
@ -541,7 +555,56 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner(
return &(status.first->second);
}
}
return nullptr;
}
std::unique_ptr<internal::SignatureDef>
Interpreter::CreatePlaceholderSignatureDef() {
auto placeholder_signature_def = std::make_unique<internal::SignatureDef>();
for (auto i = 0; i < inputs().size(); ++i) {
auto* name = GetInputName(i);
placeholder_signature_def->inputs[name] = inputs()[i];
}
for (auto i = 0; i < outputs().size(); ++i) {
auto* name = GetOutputName(i);
placeholder_signature_def->outputs[name] = outputs()[i];
}
placeholder_signature_def->signature_key = kPlaceholderSignatureDefKey;
placeholder_signature_def->subgraph_index = 0;
return placeholder_signature_def;
}
std::pair<const char*, bool>
Interpreter::ReplaceWithPlaceholderSignatureKeyIfNeeded(
const char* signature_key) {
// Handles nullptr signature key.
// If the model does not have signature def, use default name as placeholder.
// Otherwise use the first signature key that points to primary subgraph.
bool empty_signature_fallback = false;
if (signature_key == nullptr) {
if (signature_defs_.empty()) {
signature_key = kPlaceholderSignatureDefKey;
empty_signature_fallback = true;
} else {
for (const auto& signature : signature_defs_) {
if (signature.subgraph_index == 0) {
signature_key = signature.signature_key.c_str();
break;
}
}
}
}
if (signature_key == nullptr) {
// The model has signature def but none of those points to primary subgraph.
TF_LITE_REPORT_ERROR(error_reporter_,
"The model has signature def but none of those points "
"to primary subgraph.");
return {nullptr, empty_signature_fallback};
} else {
return {signature_key, empty_signature_fallback};
}
}
} // namespace tflite

View File

@ -335,21 +335,25 @@ class Interpreter {
}
/// \brief Returns a pointer to the SignatureRunner instance to run the part
/// of the graph identified by a SignatureDef. The nullptr is returned if the
/// given signature key is not valid.
/// of the graph identified by a SignatureDef. If the model does not have any
/// signature defs, pass nullptr as signature_key and a SignatureRunner will
/// be created using the primary subgraph (0). A nullptr is returned if the
/// given signature_key is not valid. Note, the returned SignatureRunner
/// instance is owned by and has the same lifetime as the Interpreter object;
/// additionally, class SignatureRunner is *not* thread-safe.
/// If you need to specify delegates, you have to do that before calling this
/// function. This function will additionally apply default delegates. Thus,
/// applying delegates after that might lead to undesirable behaviors.
/// Note, the pointed instance has lifetime same as the Interpreter object
/// and the SignatureRunner class is *not* thread-safe.
SignatureRunner* GetSignatureRunner(const char* signature_key);
/// \warning Experimental interface, subject to change. \n
/// \brief Returns a pointer to the AsyncSignatureRunner instance to run the
/// part of the graph identified by a SignatureDef. The nullptr is returned if
/// the given signature key is not valid.
/// if the model does not have signature def, pass nullptr to signature_key
/// and AsyncSignatureRunner will be created using primary subgraph (0).
/// \warning Experimental interface, subject to change. \n \brief Returns a
/// pointer to the AsyncSignatureRunner instance to run the part of the graph
/// identified by a SignatureDef. If the model does not have any signature
/// defs, pass nullptr as signature_key and an AsyncSignatureRunner will be
/// created using the primary subgraph (0). A nullptr is returned if the
/// given signature_key is not valid. Note, the returned AsyncSignatureRunner
/// instance is owned by and has the same lifetime as the Interpreter object;
/// additionally, class AsyncSignatureRunner is *not* thread-safe.
/// The async delegate should be applied before calling this function.
async::AsyncSignatureRunner* GetAsyncSignatureRunner(
const char* signature_key);
@ -905,6 +909,10 @@ class Interpreter {
TfLiteStatus ApplyOptionsImpl(InterpreterOptions* options);
std::unique_ptr<internal::SignatureDef> CreatePlaceholderSignatureDef();
std::pair<const char*, bool> ReplaceWithPlaceholderSignatureKeyIfNeeded(
const char* signature_key);
// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.
@ -964,6 +972,13 @@ class Interpreter {
// List of SignatureDefs obtained from the model.
std::vector<internal::SignatureDef> signature_defs_;
// Default signature key to use when the model has no signatures.
static constexpr char kPlaceholderSignatureDefKey[] =
"<placeholder signature>";
// Placeholder SignatureDef for legacy models with no signatures.
std::unique_ptr<internal::SignatureDef> placeholder_signature_def_;
// Map of signature key to its corresponding SignatureRunner object.
// A SignatureRunner is basically a wrapper of the Subgraph corresponding to
// its SignatureDef.

View File

@ -34,10 +34,6 @@ limitations under the License.
namespace tflite {
namespace {
static constexpr char kDefaultServingSignatureDefKey[] = "serving_default";
} // namespace
TfLiteStatus Interpreter::SetCustomAllocationForTensor(
int tensor_index, const TfLiteCustomAllocation& allocation, int64_t flags) {
return primary_subgraph().SetCustomAllocationForTensor(tensor_index,
@ -145,27 +141,10 @@ TfLiteStatus Interpreter::ApplyOptions(InterpreterOptions* options) {
}
async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner(
const char* signature_key) {
// Handles nullptr signature key.
// If the model does not have signature def, use default name as placeholder.
// Otherwise use the first signature key that points to primary subgraph.
bool empty_signature_fallback = false;
if (signature_key == nullptr) {
if (signature_defs_.empty()) {
signature_key = kDefaultServingSignatureDefKey;
empty_signature_fallback = true;
} else {
for (const auto& signature : signature_defs_) {
if (signature.subgraph_index == 0) {
signature_key = signature.signature_key.c_str();
break;
}
}
}
}
if (signature_key == nullptr) {
// The model has signature def but none of those points to primary subgraph.
const char* signature_key_) {
auto [signature_key, empty_signature_fallback] =
ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_);
if (!signature_key) {
return nullptr;
}
@ -175,11 +154,14 @@ async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner(
}
if (empty_signature_fallback) {
placeholder_signature_def_ = CreatePlaceholderSignatureDef();
auto status = async_signature_runner_map_.insert(
{signature_key,
async::AsyncSignatureRunner(nullptr, &primary_subgraph())});
async::AsyncSignatureRunner(placeholder_signature_def_.get(),
&primary_subgraph())});
return &(status.first->second);
}
for (const auto& signature : signature_defs_) {
if (signature.signature_key == signature_key) {
auto status = async_signature_runner_map_.insert(

Binary file not shown.