mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Merge pull request #49995 from geetachavan1/cherrypicks_G9C1B
Cherry pick 2.2 TFLite: Error out when the graph has a recurion.
This commit is contained in:
commit
cee8e6c359
|
|
@ -357,6 +357,7 @@ cc_test(
|
|||
"testdata/test_min_runtime.bin",
|
||||
"testdata/test_model.bin",
|
||||
"testdata/test_model_broken.bin",
|
||||
"testdata/unsupported_recursion.bin",
|
||||
],
|
||||
tags = [
|
||||
"tflite_not_portable",
|
||||
|
|
|
|||
|
|
@ -139,6 +139,42 @@ const char* GetTFLiteOpName(const TfLiteRegistration& op_reg) {
|
|||
return tflite::EnumNamesBuiltinOperator()[op_reg.builtin_code];
|
||||
}
|
||||
|
||||
// An utility test to detect if the subgraph is abused:
|
||||
// 1. Detects if recursion exists in the graph (recursion is not currently
|
||||
// supported.
|
||||
// 2. Detects if the interpreter / subgraph is used in multiple subgraphs.
|
||||
// Note: It's clearly documented that the interpreter / subgraph are not
|
||||
// thread-safe. This serves as a check with possible false negatives
|
||||
// unless we switch to atomic boolean flags.
|
||||
class SubgraphGuard {
|
||||
public:
|
||||
SubgraphGuard(TfLiteContext* context, bool* is_subgraph_in_use)
|
||||
: is_subgraph_in_use_(is_subgraph_in_use) {
|
||||
if (*is_subgraph_in_use_) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
"Subgraph is already in use. Using an interpreter or a subgraph in "
|
||||
"multiple threads is not supported. Recursion in the graph is not "
|
||||
"supported.");
|
||||
status_ = kTfLiteError;
|
||||
} else {
|
||||
*is_subgraph_in_use_ = true;
|
||||
}
|
||||
}
|
||||
~SubgraphGuard() {
|
||||
// If tht original status was OK, recover the boolean flag.
|
||||
if (status_ == kTfLiteOk) {
|
||||
*is_subgraph_in_use_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus status() const { return status_; }
|
||||
|
||||
private:
|
||||
TfLiteStatus status_ = kTfLiteOk;
|
||||
bool* is_subgraph_in_use_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A trivial implementation of GraphInfo around the Interpreter.
|
||||
|
|
@ -630,6 +666,7 @@ TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims,
|
|||
|
||||
TfLiteStatus Subgraph::AllocateTensors() {
|
||||
TFLITE_SCOPED_TAGGED_DEFAULT_PROFILE(profiler_.get(), "AllocateTensors");
|
||||
|
||||
if (!consistent_) {
|
||||
ReportError("AllocateTensors() called on inconsistent model.");
|
||||
return kTfLiteError;
|
||||
|
|
@ -653,6 +690,12 @@ TfLiteStatus Subgraph::AllocateTensors() {
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Note `AllocateTensors` sometimes calls itself recursively above
|
||||
// for delegates. Therefore only the logic below need to be guarded
|
||||
// by `SubgraphGuard`.
|
||||
SubgraphGuard guard(&context_, &is_subgraph_in_use_);
|
||||
TF_LITE_ENSURE_OK(&context_, guard.status());
|
||||
|
||||
next_execution_plan_index_to_prepare_ = 0;
|
||||
next_execution_plan_index_to_plan_allocation_ = 0;
|
||||
if (memory_planner_) {
|
||||
|
|
@ -880,6 +923,9 @@ TfLiteStatus Subgraph::PrepareOpsAndTensors() {
|
|||
}
|
||||
|
||||
TfLiteStatus Subgraph::Invoke() {
|
||||
SubgraphGuard guard(&context_, &is_subgraph_in_use_);
|
||||
TF_LITE_ENSURE_OK(&context_, guard.status());
|
||||
|
||||
if (!consistent_) {
|
||||
ReportError("Invoke called on model that is not consistent.");
|
||||
return kTfLiteError;
|
||||
|
|
|
|||
|
|
@ -682,6 +682,10 @@ class Subgraph {
|
|||
|
||||
// A map of resources. Owned by interpreter and shared by multiple subgraphs.
|
||||
resource::ResourceMap* resources_ = nullptr;
|
||||
|
||||
// Whether the subgraph is currently in use (e.g. running the `Invoke`
|
||||
// or `AllocateTensors` functions).
|
||||
bool is_subgraph_in_use_ = false;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -132,8 +132,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
auto* subgraphs = this_subgraph->GetSubgraphs();
|
||||
TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
|
||||
TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
|
||||
TF_LITE_ENSURE(context,
|
||||
op_data->cond_subgraph_index != op_data->body_subgraph_index);
|
||||
|
||||
Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
|
||||
Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
|
||||
|
|
|
|||
|
|
@ -442,6 +442,24 @@ TEST(BasicFlatBufferModel, TestParseModelWithSparseTensor) {
|
|||
}
|
||||
|
||||
// TODO(b/150072943): Add malformed model with sparse tensor tests.
|
||||
// Recursion & reentrant are not supported in TFLite.
|
||||
// The test ensures it fails gracefullly instead of crashing with
|
||||
// a stack overflow.
|
||||
TEST(BasicFlatBufferModel, TestUnsupportedRecursion) {
|
||||
const auto model_path =
|
||||
"tensorflow/lite/testdata/unsupported_recursion.bin";
|
||||
|
||||
std::unique_ptr<tflite::FlatBufferModel> model =
|
||||
FlatBufferModel::BuildFromFile(model_path);
|
||||
ASSERT_NE(model, nullptr);
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
InterpreterBuilder builder(*model, resolver);
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
ASSERT_EQ(builder(&interpreter), kTfLiteOk);
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
ASSERT_NE(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
}
|
||||
|
||||
// TODO(aselle): Add tests for serialization of builtin op data types.
|
||||
// These tests will occur with the evaluation tests of individual operators,
|
||||
|
|
|
|||
BIN
tensorflow/lite/testdata/unsupported_recursion.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/unsupported_recursion.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user