Merge pull request #49995 from geetachavan1/cherrypicks_G9C1B

Cherry pick 2.2 TFLite: Error out when the graph has a recurion.
This commit is contained in:
Mihai Maruseac 2021-06-02 16:43:21 -07:00 committed by GitHub
commit cee8e6c359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 69 additions and 2 deletions

View File

@ -357,6 +357,7 @@ cc_test(
"testdata/test_min_runtime.bin",
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
"testdata/unsupported_recursion.bin",
],
tags = [
"tflite_not_portable",

View File

@ -139,6 +139,42 @@ const char* GetTFLiteOpName(const TfLiteRegistration& op_reg) {
return tflite::EnumNamesBuiltinOperator()[op_reg.builtin_code];
}
// An utility test to detect if the subgraph is abused:
// 1. Detects if recursion exists in the graph (recursion is not currently
// supported.
// 2. Detects if the interpreter / subgraph is used in multiple subgraphs.
// Note: It's clearly documented that the interpreter / subgraph are not
// thread-safe. This serves as a check with possible false negatives
// unless we switch to atomic boolean flags.
class SubgraphGuard {
public:
SubgraphGuard(TfLiteContext* context, bool* is_subgraph_in_use)
: is_subgraph_in_use_(is_subgraph_in_use) {
if (*is_subgraph_in_use_) {
TF_LITE_KERNEL_LOG(
context,
"Subgraph is already in use. Using an interpreter or a subgraph in "
"multiple threads is not supported. Recursion in the graph is not "
"supported.");
status_ = kTfLiteError;
} else {
*is_subgraph_in_use_ = true;
}
}
~SubgraphGuard() {
// If tht original status was OK, recover the boolean flag.
if (status_ == kTfLiteOk) {
*is_subgraph_in_use_ = false;
}
}
TfLiteStatus status() const { return status_; }
private:
TfLiteStatus status_ = kTfLiteOk;
bool* is_subgraph_in_use_;
};
} // namespace
// A trivial implementation of GraphInfo around the Interpreter.
@ -630,6 +666,7 @@ TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims,
TfLiteStatus Subgraph::AllocateTensors() {
TFLITE_SCOPED_TAGGED_DEFAULT_PROFILE(profiler_.get(), "AllocateTensors");
if (!consistent_) {
ReportError("AllocateTensors() called on inconsistent model.");
return kTfLiteError;
@ -653,6 +690,12 @@ TfLiteStatus Subgraph::AllocateTensors() {
return kTfLiteOk;
}
// Note `AllocateTensors` sometimes calls itself recursively above
// for delegates. Therefore only the logic below need to be guarded
// by `SubgraphGuard`.
SubgraphGuard guard(&context_, &is_subgraph_in_use_);
TF_LITE_ENSURE_OK(&context_, guard.status());
next_execution_plan_index_to_prepare_ = 0;
next_execution_plan_index_to_plan_allocation_ = 0;
if (memory_planner_) {
@ -880,6 +923,9 @@ TfLiteStatus Subgraph::PrepareOpsAndTensors() {
}
TfLiteStatus Subgraph::Invoke() {
SubgraphGuard guard(&context_, &is_subgraph_in_use_);
TF_LITE_ENSURE_OK(&context_, guard.status());
if (!consistent_) {
ReportError("Invoke called on model that is not consistent.");
return kTfLiteError;

View File

@ -682,6 +682,10 @@ class Subgraph {
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
resource::ResourceMap* resources_ = nullptr;
// Whether the subgraph is currently in use (e.g. running the `Invoke`
// or `AllocateTensors` functions).
bool is_subgraph_in_use_ = false;
};
} // namespace tflite

View File

@ -132,8 +132,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* subgraphs = this_subgraph->GetSubgraphs();
TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
TF_LITE_ENSURE(context,
op_data->cond_subgraph_index != op_data->body_subgraph_index);
Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();

View File

@ -442,6 +442,24 @@ TEST(BasicFlatBufferModel, TestParseModelWithSparseTensor) {
}
// TODO(b/150072943): Add malformed model with sparse tensor tests.
// Recursion & reentrant are not supported in TFLite.
// The test ensures it fails gracefullly instead of crashing with
// a stack overflow.
TEST(BasicFlatBufferModel, TestUnsupportedRecursion) {
const auto model_path =
"tensorflow/lite/testdata/unsupported_recursion.bin";
std::unique_ptr<tflite::FlatBufferModel> model =
FlatBufferModel::BuildFromFile(model_path);
ASSERT_NE(model, nullptr);
tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(builder(&interpreter), kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
ASSERT_NE(interpreter->AllocateTensors(), kTfLiteOk);
}
// TODO(aselle): Add tests for serialization of builtin op data types.
// These tests will occur with the evaluation tests of individual operators,

Binary file not shown.