mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Update conv.cc
This commit is contained in:
parent
35b9cd4004
commit
89928fa1c7
|
|
@ -749,11 +749,12 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
TfLiteTensor* input, TfLiteTensor* filter,
|
||||
TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* output) {
|
||||
TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
TfLiteTensor* input,
|
||||
TfLiteTensor* filter,
|
||||
TfLiteTensor* bias,
|
||||
TfLiteTensor* im2col, TfLiteTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
|
|
@ -828,10 +829,11 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
|||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
|
||||
TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
|
||||
TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
TfLiteTensor* input, TfLiteTensor* filter,
|
||||
TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
|
|
@ -917,14 +919,17 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
|||
case kTfLiteFloat32:
|
||||
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
|
||||
if (data->is_hybrid_per_channel) {
|
||||
EvalHybridPerChannel<kernel_type>(context, node, params, data, input,
|
||||
filter, bias, im2col, output);
|
||||
TF_LITE_ENSURE_OK(context,EvalHybridPerChannel<kernel_type>(
|
||||
context, node, params, data, input,
|
||||
filter, bias, im2col, output));
|
||||
} else {
|
||||
TfLiteTensor* accum_scratch =
|
||||
&context->tensors[node->temporaries
|
||||
->data[data->accum_scratch_index]];
|
||||
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
|
||||
bias, im2col, accum_scratch, output);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
EvalHybrid<kernel_type>(context, node, params, data,
|
||||
input, filter, bias, im2col,
|
||||
accum_scratch, output));
|
||||
}
|
||||
} else {
|
||||
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user