mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Enable execution of sparse models on MLDrift backend
PiperOrigin-RevId: 377987985 Change-Id: Icd7e29ee125f689dfd3400d52e7f6d63b6b5ef69
This commit is contained in:
parent
4cae16a306
commit
1d536b3d79
|
|
@ -3275,16 +3275,17 @@ class DelegateContext {
|
|||
std::vector<int> input_ids;
|
||||
std::vector<int> output_ids;
|
||||
GraphFloat32* graph;
|
||||
std::unique_ptr<absl::flat_hash_map<int, int>> quant_conversion_map;
|
||||
};
|
||||
bool Init(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* delegate_params) {
|
||||
const auto* delegate_data =
|
||||
reinterpret_cast<DelegateData*>(delegate_params->delegate->data_);
|
||||
|
||||
return delegate_data->graph &&
|
||||
BuildModelEnforceIO(context, delegate_params,
|
||||
delegate_data->input_ids,
|
||||
delegate_data->output_ids, delegate_data->graph)
|
||||
delegate_data->output_ids, delegate_data->graph,
|
||||
delegate_data->quant_conversion_map.get())
|
||||
.ok();
|
||||
}
|
||||
};
|
||||
|
|
@ -3311,7 +3312,10 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
|||
registration.invoke = nullptr;
|
||||
registration.custom_name = nullptr;
|
||||
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
const auto* delegate_data =
|
||||
reinterpret_cast<const DelegateContext::DelegateData*>(delegate->data_);
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(
|
||||
context, static_cast<bool>(delegate_data->quant_conversion_map));
|
||||
const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, registration, ops_to_replace, delegate);
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
|
|
@ -3322,7 +3326,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
|||
|
||||
absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
|
||||
const tflite::OpResolver& op_resolver,
|
||||
GraphFloat32* graph) {
|
||||
GraphFloat32* graph, bool allow_quant_ops) {
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
|
||||
if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
|
||||
|
|
@ -3332,6 +3336,10 @@ absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
|
|||
|
||||
DelegateContext::DelegateData delegate_data{interpreter->inputs(),
|
||||
interpreter->outputs(), graph};
|
||||
if (allow_quant_ops) {
|
||||
delegate_data.quant_conversion_map =
|
||||
absl::make_unique<absl::flat_hash_map<int, int>>();
|
||||
}
|
||||
|
||||
delegate.data_ = &delegate_data;
|
||||
delegate.flags = kTfLiteDelegateFlagsNone;
|
||||
|
|
|
|||
|
|
@ -78,7 +78,8 @@ absl::Status BuildFinalModel(
|
|||
// FlatBufferModel.
|
||||
absl::Status BuildFromFlatBuffer(const FlatBufferModel& flatbuffer,
|
||||
const OpResolver& op_resolver,
|
||||
GraphFloat32* graph);
|
||||
GraphFloat32* graph,
|
||||
bool allow_quant_ops = false);
|
||||
|
||||
// Module-internal converter, exposed for unit testing purpose only.
|
||||
absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user