Enable execution of sparse models on MLDrift backend

PiperOrigin-RevId: 377987985
Change-Id: Icd7e29ee125f689dfd3400d52e7f6d63b6b5ef69
This commit is contained in:
A. Unique TensorFlower 2021-06-07 13:09:06 -07:00 committed by TensorFlower Gardener
parent 4cae16a306
commit 1d536b3d79
2 changed files with 14 additions and 5 deletions

View File

@ -3275,16 +3275,17 @@ class DelegateContext {
std::vector<int> input_ids; std::vector<int> input_ids;
std::vector<int> output_ids; std::vector<int> output_ids;
GraphFloat32* graph; GraphFloat32* graph;
std::unique_ptr<absl::flat_hash_map<int, int>> quant_conversion_map;
}; };
bool Init(TfLiteContext* context, bool Init(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params) { const TfLiteDelegateParams* delegate_params) {
const auto* delegate_data = const auto* delegate_data =
reinterpret_cast<DelegateData*>(delegate_params->delegate->data_); reinterpret_cast<DelegateData*>(delegate_params->delegate->data_);
return delegate_data->graph && return delegate_data->graph &&
BuildModelEnforceIO(context, delegate_params, BuildModelEnforceIO(context, delegate_params,
delegate_data->input_ids, delegate_data->input_ids,
delegate_data->output_ids, delegate_data->graph) delegate_data->output_ids, delegate_data->graph,
delegate_data->quant_conversion_map.get())
.ok(); .ok();
} }
}; };
@ -3311,7 +3312,10 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
registration.invoke = nullptr; registration.invoke = nullptr;
registration.custom_name = nullptr; registration.custom_name = nullptr;
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); const auto* delegate_data =
reinterpret_cast<const DelegateContext::DelegateData*>(delegate->data_);
TfLiteIntArray* ops_to_replace = GetOpsToReplace(
context, static_cast<bool>(delegate_data->quant_conversion_map));
const auto status = context->ReplaceNodeSubsetsWithDelegateKernels( const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
context, registration, ops_to_replace, delegate); context, registration, ops_to_replace, delegate);
TfLiteIntArrayFree(ops_to_replace); TfLiteIntArrayFree(ops_to_replace);
@ -3322,7 +3326,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer, absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
const tflite::OpResolver& op_resolver, const tflite::OpResolver& op_resolver,
GraphFloat32* graph) { GraphFloat32* graph, bool allow_quant_ops) {
std::unique_ptr<tflite::Interpreter> interpreter; std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver); tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) { if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
@ -3332,6 +3336,10 @@ absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
DelegateContext::DelegateData delegate_data{interpreter->inputs(), DelegateContext::DelegateData delegate_data{interpreter->inputs(),
interpreter->outputs(), graph}; interpreter->outputs(), graph};
if (allow_quant_ops) {
delegate_data.quant_conversion_map =
absl::make_unique<absl::flat_hash_map<int, int>>();
}
delegate.data_ = &delegate_data; delegate.data_ = &delegate_data;
delegate.flags = kTfLiteDelegateFlagsNone; delegate.flags = kTfLiteDelegateFlagsNone;

View File

@ -78,7 +78,8 @@ absl::Status BuildFinalModel(
// FlatBufferModel. // FlatBufferModel.
absl::Status BuildFromFlatBuffer(const FlatBufferModel& flatbuffer, absl::Status BuildFromFlatBuffer(const FlatBufferModel& flatbuffer,
const OpResolver& op_resolver, const OpResolver& op_resolver,
GraphFloat32* graph); GraphFloat32* graph,
bool allow_quant_ops = false);
// Module-internal converter, exposed for unit testing purpose only. // Module-internal converter, exposed for unit testing purpose only.
absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,