diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index 1f471d495df..ac299a50752 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -742,7 +742,7 @@ struct HelperInterpBase { } template - static inline void _compute_weights_aa( + static inline scalar_t _compute_weights_aa( const int64_t i, const int64_t input_size, const scalar_t scale, const scalar_t support, scalar_t* wt_ptr, const int64_t max_interp_size, aa_filter_fn_t filter_fn, int64_t& xmin, int64_t& xsize, bool antialias, double align_corners_delta @@ -764,14 +764,19 @@ struct HelperInterpBase { wt_ptr[j] = w; total_w += w; } - for (j = 0; j < xsize; j++) { - if (total_w != 0.0) { + + scalar_t wt_max = 0.0; + if (total_w != 0.0) { + for (j = 0; j < xsize; j++) { wt_ptr[j] /= total_w; + wt_max = std::max(wt_max, wt_ptr[j]); } } + for (; j < max_interp_size; j++) { wt_ptr[j] = static_cast(0.0); } + return wt_max; } // Note [ Support for antialias=False as a subcase of antilias=True ] @@ -785,7 +790,7 @@ struct HelperInterpBase { // indices, but this can be optimized further when aa=False since we know // their actual dimensions. template - static inline std::tuple, int> _compute_indices_weights_aa( + static inline std::tuple, int, scalar_t> _compute_indices_weights_aa( int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim, scalar_t scale, int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, double align_corners_delta @@ -834,10 +839,10 @@ struct HelperInterpBase { scalar_t* wt_ptr = output[3].data_ptr(); int64_t* wt_idx_ptr = output[4].data_ptr(); - int64_t xmin, xmax; - + scalar_t wt_max = 0.0; for (const auto i : c10::irange(output_size)) { - HelperInterpBase::_compute_weights_aa( + int64_t xmin, xmax; + auto wt_max_i = HelperInterpBase::_compute_weights_aa( i, input_size, scale, @@ -850,12 +855,14 @@ struct HelperInterpBase { antialias, align_corners_delta); + wt_max = std::max(wt_max, wt_max_i); + idx_ptr_xmin[i] = xmin * stride; idx_ptr_size[i] = xmax; idx_ptr_stride[i] = stride; wt_idx_ptr[i] = i * max_interp_size * weight_index_stride; } - return {output, max_interp_size}; + return {output, max_interp_size, wt_max}; } /* @@ -911,25 +918,17 @@ struct HelperInterpBase { std::vector indices_weights; auto align_corners_delta = (align_corners && !antialias) ? 0.5 : 0.0; - std::tie(indices_weights, interp_size) = HelperInterpBase::_compute_indices_weights_aa( + double wt_max; + std::tie(indices_weights, interp_size, wt_max) = HelperInterpBase::_compute_indices_weights_aa( input_size, output_size, stride, ndims, reshape_dim, scale, interp_size, aa_filter_fn, antialias, align_corners_delta); // Rescale float weights to int16 and compute weights precision auto weights_f64 = indices_weights[3]; double * data_f64 = weights_f64.data_ptr(); - int64_t weights_f64_size = output_size * interp_size; - // can't use weights_f64.max() here as tensor is restrided - double w_max = data_f64[0]; - for (const auto i : c10::irange(weights_f64_size)) { - double v = data_f64[i]; - if (w_max < v) { - w_max = v; - } - } unsigned int weights_precision = 0; - for (weights_precision = 0; weights_precision < 22; weights_precision += 1) { - int next_value = (int) (0.5 + w_max * (1 << (weights_precision + 1))); + for (weights_precision = 0; weights_precision < 22; ++weights_precision) { + int next_value = (int) (0.5 + wt_max * (1 << (weights_precision + 1))); if (next_value >= (1 << 15)) break; } @@ -939,8 +938,7 @@ struct HelperInterpBase { auto aligned_interp_size = interp_size; if (align_i32) { - // We should respect int32 alignment as - // we will load data as int32 with AVX2 + // We should respect int32 alignment as we will load int16 data as int32 // See ImagingResampleHorizontalConvolution8u4x, mmk0 = _mm256_set1_epi32(*(int32_t*)&k[x]); // compute aligned_interp_size = nearest pair value to interp_size while (aligned_interp_size % sizeof(int32_t) != 0) { @@ -952,20 +950,13 @@ struct HelperInterpBase { for (const auto j : c10::irange(output_size)) { for (const auto k : c10::irange(interp_size)) { - double v = data_f64[j * interp_size + k]; - if (v < 0) { - data_i16[j * aligned_interp_size + k] = (int) (-0.5 + v * (1 << weights_precision)); - } else { - data_i16[j * aligned_interp_size + k] = (int) (0.5 + v * (1 << weights_precision)); - } + double v = data_f64[j * interp_size + k] * (1 << weights_precision); + data_i16[j * aligned_interp_size + k] = (v < 0) ? (int) (-0.5 + v) : (int) (0.5 + v); } } return {indices_weights, aligned_interp_size, weights_precision}; } - - - }; struct HelperInterpNearest : public HelperInterpBase { @@ -1175,8 +1166,9 @@ struct HelperInterpLinear : public HelperInterpBase { auto interp_size = HelperInterpLinear::interp_size; int unused; + scalar_t unused_2; - std::tie(indices_weights, unused) = HelperInterpLinear::_compute_indices_weights_aa( + std::tie(indices_weights, unused, unused_2) = HelperInterpLinear::_compute_indices_weights_aa( input_size, output_size, stride, @@ -1307,8 +1299,9 @@ struct HelperInterpCubic : public HelperInterpBase { auto interp_size = HelperInterpCubic::interp_size; int unused; + scalar_t unused_2; - std::tie(indices_weights, unused) = HelperInterpCubic::_compute_indices_weights_aa( + std::tie(indices_weights, unused, unused_2) = HelperInterpCubic::_compute_indices_weights_aa( input_size, output_size, stride, diff --git a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h index e8239cf6b86..ef4c204f7bd 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h +++ b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h @@ -39,13 +39,10 @@ static __m128i inline mm_cvtepu8_epi32(const uint32_t* C10_RESTRICT ptr) { return _mm_cvtepu8_epi32(_mm_cvtsi32_si128(*(int32_t*)ptr)); } -// TODO: We may want to hard-code an unrolled version for the case where -// num_channels=3 to hint the compiler to vectorize this (looks at original -// PIL-SIMD's code). at::Tensor unpack_rgb(const at::Tensor& packed_tensor) { // Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into - // RGBARGBARGBA format where A is hard-coded to 255. Each pixel is encoded - // into as 32bits. This generalizes to num_channels <= 4 and also works for + // RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded + // into as 32 bits. This generalizes to num_channels <= 4 and also works for // non-channels_last tensors. const uint8_t* packed = (const uint8_t*)packed_tensor.data_ptr(); @@ -71,6 +68,8 @@ void pack_rgb( const at::Tensor& unpacked_tensor, // IN const at::Tensor& packed_tensor // OUT ) { + // Convert from unpacked channels last 4-channels tensor into original data layout. + constexpr int rgba_size = 4; uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr(); uint8_t* packed = (uint8_t*)packed_tensor.data_ptr(); @@ -94,33 +93,35 @@ void ImagingResampleHorizontalConvolution8u4x( uint32_t* C10_RESTRICT lineOut1, uint32_t* C10_RESTRICT lineOut2, uint32_t* C10_RESTRICT lineOut3, + int64_t out_xsize, const uint32_t* C10_RESTRICT lineIn0, const uint32_t* C10_RESTRICT lineIn1, const uint32_t* C10_RESTRICT lineIn2, const uint32_t* C10_RESTRICT lineIn3, - int xsize, - int* xbounds, - int16_t* kk, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, int kmax, - int coefs_precision); + unsigned int coefs_precision); void ImagingResampleHorizontalConvolution8u( uint32_t* C10_RESTRICT lineOut, + int64_t out_xsize, const uint32_t* C10_RESTRICT lineIn, - int xsize, - int* xbounds, - int16_t* kk, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, int kmax, - int coefs_precision); + unsigned int coefs_precision); void ImagingResampleVerticalConvolution8u( uint32_t* C10_RESTRICT lineOut, - const uint32_t* C10_RESTRICT imIn, - int xmin, - int xmax, - int16_t* k, - int coefs_precision, - int xin); + const uint32_t* C10_RESTRICT lineIn, + int64_t xsize, + int64_t ids_min, + int64_t ids_size, + const int16_t* k, + unsigned int coefs_precision); void ImagingResampleHorizontal( const at::Tensor & unpacked_output, @@ -128,57 +129,66 @@ void ImagingResampleHorizontal( int ksize, const std::vector& horiz_indices_weights, unsigned int horiz_weights_precision) { + + // Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs. + + // Input data is stored as + // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...] + // Weights are float values computed for each output pixel and rescaled to uint16: + // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]] + // We want to compute the output as following: + // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...] + // where + // oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1] + // oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1] + // oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1] + // + // TODO: we may want to merge that into the fallback code (currently called // basic_loop_aa_horizontal) // Although this may not be needed if / when we port all this code to use // Vec.h since this would potentially give us another fall-back implem - int yy; - int16_t* kk = (int16_t*)(horiz_indices_weights[3].data_ptr()); + const int16_t* kk = (int16_t*)(horiz_indices_weights[3].data_ptr()); auto xout = unpacked_output.size(2); auto yout = unpacked_output.size(1); auto xin = unpacked_input.size(2); - std::vector bounds_vec(2 * xout, 0); - int* bounds = bounds_vec.data(); + const int64_t* idx_ptr_xmin = horiz_indices_weights[0].data_ptr(); + const int64_t* idx_ptr_size = horiz_indices_weights[1].data_ptr(); - int64_t* idx_ptr_xmin = horiz_indices_weights[0].data_ptr(); - int64_t* idx_ptr_size = horiz_indices_weights[1].data_ptr(); - for (int i = 0; i < xout; i++) { - bounds[2 * i + 0] = idx_ptr_xmin[i]; - bounds[2 * i + 1] = idx_ptr_size[i]; - } - - uint32_t* unpacked_input_p = (uint32_t*) unpacked_input.data_ptr(); uint32_t* unpacked_output_p = (uint32_t*) unpacked_output.data_ptr(); + const uint32_t* unpacked_input_p = (uint32_t*) unpacked_input.data_ptr(); - yy = 0; + int64_t yy = 0; for (; yy < yout - 3; yy += 4) { ImagingResampleHorizontalConvolution8u4x( unpacked_output_p + yy * xout, unpacked_output_p + (yy + 1) * xout, unpacked_output_p + (yy + 2) * xout, unpacked_output_p + (yy + 3) * xout, + xout, unpacked_input_p + yy * xin, unpacked_input_p + (yy + 1) * xin, unpacked_input_p + (yy + 2) * xin, unpacked_input_p + (yy + 3) * xin, - xout, - bounds, + idx_ptr_xmin, + idx_ptr_size, kk, ksize, - (int)horiz_weights_precision); + horiz_weights_precision); } for (; yy < yout; yy++) { ImagingResampleHorizontalConvolution8u( unpacked_output_p + yy * xout, - unpacked_input_p + yy * xin, xout, - bounds, + unpacked_input_p + yy * xin, + idx_ptr_xmin, + idx_ptr_size, kk, ksize, - (int)horiz_weights_precision); + horiz_weights_precision); } } @@ -188,36 +198,46 @@ void ImagingResampleVertical( int ksize, const std::vector& vert_indices_weights, unsigned int vert_weights_precision) { + + // Interpolation vertical pass: we compute y-axis interpolation outputs. + // Input data is stored as + // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...] + // Weights are float values computed for each output pixel and rescaled to uint16: + // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]] + // We want to compute the output as following: + // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...] + // where + // oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // TODO: we may want to merge that into the fallback code (currently called // basic_loop_aa_vertical) // Although this may not be needed if / when we port all this code to use // Vec.h since this would potentially give us another fall-back implem - int ymin, ymax; - int16_t* k = nullptr; - int16_t* kk = (int16_t*)(vert_indices_weights[3].data_ptr()); + const int16_t* kk = (int16_t*)(vert_indices_weights[3].data_ptr()); - int64_t* idx_ptr_xmin = vert_indices_weights[0].data_ptr(); - int64_t* idx_ptr_size = vert_indices_weights[1].data_ptr(); + const int64_t* idx_ptr_xmin = vert_indices_weights[0].data_ptr(); + const int64_t* idx_ptr_size = vert_indices_weights[1].data_ptr(); uint32_t* unpacked_output_p = (uint32_t*) unpacked_output.data_ptr(); - uint32_t* unpacked_input_p = (uint32_t*) unpacked_input.data_ptr(); + const uint32_t* unpacked_input_p = (uint32_t*) unpacked_input.data_ptr(); auto xout = unpacked_output.size(2); auto yout = unpacked_output.size(1); for (const auto yy : c10::irange(yout)) { - k = &kk[yy * ksize]; - - ymin = idx_ptr_xmin[yy]; - ymax = idx_ptr_size[yy]; + const auto* k = &kk[yy * ksize]; + auto ids_min = idx_ptr_xmin[yy]; + auto ids_size = idx_ptr_size[yy]; ImagingResampleVerticalConvolution8u( unpacked_output_p + yy * xout, unpacked_input_p, - ymin, - ymax, + xout, + ids_min, + ids_size, k, - (int)vert_weights_precision, - xout); + vert_weights_precision); } } @@ -283,7 +303,7 @@ void upsample_avx_bilinear_uint8( F::compute_indices_int16_weights_aa( /*input_size=*/yin, /*output_size=*/yout, - /*stride=*/1, + /*stride=*/xout, /*ndims=*/4, /*reshape_dim=*/interp_dim, /*align_corners=*/align_corners, @@ -292,13 +312,16 @@ void upsample_avx_bilinear_uint8( /*align_i32=*/true); } - bool is_rgba = num_channels == 4 && input.is_contiguous(at::MemoryFormat::ChannelsLast); + bool needs_unpacking = num_channels == 4 && input.is_contiguous(at::MemoryFormat::ChannelsLast); at::Tensor buffer_horiz, buffer_vert; - if (need_horizontal && !(is_rgba && !need_vertical)) { + // Minor optimization: we can avoid allocating an extra buffer if we're performing + // horizontal-only or vertical-only interpolation, and if the tensor doesn't + // need unpacking + if (need_horizontal && !(needs_unpacking && !need_vertical)) { buffer_horiz = at::empty({4, yin, xout}, input.options()); } - if (need_vertical && !is_rgba) { + if (need_vertical && !needs_unpacking) { buffer_vert = at::empty({4, yout, xout}, input.options()); } @@ -308,12 +331,12 @@ void upsample_avx_bilinear_uint8( // tensors and just copy part of them (line by line). for (const auto i : c10::irange(batch_size)) { - at::Tensor unpacked_input = (is_rgba) ? input[i] : unpack_rgb(input[i]); + at::Tensor unpacked_input = (needs_unpacking) ? input[i] : unpack_rgb(input[i]); at::Tensor unpacked_output; if (need_horizontal) { - at::Tensor unpacked_output_temp = (is_rgba && !need_vertical) ? output[i] : buffer_horiz; + at::Tensor unpacked_output_temp = (needs_unpacking && !need_vertical) ? output[i] : buffer_horiz; ImagingResampleHorizontal( unpacked_output_temp, @@ -324,7 +347,7 @@ void upsample_avx_bilinear_uint8( unpacked_output = unpacked_input = unpacked_output_temp; } if (need_vertical) { - unpacked_output = (is_rgba) ? output[i] : buffer_vert; + unpacked_output = (needs_unpacking) ? output[i] : buffer_vert; ImagingResampleVertical( unpacked_output, @@ -336,382 +359,588 @@ void upsample_avx_bilinear_uint8( TORCH_INTERNAL_ASSERT(unpacked_output.defined()); - if (!is_rgba) { + if (!needs_unpacking) { pack_rgb(unpacked_output, output[i]); } } } -// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5 void ImagingResampleHorizontalConvolution8u4x( uint32_t* C10_RESTRICT lineOut0, uint32_t* C10_RESTRICT lineOut1, uint32_t* C10_RESTRICT lineOut2, uint32_t* C10_RESTRICT lineOut3, + int64_t out_xsize, const uint32_t* C10_RESTRICT lineIn0, const uint32_t* C10_RESTRICT lineIn1, const uint32_t* C10_RESTRICT lineIn2, const uint32_t* C10_RESTRICT lineIn3, - int xsize, - int* xbounds, - int16_t* kk, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, int kmax, - int coefs_precision) { - int xmin, xmax, x; - int16_t* k; + unsigned int coefs_precision) { + // Interpolation horizontal pass processing together 4 vertical lines. + // - Input data format is RGBA with R,G,B,A being uint8, we can encode 4 values as a single uint32 value. + // - We split the size of weight vector for a given output index as a sum: + // ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1. + // - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values + // in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1"). - for (const auto xx : c10::irange(xsize)) { - xmin = xbounds[xx * 2 + 0]; - xmax = xbounds[xx * 2 + 1]; - k = &kk[xx * kmax]; - x = 0; + // Define shuffling masks (low/high) for num_channels 4 + // Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA: + // [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] -> + // [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0] + // Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:: + // [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] -> + // [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0] - __m256i sss0, sss1; - __m256i zero = _mm256_setzero_si256(); - __m256i initial = _mm256_set1_epi32(1 << (coefs_precision - 1)); - sss0 = initial; - sss1 = initial; + const auto mask_low_c4 = _mm256_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_high_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8); - for (; x < xmax - 3; x += 4) { - __m256i pix, mmk0, mmk1, source; + const auto mask_low = mask_low_c4; + const auto mask_high = mask_high_c4; - mmk0 = _mm256_set1_epi32(*(int32_t*)&k[x]); - mmk1 = _mm256_set1_epi32(*(int32_t*)&k[x + 2]); + const auto zero = _mm256_setzero_si256(); + const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1)); - source = _mm256_inserti128_si256( - _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&lineIn0[x + xmin])), - _mm_loadu_si128((__m128i*)&lineIn1[x + xmin]), - 1); - // clang-format off - pix = _mm256_shuffle_epi8(source, _mm256_set_epi8( - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0, - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0)); - sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk0)); - pix = _mm256_shuffle_epi8(source, _mm256_set_epi8( - -1,15, -1,11, -1,14, -1,10, -1,13, -1,9, -1,12, -1,8, - -1,15, -1,11, -1,14, -1,10, -1,13, -1,9, -1,12, -1,8)); - sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk1)); + for (const auto out_x : c10::irange(out_xsize)) { + const auto ids_min = idx_ptr_xmin[out_x]; + const auto ids_size = idx_ptr_size[out_x]; + const auto * k = &kk[out_x * kmax]; + int64_t i = 0; - source = _mm256_inserti128_si256( - _mm256_castsi128_si256(_mm_loadu_si128((__m128i*)&lineIn2[x + xmin])), - _mm_loadu_si128((__m128i*)&lineIn3[x + xmin]), - 1); - pix = _mm256_shuffle_epi8(source, _mm256_set_epi8( - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0, - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0)); - sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix, mmk0)); - pix = _mm256_shuffle_epi8(source, _mm256_set_epi8( - -1,15, -1,11, -1,14, -1,10, -1,13, -1,9, -1,12, -1,8, - -1,15, -1,11, -1,14, -1,10, -1,13, -1,9, -1,12, -1,8)); - sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix, mmk1)); + auto sss0 = initial; + auto sss1 = initial; + + const auto * lineIn0_min = lineIn0 + ids_min; + const auto * lineIn1_min = lineIn1 + ids_min; + const auto * lineIn2_min = lineIn2 + ids_min; + const auto * lineIn3_min = lineIn3 + ids_min; + + // block 4 + for (; i < ids_size - 3; i += 4) { + // Load 4 values from weight vector + // mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + // mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...] + const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]); + const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]); + + // Load 8 pixels (4 per line) from input lines 0 and 1: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3 + // ] + auto source = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i*)&lineIn0_min[i])), + _mm_loadu_si128((__m128i*)&lineIn1_min[i]), 1); + + // Apply mask_low: + // [r0 g0 b0 a0 r1 g1 b1 a1 ... | R0 G0 B0 A0 R1 G1 B1 A1 ... ] -> + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0] + auto pix1 = _mm256_shuffle_epi8(source, mask_low); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0)); + + // Apply mask_high: + // [ ... r2 g2 b2 a2 r3 g3 b3 a3 | ... R2 G2 B2 A2 R3 G3 B3 A3 ] -> + // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0] + auto pix2 = _mm256_shuffle_epi8(source, mask_high); + // Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1)); + + // Same as above to next lines 2 and 3: + auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i*)&lineIn2_min[i])), + _mm_loadu_si128((__m128i*)&lineIn3_min[i]), 1); + auto pix3 = _mm256_shuffle_epi8(source2, mask_low); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0)); + auto pix4 = _mm256_shuffle_epi8(source2, mask_high); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1)); } - for (; x < xmax - 1; x += 2) { - __m256i pix, mmk; + // block 2 + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]); - mmk = _mm256_set1_epi32(*(int32_t*)&k[x]); + // Load 4 pixels (2 per line) from input lines 0 and 1: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadl_epi64((__m128i*)&lineIn0_min[i])), + _mm_loadl_epi64((__m128i*)&lineIn1_min[i]), 1); - pix = _mm256_inserti128_si256( - _mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)&lineIn0[x + xmin])), - _mm_loadl_epi64((__m128i*)&lineIn1[x + xmin]), - 1); - pix = _mm256_shuffle_epi8(pix, _mm256_set_epi8( - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0, - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0)); - sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk)); + // Apply mask_low: + // [r0 g0 b0 a0 r1 g1 b1 a1 ... | R0 G0 B0 A0 R1 G1 B1 A1 ... ] -> + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0] + auto pix1 = _mm256_shuffle_epi8(source1, mask_low); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); - pix = _mm256_inserti128_si256( - _mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)&lineIn2[x + xmin])), - _mm_loadl_epi64((__m128i*)&lineIn3[x + xmin]), - 1); - pix = _mm256_shuffle_epi8(pix, _mm256_set_epi8( - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0, - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0)); - sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix, mmk)); - // clang-format on + // Same as above for lines 2 and 3: + auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadl_epi64((__m128i*)&lineIn2_min[i])), + _mm_loadl_epi64((__m128i*)&lineIn3_min[i]), 1); + auto pix2 = _mm256_shuffle_epi8(source2, mask_low); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); } - for (; x < xmax; x++) { - __m256i pix, mmk; + // block 1 + for (; i < ids_size; i++) { + // Load 1 value from weight vector + // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...] + const auto mmk = _mm256_set1_epi32(k[i]); - // [16] xx k0 xx k0 xx k0 xx k0 xx k0 xx k0 xx k0 xx k0 - mmk = _mm256_set1_epi32(k[x]); + // Load 2 pixels (one per line) from input lines 0 and 1: + // source = [ + // r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0 + // R0 G0 B0 A0 0 0 0 0 0 0 0 0 0 0 0 0 + // ] + auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(&lineIn0_min[i])), + mm_cvtepu8_epi32(&lineIn1_min[i]), 1); + // Compute output value as C += w0 * C0 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); - // [16] xx a0 xx b0 xx g0 xx r0 xx a0 xx b0 xx g0 xx r0 - pix = _mm256_inserti128_si256( - _mm256_castsi128_si256(mm_cvtepu8_epi32(&lineIn0[x + xmin])), - mm_cvtepu8_epi32(&lineIn1[x + xmin]), - 1); - sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk)); - - pix = _mm256_inserti128_si256( - _mm256_castsi128_si256(mm_cvtepu8_epi32(&lineIn2[x + xmin])), - mm_cvtepu8_epi32(&lineIn3[x + xmin]), - 1); - sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix, mmk)); + // Same as above for lines 2 and 3 + auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(&lineIn2_min[i])), + mm_cvtepu8_epi32(&lineIn3_min[i]), 1); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); } + // Convert fixed point values back to integers (truncating) sss0 = _mm256_srai_epi32(sss0, coefs_precision); sss1 = _mm256_srai_epi32(sss1, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0) sss0 = _mm256_packs_epi32(sss0, zero); sss1 = _mm256_packs_epi32(sss1, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d 0 0 0 0) sss0 = _mm256_packus_epi16(sss0, zero); sss1 = _mm256_packus_epi16(sss1, zero); - lineOut0[xx] = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 0)); - lineOut1[xx] = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1)); - lineOut2[xx] = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 0)); - lineOut3[xx] = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1)); + + // Write the output into single uint32 + // (a b c d) -> x_uint32 + lineOut0[out_x] = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 0)); + lineOut1[out_x] = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1)); + lineOut2[out_x] = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 0)); + lineOut3[out_x] = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1)); } } -// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5 void ImagingResampleHorizontalConvolution8u( uint32_t* C10_RESTRICT lineOut, + int64_t out_xsize, const uint32_t* C10_RESTRICT lineIn, - int xsize, - int* xbounds, - int16_t* kk, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, int kmax, - int coefs_precision) { - int xmin, xmax, x; - int16_t* k; + unsigned int coefs_precision) { - for (const auto xx : c10::irange(xsize)) { + // Interpolation horizontal pass processing only one vertical line. + // - Input data format is RGBA with R,G,B,A being uint8, we can encode 4 values as a single uint32 value. + // - We split the size of weight vector for a given output index as a sum: + // ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1 + // - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in + // in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1"). + + // Define various shuffling masks + const auto kmask_low = _mm256_set_epi8( + 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, + 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0); + const auto kmask_high = _mm256_set_epi8( + 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4); + const auto kmask_hl = _mm256_set_epi8( + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, + 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0); + + const auto mask_low_c4 = _mm256_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_high_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8); + const auto mask_hl_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_low128_c4 = _mm_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + + const auto mask_low = mask_low_c4; + const auto mask_high = mask_high_c4; + const auto mask_hl = mask_hl_c4; + const auto mask_low128 = mask_low128_c4; + + // out_xsize = output width, out_x = output x index + // ids_min is the input offset index corresponding to out_x + // ids_size is the interpolation size for out_x + + const auto zero = _mm_setzero_si128(); + + for (const auto out_x : c10::irange(out_xsize)) { __m128i sss; - xmin = xbounds[xx * 2 + 0]; - xmax = xbounds[xx * 2 + 1]; - k = &kk[xx * kmax]; - x = 0; + const auto ids_min = idx_ptr_xmin[out_x]; + const auto ids_size = idx_ptr_size[out_x]; + const auto * k = &kk[out_x * kmax]; + int64_t i = 0; - if (xmax < 8) { + const auto * lineIn_min = lineIn + ids_min; + + if (ids_size < 8) { sss = _mm_set1_epi32(1 << (coefs_precision - 1)); } else { // Lower part will be added to higher, use only half of the error - __m256i sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2)); + auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2)); - for (; x < xmax - 7; x += 8) { - __m256i pix, mmk, source; - __m128i tmp = _mm_loadu_si128((__m128i*)&k[x]); - __m256i ksource = - _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + // block 8 + for (; i < ids_size - 7; i += 8) { + // Load 8 values from weight vector + auto tmp = _mm_loadu_si128((__m128i*)&k[i]); + // ksource = [ + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7 + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7 + // ] + auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); - // clang-format off - source = _mm256_loadu_si256((__m256i*)&lineIn[x + xmin]); - pix = _mm256_shuffle_epi8(source, _mm256_set_epi8( - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0, - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0)); - mmk = _mm256_shuffle_epi8(ksource, _mm256_set_epi8( - 11,10, 9,8, 11,10, 9,8, 11,10, 9,8, 11,10, 9,8, - 3,2, 1,0, 3,2, 1,0, 3,2, 1,0, 3,2, 1,0)); - sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk)); + // Load 8 pixels from input: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7 + // ] + auto source = _mm256_loadu_si256((__m256i*)&lineIn_min[i]); + // Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA + // pix1 = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 + // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0 + // ] + auto pix1 = _mm256_shuffle_epi8(source, mask_low); + // mmk1 = [ + // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ... + // wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ... + // ] + auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low); + // Compute output value as + // C += w0 * C0 + w1 * C1 + // C += w4 * C4 + w5 * C5 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1)); - pix = _mm256_shuffle_epi8(source, _mm256_set_epi8( - -1,15, -1,11, -1,14, -1,10, -1,13, -1,9, -1,12, -1,8, - -1,15, -1,11, -1,14, -1,10, -1,13, -1,9, -1,12, -1,8)); - mmk = _mm256_shuffle_epi8(ksource, _mm256_set_epi8( - 15,14, 13,12, 15,14, 13,12, 15,14, 13,12, 15,14, 13,12, - 7,6, 5,4, 7,6, 5,4, 7,6, 5,4, 7,6, 5,4)); - sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk)); - // clang-format on + // Same as above for higher part of each lane + auto pix2 = _mm256_shuffle_epi8(source, mask_high); + auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high); + // Compute output value as + // C += w2 * C2 + w3 * C3 + // C += w6 * C6 + w7 * C7 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2)); } - for (; x < xmax - 3; x += 4) { - __m256i pix, mmk, source; - __m128i tmp = _mm_loadl_epi64((__m128i*)&k[x]); - __m256i ksource = - _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + // block 4 + for (; i < ids_size - 3; i += 4) { + // Load 4 values from weight vector + auto tmp = _mm_loadl_epi64((__m128i *) &k[i]); + // ksource = [ + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0 + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0 + // ] + auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); - tmp = _mm_loadu_si128((__m128i*)&lineIn[x + xmin]); - source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + // Load 4 pixels from input line + tmp = _mm_loadu_si128((__m128i*)&lineIn_min[i]); + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // ] + auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); - // clang-format off - pix = _mm256_shuffle_epi8(source, _mm256_set_epi8( - -1,15, -1,11, -1,14, -1,10, -1,13, -1,9, -1,12, -1,8, - -1,7, -1,3, -1,6, -1,2, -1,5, -1,1, -1,4, -1,0)); - mmk = _mm256_shuffle_epi8(ksource, _mm256_set_epi8( - 7,6, 5,4, 7,6, 5,4, 7,6, 5,4, 7,6, 5,4, - 3,2, 1,0, 3,2, 1,0, 3,2, 1,0, 3,2, 1,0)); + // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA + // pix = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 + // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 + // ] + auto pix = _mm256_shuffle_epi8(source, mask_hl); + // mmk = [ + // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ... + // wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ... + // ] + auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl); + // Compute output value as + // C += w0 * C0 + w1 * C1 + // C += w2 * C2 + w3 * C3 for each channel in 32-bit precision sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk)); - // clang-format on } + // Sum results between the lanes sss = _mm_add_epi32( _mm256_extracti128_si256(sss256, 0), _mm256_extracti128_si256(sss256, 1)); } - for (; x < xmax - 1; x += 2) { - __m128i mmk = _mm_set1_epi32(*(int32_t*)&k[x]); - __m128i source = _mm_loadl_epi64((__m128i*)&lineIn[x + xmin]); - __m128i pix = _mm_shuffle_epi8( - source, - _mm_set_epi8(-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0)); + // block 2 + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + // Load 2 pixels from input line + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // ] + auto source = _mm_loadl_epi64((__m128i*)&lineIn_min[i]); + // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA + auto pix = _mm_shuffle_epi8(source, mask_low128); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); } - for (; x < xmax; x++) { - __m128i pix = mm_cvtepu8_epi32(&lineIn[x + xmin]); - __m128i mmk = _mm_set1_epi32(k[x]); + // block 1 + for (; i < ids_size; i++) { + // Load 1 value from weight vector + // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...] + auto mmk = _mm_set1_epi32(k[i]); + // Load one pixel from input line + // pix = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0 + // ] + auto pix = mm_cvtepu8_epi32(&lineIn_min[i]); + // Compute output value as C += w0 * C0 for each channel in 32-bit precision sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); } + + // Convert fixed point values back to integers (truncating) sss = _mm_srai_epi32(sss, coefs_precision); - sss = _mm_packs_epi32(sss, sss); - lineOut[xx] = _mm_cvtsi128_si32(_mm_packus_epi16(sss, sss)); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0) + sss = _mm_packs_epi32(sss, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d 0 0 0 0) + sss = _mm_packus_epi16(sss, zero); + // Write the output into single uint32 + // (a b c d) -> x_uint32 + lineOut[out_x] = _mm_cvtsi128_si32(sss); } } -// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5 void ImagingResampleVerticalConvolution8u( uint32_t* C10_RESTRICT lineOut, - const uint32_t* C10_RESTRICT imIn, - int xmin, - int xmax, - int16_t* k, - int coefs_precision, - int xin) { - int x; - int xx = 0; - int xsize = xin; + const uint32_t* C10_RESTRICT lineIn, + int64_t xsize, + int64_t ids_min, + int64_t ids_size, + const int16_t* k, + unsigned int coefs_precision) { + // Interpolation vertical pass processing one line. + // - We process x-axis data with blocks of 8, 2 and 1 + // - We split the size of weight vector for a given output index as a sum: K = n * 2 + m. - __m128i initial = _mm_set1_epi32(1 << (coefs_precision - 1)); - __m256i initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1)); + // xsize = output width, also equals to input width + // ids_size = interpolation size + // ids_min = input y start index + const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1)); + const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1)); + const auto zero = _mm_setzero_si128(); + const auto zero_256 = _mm256_setzero_si256(); + + int64_t xx = 0; + // block 8 for (; xx < xsize - 7; xx += 8) { - __m256i sss0 = initial_256; - __m256i sss1 = initial_256; - __m256i sss2 = initial_256; - __m256i sss3 = initial_256; - x = 0; - for (; x < xmax - 1; x += 2) { - __m256i source, source1, source2; - __m256i pix, mmk; + auto sss0 = initial_256; + auto sss1 = initial_256; + auto sss2 = initial_256; + auto sss3 = initial_256; + int64_t i = 0; + const auto * lineIn_min = lineIn + xx + ids_min; - // Load two coefficients at once - mmk = _mm256_set1_epi32(*(int32_t*)&k[x]); + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]); - // Load 2 lines - // (__m256i *) &imIn->image32[x + xmin][xx] - source1 = _mm256_loadu_si256((__m256i*)(imIn + (x + xmin) * xin + xx)); - // (__m256i *) &imIn->image32[x + 1 + xmin][xx] - source2 = - _mm256_loadu_si256((__m256i*)(imIn + (x + 1 + xmin) * xin + xx)); + // Load 8 pixels per line + // source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7 + // ] + auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * xsize)); + auto source2 = _mm256_loadu_si256((__m256i*)(lineIn_min + (i + 1) * xsize)); - source = _mm256_unpacklo_epi8(source1, source2); - pix = _mm256_unpacklo_epi8(source, _mm256_setzero_si256()); - sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk)); - pix = _mm256_unpackhi_epi8(source, _mm256_setzero_si256()); - sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix, mmk)); + // Interleave source1 and source2 from the low half of each 128-bit lane + // and cast the result to epi16 + // pix1 = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0 + // ] + auto source_lo = _mm256_unpacklo_epi8(source1, source2); + auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256); + // Compute output value as + // C += w0 * c0 + w1 * C0 + // C += w0 * c1 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); - source = _mm256_unpackhi_epi8(source1, source2); - pix = _mm256_unpacklo_epi8(source, _mm256_setzero_si256()); - sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix, mmk)); - pix = _mm256_unpackhi_epi8(source, _mm256_setzero_si256()); - sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix, mmk)); + // pix2 = [ + // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0 + // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0 + // ] + auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256); + // Compute output value as + // C += w0 * c2 + w1 * C2 + // C += w0 * c3 + w1 * C3 for each channel in 32-bit precision + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + + // Same as above for the high half of each 128-bit lane + auto source_hi = _mm256_unpackhi_epi8(source1, source2); + auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256); + sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk)); + auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256); + sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk)); } - for (; x < xmax; x += 1) { - __m256i source, source1, pix, mmk; - mmk = _mm256_set1_epi32(k[x]); + // Same processing as above but with a single weight value + for (; i < ids_size; i += 1) { + auto mmk = _mm256_set1_epi32(k[i]); - // (__m256i *) &imIn->image32[x + xmin][xx]) - source1 = _mm256_loadu_si256((__m256i*)(imIn + (x + xmin) * xin + xx)); + auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * xsize)); - source = _mm256_unpacklo_epi8(source1, _mm256_setzero_si256()); - pix = _mm256_unpacklo_epi8(source, _mm256_setzero_si256()); - sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk)); - pix = _mm256_unpackhi_epi8(source, _mm256_setzero_si256()); - sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix, mmk)); + auto source_lo = _mm256_unpacklo_epi8(source1, zero_256); + auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256); + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); - source = _mm256_unpackhi_epi8(source1, _mm256_setzero_si256()); - pix = _mm256_unpacklo_epi8(source, _mm256_setzero_si256()); - sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix, mmk)); - pix = _mm256_unpackhi_epi8(source, _mm256_setzero_si256()); - sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix, mmk)); + auto source_hi = _mm256_unpackhi_epi8(source1, zero_256); + auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256()); + sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk)); + auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256()); + sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk)); } + // Convert fixed point values back to integers (truncating) sss0 = _mm256_srai_epi32(sss0, coefs_precision); sss1 = _mm256_srai_epi32(sss1, coefs_precision); sss2 = _mm256_srai_epi32(sss2, coefs_precision); sss3 = _mm256_srai_epi32(sss3, coefs_precision); - + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) sss0 = _mm256_packs_epi32(sss0, sss1); sss2 = _mm256_packs_epi32(sss2, sss3); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) sss0 = _mm256_packus_epi16(sss0, sss2); + + // Store 8 pixels to the output _mm256_storeu_si256((__m256i*)&lineOut[xx], sss0); } + // block 2 for (; xx < xsize - 1; xx += 2) { - __m128i sss0 = initial; // left row - __m128i sss1 = initial; // right row - x = 0; - for (; x < xmax - 1; x += 2) { - __m128i source, source1, source2; - __m128i pix, mmk; + auto sss0 = initial; + auto sss1 = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + xx + ids_min; - // Load two coefficients at once - mmk = _mm_set1_epi32(*(int32_t*)&k[x]); + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); - // Load 2 lines - // (__m128i *) &imIn->image32[x + xmin][xx]) - source1 = _mm_loadl_epi64((__m128i*)(imIn + (x + xmin) * xin + xx)); - // (__m128i *) &imIn->image32[x + 1 + xmin][xx] - source2 = _mm_loadl_epi64((__m128i*)(imIn + (x + 1 + xmin) * xin + xx)); + // Load 2 pixels per line + // source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm_loadl_epi64((__m128i*)(lineIn_min + i * xsize)); + auto source2 = _mm_loadl_epi64((__m128i*)(lineIn_min + (i + 1) * xsize)); - source = _mm_unpacklo_epi8(source1, source2); - pix = _mm_unpacklo_epi8(source, _mm_setzero_si128()); + // Interleave source1 and source2 and cast the result to epi16 + // pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // ] + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk)); - pix = _mm_unpackhi_epi8(source, _mm_setzero_si128()); + // pix = [ + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0 + // ] + pix = _mm_unpackhi_epi8(source, zero); + // Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk)); } - for (; x < xmax; x += 1) { - __m128i source, source1, pix, mmk; - mmk = _mm_set1_epi32(k[x]); + // Same processing as above but with a single weight value + for (; i < ids_size; i += 1) { + auto mmk = _mm_set1_epi32(k[i]); - // (__m128i *) &imIn->image32[x + xmin][xx]); - source1 = _mm_loadl_epi64((__m128i*)(imIn + (x + xmin) * xin + xx)); + auto source1 = _mm_loadl_epi64((__m128i*)(lineIn_min + i * xsize)); - source = _mm_unpacklo_epi8(source1, _mm_setzero_si128()); - pix = _mm_unpacklo_epi8(source, _mm_setzero_si128()); - sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk)); - pix = _mm_unpackhi_epi8(source, _mm_setzero_si128()); - sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk)); + auto source = _mm_unpacklo_epi8(source1, zero); + auto pix1 = _mm_unpacklo_epi8(source, zero); + sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk)); + auto pix2 = _mm_unpackhi_epi8(source, zero); + sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk)); } + // Convert fixed point values back to integers (truncating) sss0 = _mm_srai_epi32(sss0, coefs_precision); sss1 = _mm_srai_epi32(sss1, coefs_precision); - + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) sss0 = _mm_packs_epi32(sss0, sss1); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) sss0 = _mm_packus_epi16(sss0, sss0); + // Store 2 pixels to the output _mm_storel_epi64((__m128i*)&lineOut[xx], sss0); } + // block 1 for (; xx < xsize; xx++) { - __m128i sss = initial; - x = 0; - for (; x < xmax - 1; x += 2) { - __m128i source, source1, source2; - __m128i pix, mmk; + auto sss = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + xx + ids_min; - // Load two coefficients at once - mmk = _mm_set1_epi32(*(int32_t*)&k[x]); + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); - // Load 2 lines - // *(int *) &imIn->image32[x + xmin][xx] - source1 = _mm_cvtsi32_si128(*(int*)(imIn + (x + xmin) * xin + xx)); - // *(int *) &imIn->image32[x + 1 + xmin][xx] - source2 = _mm_cvtsi32_si128(*(int*)(imIn + (x + 1 + xmin) * xin + xx)); + // Load one pixel per line + // source1 = [ + // r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm_cvtsi32_si128(*(int32_t*)(lineIn_min + i * xsize)); + auto source2 = _mm_cvtsi32_si128(*(int32_t*)(lineIn_min + (i + 1) * xsize)); - source = _mm_unpacklo_epi8(source1, source2); - pix = _mm_unpacklo_epi8(source, _mm_setzero_si128()); + // Interleave source1 and source2 and cast the result to epi16 + // pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // ] + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + // Same processing as above but with a single weight value + for (; i < ids_size; i++) { + auto mmk = _mm_set1_epi32(k[i]); + auto pix = mm_cvtepu8_epi32(lineIn_min + i * xsize); sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); } - for (; x < xmax; x++) { - // &imIn->image32[x + xmin][xx] - __m128i pix = mm_cvtepu8_epi32(imIn + (x + xmin) * xin + xx); - __m128i mmk = _mm_set1_epi32(k[x]); - sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); - } + // Convert fixed point values back to integers (truncating) sss = _mm_srai_epi32(sss, coefs_precision); - sss = _mm_packs_epi32(sss, sss); - lineOut[xx] = _mm_cvtsi128_si32(_mm_packus_epi16(sss, sss)); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss = _mm_packs_epi32(sss, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss = _mm_packus_epi16(sss, zero); + // Store one pixel to the output + lineOut[xx] = _mm_cvtsi128_si32(sss); } }