Code update for vectorized interpolate cpu uint8 (#96847)

- code style update
- use idx_ptr_xmin/idx_ptr_size instead of bounds
- compute wt_max inside _compute_indices_weights_aa (no significant overhead)
- added comments and explanations
- renamed xmin/xmax into ids_min, ids_size

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96847
Approved by: https://github.com/peterbell10, https://github.com/NicolasHug, https://github.com/lezcano
This commit is contained in:
vfdev-5 2023-03-23 21:14:55 +00:00 committed by PyTorch MergeBot
parent 4ff71c91d3
commit a96ccaa362
2 changed files with 566 additions and 344 deletions

View File

@ -742,7 +742,7 @@ struct HelperInterpBase {
}
template <typename scalar_t, typename aa_filter_fn_t>
static inline void _compute_weights_aa(
static inline scalar_t _compute_weights_aa(
const int64_t i, const int64_t input_size, const scalar_t scale, const scalar_t support,
scalar_t* wt_ptr, const int64_t max_interp_size, aa_filter_fn_t filter_fn,
int64_t& xmin, int64_t& xsize, bool antialias, double align_corners_delta
@ -764,14 +764,19 @@ struct HelperInterpBase {
wt_ptr[j] = w;
total_w += w;
}
for (j = 0; j < xsize; j++) {
if (total_w != 0.0) {
scalar_t wt_max = 0.0;
if (total_w != 0.0) {
for (j = 0; j < xsize; j++) {
wt_ptr[j] /= total_w;
wt_max = std::max(wt_max, wt_ptr[j]);
}
}
for (; j < max_interp_size; j++) {
wt_ptr[j] = static_cast<scalar_t>(0.0);
}
return wt_max;
}
// Note [ Support for antialias=False as a subcase of antilias=True ]
@ -785,7 +790,7 @@ struct HelperInterpBase {
// indices, but this can be optimized further when aa=False since we know
// their actual dimensions.
template <typename scalar_t, typename aa_filter_fn_t, int weight_index_stride=sizeof(scalar_t)>
static inline std::tuple<std::vector<Tensor>, int> _compute_indices_weights_aa(
static inline std::tuple<std::vector<Tensor>, int, scalar_t> _compute_indices_weights_aa(
int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
int64_t reshape_dim, scalar_t scale,
int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, double align_corners_delta
@ -834,10 +839,10 @@ struct HelperInterpBase {
scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
int64_t* wt_idx_ptr = output[4].data_ptr<int64_t>();
int64_t xmin, xmax;
scalar_t wt_max = 0.0;
for (const auto i : c10::irange(output_size)) {
HelperInterpBase::_compute_weights_aa(
int64_t xmin, xmax;
auto wt_max_i = HelperInterpBase::_compute_weights_aa(
i,
input_size,
scale,
@ -850,12 +855,14 @@ struct HelperInterpBase {
antialias,
align_corners_delta);
wt_max = std::max(wt_max, wt_max_i);
idx_ptr_xmin[i] = xmin * stride;
idx_ptr_size[i] = xmax;
idx_ptr_stride[i] = stride;
wt_idx_ptr[i] = i * max_interp_size * weight_index_stride;
}
return {output, max_interp_size};
return {output, max_interp_size, wt_max};
}
/*
@ -911,25 +918,17 @@ struct HelperInterpBase {
std::vector<Tensor> indices_weights;
auto align_corners_delta = (align_corners && !antialias) ? 0.5 : 0.0;
std::tie(indices_weights, interp_size) = HelperInterpBase::_compute_indices_weights_aa<double, aa_filter_fn_t, sizeof(int16_t)>(
double wt_max;
std::tie(indices_weights, interp_size, wt_max) = HelperInterpBase::_compute_indices_weights_aa<double, aa_filter_fn_t, sizeof(int16_t)>(
input_size, output_size, stride, ndims, reshape_dim, scale, interp_size, aa_filter_fn, antialias, align_corners_delta);
// Rescale float weights to int16 and compute weights precision
auto weights_f64 = indices_weights[3];
double * data_f64 = weights_f64.data_ptr<double>();
int64_t weights_f64_size = output_size * interp_size;
// can't use weights_f64.max() here as tensor is restrided
double w_max = data_f64[0];
for (const auto i : c10::irange(weights_f64_size)) {
double v = data_f64[i];
if (w_max < v) {
w_max = v;
}
}
unsigned int weights_precision = 0;
for (weights_precision = 0; weights_precision < 22; weights_precision += 1) {
int next_value = (int) (0.5 + w_max * (1 << (weights_precision + 1)));
for (weights_precision = 0; weights_precision < 22; ++weights_precision) {
int next_value = (int) (0.5 + wt_max * (1 << (weights_precision + 1)));
if (next_value >= (1 << 15))
break;
}
@ -939,8 +938,7 @@ struct HelperInterpBase {
auto aligned_interp_size = interp_size;
if (align_i32) {
// We should respect int32 alignment as
// we will load data as int32 with AVX2
// We should respect int32 alignment as we will load int16 data as int32
// See ImagingResampleHorizontalConvolution8u4x, mmk0 = _mm256_set1_epi32(*(int32_t*)&k[x]);
// compute aligned_interp_size = nearest pair value to interp_size
while (aligned_interp_size % sizeof(int32_t) != 0) {
@ -952,20 +950,13 @@ struct HelperInterpBase {
for (const auto j : c10::irange(output_size)) {
for (const auto k : c10::irange(interp_size)) {
double v = data_f64[j * interp_size + k];
if (v < 0) {
data_i16[j * aligned_interp_size + k] = (int) (-0.5 + v * (1 << weights_precision));
} else {
data_i16[j * aligned_interp_size + k] = (int) (0.5 + v * (1 << weights_precision));
}
double v = data_f64[j * interp_size + k] * (1 << weights_precision);
data_i16[j * aligned_interp_size + k] = (v < 0) ? (int) (-0.5 + v) : (int) (0.5 + v);
}
}
return {indices_weights, aligned_interp_size, weights_precision};
}
};
struct HelperInterpNearest : public HelperInterpBase {
@ -1175,8 +1166,9 @@ struct HelperInterpLinear : public HelperInterpBase {
auto interp_size = HelperInterpLinear::interp_size;
int unused;
scalar_t unused_2;
std::tie(indices_weights, unused) = HelperInterpLinear::_compute_indices_weights_aa<scalar_t>(
std::tie(indices_weights, unused, unused_2) = HelperInterpLinear::_compute_indices_weights_aa<scalar_t>(
input_size,
output_size,
stride,
@ -1307,8 +1299,9 @@ struct HelperInterpCubic : public HelperInterpBase {
auto interp_size = HelperInterpCubic::interp_size;
int unused;
scalar_t unused_2;
std::tie(indices_weights, unused) = HelperInterpCubic::_compute_indices_weights_aa<scalar_t>(
std::tie(indices_weights, unused, unused_2) = HelperInterpCubic::_compute_indices_weights_aa<scalar_t>(
input_size,
output_size,
stride,

File diff suppressed because it is too large Load Diff