Merge pull request #50087 from tensorflow/mihaimaruseac-patch-2

Update conv.cc
This commit is contained in:
Mihai Maruseac 2021-06-04 13:32:24 -07:00 committed by GitHub
commit fc407649b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -749,11 +749,12 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
TfLiteTensor* input, TfLiteTensor* filter,
TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* output) {
TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
TfLiteTensor* input,
TfLiteTensor* filter,
TfLiteTensor* bias,
TfLiteTensor* im2col, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
@ -828,10 +829,11 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
TfLiteTensor* input, TfLiteTensor* filter,
TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
@ -917,14 +919,17 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32:
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
if (data->is_hybrid_per_channel) {
EvalHybridPerChannel<kernel_type>(context, node, params, data, input,
filter, bias, im2col, output);
TF_LITE_ENSURE_OK(context, EvalHybridPerChannel<kernel_type>(
context, node, params, data, input,
filter, bias, im2col, output));
} else {
TfLiteTensor* accum_scratch =
&context->tensors[node->temporaries
->data[data->accum_scratch_index]];
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
bias, im2col, accum_scratch, output);
TF_LITE_ENSURE_OK(context,
EvalHybrid<kernel_type>(context, node, params, data,
input, filter, bias, im2col,
accum_scratch, output));
}
} else {
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,