Fix cppcoreguidelines-init-variables ignorance (#141795)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141795
Approved by: https://github.com/albanD
This commit is contained in:
cyy 2025-01-28 17:11:37 +00:00 committed by PyTorch MergeBot
parent ac87388e61
commit c751541e79
39 changed files with 114 additions and 216 deletions

View File

@ -177,21 +177,18 @@ static void avg_pool3d_out_frame(
{
at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
for (const auto k : c10::irange(start, end)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t i, j, ti;
/* local pointers. */
const scalar_t *ip = input_p + k * itime * iwidth * iheight;
scalar_t *op = output_p + k * otime * owidth * oheight;
for (i = 0; i < otime * oheight * owidth; ++i)
for (int64_t i = 0; i < otime * oheight * owidth; ++i)
*(op + i) = 0;
/* loop over output */
for (ti = 0; ti < otime; ti++)
for (int64_t ti = 0; ti < otime; ti++)
{
for (i = 0; i < oheight; i++)
for (int64_t i = 0; i < oheight; i++)
{
for (j = 0; j < owidth; j++)
for (int64_t j = 0; j < owidth; j++)
{
/* compute pool range. */
int64_t tstart = ti * dT - padT;
@ -226,14 +223,11 @@ static void avg_pool3d_out_frame(
/* compute local sum: */
scalar_t sum = 0.0;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t x, y, z;
for (z = tstart; z < tend; z++)
for (int64_t z = tstart; z < tend; z++)
{
for (y = hstart; y < hend; y++)
for (int64_t y = hstart; y < hend; y++)
{
for (x = wstart; x < wend; x++)
for (int64_t x = wstart; x < wend; x++)
{
sum += *(ip + z * iwidth * iheight + y * iwidth + x);
}

View File

@ -1703,11 +1703,10 @@ static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, Tensor& infos
auto ldab = std::max<int64_t>(1, n);
auto nrhs = b.size(-1);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int info;
for (const auto i : c10::irange(batch_size)) {
const scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
int info = 0;
lapackCholeskySolve<scalar_t>(uplo, n, nrhs, const_cast<scalar_t*>(A_working_ptr), ldab, b_working_ptr, ldab, &info);
infos_data[i] = info;
if (info != 0) {

View File

@ -250,8 +250,7 @@ void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor
int liwork = -1;
scalar_t lwork_query;
value_t rwork_query;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int iwork_query;
int iwork_query = 0;
// call lapackSyevd once to get the optimal size for work data
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, vectors_data, lda, values_data,
@ -339,8 +338,7 @@ static void apply_geqrf(const Tensor& input, const Tensor& tau) {
auto n = input.size(-1);
auto lda = std::max<int64_t>(1, m);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int info;
int info = 0;
// Run once, first to get the optimum work size.
// Since we deal with batches of matrices with the same dimensions, doing this outside
// the loop saves (batch_size - 1) workspace queries which would provide the same result
@ -410,8 +408,7 @@ inline void apply_orgqr(Tensor& self, const Tensor& tau) {
auto n = self.size(-1);
auto k = tau.size(-1);
auto lda = std::max<int64_t>(1, m);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int info;
int info = 0;
// LAPACK's requirement
TORCH_INTERNAL_ASSERT(m >= n);

View File

@ -71,8 +71,7 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
// special case copy where tensor is contiguous and src is a transposed matrix
// This can be generalized to most copies, but it's trickier
void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t BLOCK_SZ;
int64_t BLOCK_SZ = 0;
if (self.scalar_type() == kByte) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
BLOCK_SZ = 120;

View File

@ -1523,8 +1523,7 @@ void _embedding_bag_dense_backward_cpu_sum_mean(
auto offset2bag = offset2bag_.index_select(0, ind_sort);
std::optional<Tensor> per_sample_weights;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const scalar_t* per_sample_weights_data;
const scalar_t* per_sample_weights_data = nullptr;
std::optional<int64_t> per_sample_weights_stride;
if (per_sample_weights_.defined()) {
per_sample_weights = per_sample_weights_.index_select(0, ind_sort);

View File

@ -151,17 +151,14 @@ static void fractional_max_pool2d_out_single_batch_frame(
randomSamplesForPlane[1], inputH, outputH, poolSizeH);
/* loop over output */
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int h, w;
const scalar_t* inputForPlane = input + plane * inputW * inputH;
scalar_t* outputForPlane = output + plane * outputW * outputH;
int64_t* indicesForPlane = indices + plane * outputW * outputH;
for (h = 0; h < outputH; ++h) {
for (int h = 0; h < outputH; ++h) {
int inputHStart = sequenceH[h];
for (w = 0; w < outputW; ++w) {
for (int w = 0; w < outputW; ++w) {
int inputWStart = sequenceW[w];
int h2 = inputHStart, w2 = inputWStart;

View File

@ -124,20 +124,18 @@ static void fractional_max_pool3d_out_single_batch_frame(
randomSamplesForPlane[2], inputW, outputW, poolSizeW);
/* loop over output */
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t t, h, w;
const scalar_t* inputForPlane = input + plane * inputT * inputH * inputW;
scalar_t* outputForPlane = output + plane * outputT * outputH * outputW;
int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
for (t = 0; t < outputT; ++t) {
for (int64_t t = 0; t < outputT; ++t) {
int64_t inputTStart = sequenceT[t];
for (h = 0; h < outputH; ++h) {
for (int64_t h = 0; h < outputH; ++h) {
int64_t inputHStart = sequenceH[h];
for (w = 0; w < outputW; ++w) {
for (int64_t w = 0; w < outputW; ++w) {
int64_t inputWStart = sequenceW[w];
int64_t t2 = inputTStart, h2 = inputHStart, w2 = inputWStart;
@ -274,11 +272,9 @@ static void fractional_max_pool3d_backward_out_single_batch_frame(
plane * outputT * outputH * outputW;
const int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t h, w, t;
for (t = 0; t < outputT; ++t) {
for (h = 0; h < outputH; ++h) {
for (w = 0; w < outputW; ++w) {
for (int64_t t = 0; t < outputT; ++t) {
for (int64_t h = 0; h < outputH; ++h) {
for (int64_t w = 0; w < outputW; ++w) {
int64_t outputIndex = t * outputH * outputW + h * outputW + w;
int64_t index = indicesForPlane[outputIndex];
AT_ASSERT(index >= 0 && index < inputT * inputH * inputW);

View File

@ -777,8 +777,7 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
scalar_t y = grid_ptr_NHW[grid_sCoor];
// multipliers for gradients on ix, iy
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
scalar_t gix_mult, giy_mult;
scalar_t gix_mult{}, giy_mult{};
scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);

View File

@ -70,8 +70,7 @@ std::tuple<Tensor, Tensor, size_t, std::vector<int64_t>> ctc_loss_allocate_outpu
TORCH_CHECK((int64_t) input_lengths.size() == batch_size, "input_lengths must be of size batch_size");
TORCH_CHECK((int64_t) target_lengths.size() == batch_size, "target_lengths must be of size batch_size");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t tg_target_stride;
size_t tg_target_stride = 0;
int64_t max_target_length = 0;
std::vector<int64_t> tg_batch_offsets(batch_size);
if (targets.dim() == 1) { // concatenated targets
@ -240,10 +239,8 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_
Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // at this point, this is log of empty sum
// The admin bits. We don't do much checking and assume that the forward did.
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t tg_target_stride;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t max_target_length;
int64_t tg_target_stride = 0;
int64_t max_target_length = 0;
std::vector<int64_t> tg_batch_offsets(batch_size);
if (targets.dim() == 1) { // concatenated targets

View File

@ -117,8 +117,7 @@ static void multilabel_margin_loss_forward_out_cpu_template(
#ifndef STRIP_ERROR_MESSAGES
auto target_arg = TensorArg(target, "target", 2);
#endif
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
int64_t nframe = 0, dim = 0;
const int64_t ndims = input.dim();
multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
@ -230,8 +229,7 @@ static void multilabel_margin_loss_backward_out_cpu_template(
const Tensor& target,
int64_t reduction,
const Tensor& is_target) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
int64_t nframe = 0, dim = 0;
CheckedFrom c = "multilabel_margin_loss_backward_cpu_template";
auto target_arg = TensorArg(target, "target", 3);
auto is_target_arg = TensorArg(is_target, "is_target", 5);

View File

@ -104,8 +104,7 @@ void multi_margin_loss_out_cpu_template(
const Scalar& margin,
const std::optional<Tensor>& weight,
int64_t reduction) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
int64_t nframe = 0, dim = 0;
const auto ndims = input.dim();
TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
@ -216,8 +215,7 @@ void multi_margin_loss_backward_out_cpu_template(
const Scalar& margin,
const Tensor& weight,
int64_t reduction) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
int64_t nframe = 0, dim = 0;
const auto ndims = input.dim();
TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

View File

@ -668,8 +668,7 @@ void slow_conv_transpose3d_acc_grad_parameters_cpu(
output_padding_height,
1);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t n_output_plane;
int64_t n_output_plane = 0;
if (grad_weight.defined()) {
n_output_plane = grad_weight.size(1);
} else if (grad_bias.defined()) {

View File

@ -78,10 +78,8 @@ Tensor fbgemm_linear_int8_weight_fp32_activation(
TORCH_CHECK(weight_zero_point.isIntegral(false));
// Calculate statistics for quantization of the input Tensor
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float x_min;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float x_max;
float x_min = std::numeric_limits<float>::quiet_NaN();
float x_max = std::numeric_limits<float>::quiet_NaN();
fbgemm::FindMinMax(
/*m=*/input_ptr,
/*min=*/&x_min,
@ -236,10 +234,8 @@ std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
const Tensor weight_contig = weight.contiguous();
// Calculate weight statistics
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float w_min;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float w_max;
float w_min = std::numeric_limits<float>::quiet_NaN();
float w_max = std::numeric_limits<float>::quiet_NaN();
fbgemm::FindMinMax(
/*m=*/weight_contig.data_ptr<float>(),
/*min=*/&w_min,

View File

@ -73,8 +73,6 @@ Tensor& _sobol_engine_ff_(Tensor& quasi, int64_t n, const Tensor& sobolstate,
"quasi needs to be of type ", at::kLong);
// We deal with `data` and `strides` due to performance issues.
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t l;
int64_t* quasi_data = quasi.data_ptr<int64_t>();
int64_t* sobolstate_data = sobolstate.data_ptr<int64_t>();
@ -82,7 +80,7 @@ Tensor& _sobol_engine_ff_(Tensor& quasi, int64_t n, const Tensor& sobolstate,
int64_t sobolstate_row_stride = sobolstate.stride(0), sobolstate_col_stride = sobolstate.stride(1);
for (int64_t i = 0; i < n; i++, num_generated++) {
l = rightmost_zero(num_generated);
auto l = rightmost_zero(num_generated);
for (const auto j : c10::irange(dimension)) {
quasi_data[j * quasi_stride] ^= sobolstate_data[j * sobolstate_row_stride + l * sobolstate_col_stride];
}

View File

@ -134,11 +134,8 @@ void quick_select_template(
int64_t k,
Comp gt_or_nan,
Fn swap_fn) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t P, L, R, i, j;
scalar_t piv;
L = 0;
R = arr.size(0) - 1;
int64_t L = 0;
int64_t R = arr.size(0) - 1;
do {
if (R <= L) // One element only
@ -152,7 +149,7 @@ void quick_select_template(
}
// Use median of three for pivot choice
P = L + (R - L) / 2;
auto P = L + (R - L) / 2;
swap_fn(P, L + 1);
if (gt_or_nan(arr[L + 1], arr[R])) {
swap_fn(L + 1, R);
@ -164,9 +161,9 @@ void quick_select_template(
swap_fn(L + 1, L);
}
i = L + 1;
j = R;
piv = arr[L];
auto i = L + 1;
auto j = R;
auto piv = arr[L];
do {
do
i++;

View File

@ -97,13 +97,11 @@ Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) {
auto csr_accessor = csr.accessor<int64_t, 1>();
// Convert the sparse matrix to CSR format
at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t h, hp0, hp1;
for (const auto i : c10::irange(start, end)) {
hp0 = indices[i];
hp1 = (i+1 == nnz) ? dim : indices[i+1];
auto hp0 = indices[i];
auto hp1 = (i+1 == nnz) ? dim : indices[i+1];
if (hp0 != hp1) {
for (h = hp0; h < hp1; h++) {
for (int64_t h = hp0; h < hp1; h++) {
csr_accessor[h+1] = i+1;
}
}

View File

@ -1234,8 +1234,7 @@ Tensor diagonal(
auto outnames = namedinference::compute_diagonal_outnames(self, dim1, dim2);
NoNamesGuard no_names_guard;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t diag_size;
int64_t diag_size = 0;
int64_t storage_offset = self.storage_offset();
// compute storage offset and size for the diagonal
// for positive values of offset (above the main diagonal)

View File

@ -93,8 +93,7 @@ void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int
auto self_col_stride = self.stride(-1);
auto result_data = result.data_ptr<scalar_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t result_stride, result_row_stride, result_col_stride;
int64_t result_stride = 0, result_row_stride = 0, result_col_stride = 0;
if (result_data != self_data) {
result_stride = (result.dim() > 2 && result.stride(-3) > 0) ? result.stride(-3) : 1;
result_row_stride = result.stride(-2);

View File

@ -227,7 +227,7 @@ inline void _vec_host_softmax_backward_lastdim(
scalar_t* grad_input_data = grad_input_data_base + i * dim_size;
const scalar_t* grad_data = grad_data_base + i * dim_size;
const scalar_t* output_data = output_data_base + i * dim_size;
if (log_softmax) {
if constexpr (log_softmax) {
auto sum = vec::reduce_all<scalar_t>(
[](Vec& x, Vec& y) { return x + y; }, grad_data, dim_size);
vec::map2(

View File

@ -50,12 +50,8 @@ static void unfolded2d_acc(
int64_t output_width) {
at::parallel_for(0, n_input_plane, 0, [&](int64_t start, int64_t end) {
for (const auto nip : c10::irange(start, end)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t kw, kh, y, x;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t ix, iy;
for (kh = 0; kh < kH; kh++) {
for (kw = 0; kw < kW; kw++) {
for (int64_t kh = 0; kh < kH; kh++) {
for (int64_t kw = 0; kw < kW; kw++) {
scalar_t* src = finput_data +
nip * ((size_t)kH * kW * output_height * output_width) +
kh * ((size_t)kW * output_height * output_width) +
@ -63,16 +59,14 @@ static void unfolded2d_acc(
scalar_t* dst =
input_data + nip * ((size_t)input_height * input_width);
if (padW > 0 || padH > 0) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t lpad, rpad;
for (y = 0; y < output_height; y++) {
iy = (int64_t)y * dH - padH + kh;
for (int64_t y = 0; y < output_height; y++) {
auto iy = y * dH - padH + kh;
if (iy < 0 || iy >= input_height) {
} else {
if (dW == 1) {
ix = 0 - padW + kw;
lpad = std::max<int64_t>(0, padW - kw);
rpad = std::max<int64_t>(0, padW - (kW - kw - 1));
auto ix = 0 - padW + kw;
auto lpad = std::max<int64_t>(0, padW - kw);
auto rpad = std::max<int64_t>(0, padW - (kW - kw - 1));
scalar_t* dst_slice =
dst + (size_t)iy * input_width + ix + lpad;
cadd(
@ -81,8 +75,8 @@ static void unfolded2d_acc(
src + (size_t)y * output_width + lpad,
output_width - lpad - rpad);
} else {
for (x = 0; x < output_width; x++) {
ix = (int64_t)x * dW - padW + kw;
for (int64_t x = 0; x < output_width; x++) {
auto ix = x * dW - padW + kw;
if (ix < 0 || ix >= input_width) {
} else {
scalar_t* dst_slice = dst + (size_t)iy * input_width + ix;
@ -93,9 +87,9 @@ static void unfolded2d_acc(
}
}
} else {
for (y = 0; y < output_height; y++) {
iy = (int64_t)y * dH + kh;
ix = 0 + kw;
for (int64_t y = 0; y < output_height; y++) {
auto iy = y * dH + kh;
auto ix = 0 + kw;
if (dW == 1) {
scalar_t* dst_slice = dst + (size_t)iy * input_width + ix;
cadd(
@ -104,7 +98,7 @@ static void unfolded2d_acc(
src + (size_t)y * output_width,
output_width);
} else {
for (x = 0; x < output_width; x++) {
for (int64_t x = 0; x < output_width; x++) {
scalar_t* dst_slice =
dst + (size_t)iy * input_width + ix + x * dW;
*dst_slice = *dst_slice + src[(size_t)y * output_width + x];
@ -248,10 +242,6 @@ static void unfolded2d_copy(
int64_t rest = k % (kH * kW);
int64_t kh = rest / kW;
int64_t kw = rest % kW;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t x, y;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t ix, iy;
scalar_t* dst = finput_data +
nip * ((size_t)kH * kW * output_height * output_width) +
kh * ((size_t)kW * output_height * output_width) +
@ -259,10 +249,8 @@ static void unfolded2d_copy(
const scalar_t* src =
input_data + nip * ((size_t)input_height * input_width);
if (padW > 0 || padH > 0) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t lpad, rpad;
for (y = 0; y < output_height; y++) {
iy = (int64_t)y * dH - padH + kh;
for (int64_t y = 0; y < output_height; y++) {
auto iy = y * dH - padH + kh;
if (iy < 0 || iy >= input_height) {
memset(
dst + (size_t)y * output_width,
@ -270,9 +258,9 @@ static void unfolded2d_copy(
sizeof(scalar_t) * output_width);
} else {
if (dW == 1) {
ix = 0 - padW + kw;
lpad = std::max<int64_t>(0, padW - kw);
rpad = std::max<int64_t>(0, padW - (kW - kw - 1));
auto ix = 0 - padW + kw;
auto lpad = std::max<int64_t>(0, padW - kw);
auto rpad = std::max<int64_t>(0, padW - (kW - kw - 1));
if (output_width - rpad - lpad <= 0) {
memset(
dst + (size_t)y * output_width,
@ -295,8 +283,8 @@ static void unfolded2d_copy(
sizeof(scalar_t) * rpad);
}
} else {
for (x = 0; x < output_width; x++) {
ix = (int64_t)x * dW - padW + kw;
for (int64_t x = 0; x < output_width; x++) {
auto ix = x * dW - padW + kw;
if (ix < 0 || ix >= input_width)
memset(
dst + (size_t)y * output_width + x,
@ -312,16 +300,16 @@ static void unfolded2d_copy(
}
}
} else {
for (y = 0; y < output_height; y++) {
iy = (int64_t)y * dH + kh;
ix = 0 + kw;
for (int64_t y = 0; y < output_height; y++) {
auto iy = y * dH + kh;
auto ix = 0 + kw;
if (dW == 1)
memcpy(
dst + (size_t)y * output_width,
src + (size_t)iy * input_width + ix,
sizeof(scalar_t) * output_width);
else {
for (x = 0; x < output_width; x++)
for (int64_t x = 0; x < output_width; x++)
memcpy(
dst + (size_t)y * output_width + x,
src + (size_t)iy * input_width + ix + (int64_t)x * dW,

View File

@ -462,12 +462,11 @@ void cpu_upsample_linear_backward(
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales[0]);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t iw0, iw1;
opmath_t w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
for (const auto ow : c10::irange(output_width)) {
int64_t iw0 = 0, iw1 = 0;
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
opmath_t grad_output_value = grad_output_data[c * output_slice_size + ow];
@ -497,12 +496,11 @@ void cpu_upsample_linear_backward(
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales[1]);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t ih0, ih1, iw0, iw1;
opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
for (const auto oh : c10::irange(output_height)) {
int64_t ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
compute_source_index_and_lambda(
ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
for (const auto ow : c10::irange(output_width)) {
@ -540,12 +538,11 @@ void cpu_upsample_linear_backward(
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales[2]);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t id0, id1, ih0, ih1, iw0, iw1;
opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
for (const auto od : c10::irange(output_depth)) {
int64_t id0 = 0, id1 = 0, ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
compute_source_index_and_lambda(
id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
for (const auto oh : c10::irange(output_height)) {
@ -644,12 +641,11 @@ void cpu_upsample_linear_backward_channels_last(
return acc_data_ptr + offset + (h * input_width + w) * channels;
};
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t ih0, ih1, iw0, iw1;
opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto n : c10::irange(begin, end)) {
int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
for (const auto oh : c10::irange(output_height)) {
int64_t ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
compute_source_index_and_lambda(
ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
for (const auto ow : c10::irange(output_width)) {
@ -693,12 +689,11 @@ void cpu_upsample_linear_backward_channels_last(
return acc_data_ptr + offset + (d * input_height * input_width + h * input_width + w) * channels;
};
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t id0, id1, ih0, ih1, iw0, iw1;
opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto n : c10::irange(begin, end)) {
int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
for (const auto od : c10::irange(output_depth)) {
int64_t id0 = 0, id1 = 0, ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
compute_source_index_and_lambda(
id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
for (const auto oh : c10::irange(output_height)) {

View File

@ -923,7 +923,6 @@ void codegenOutputQuery(
// TODO: try making the CUcontext thread local to see if that improves performance - why is this slow?
void initializeCudaContext() {
// lazily construct context if non-existing yet;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CUcontext pctx = nullptr;
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx));
if (!pctx) {
@ -1594,7 +1593,6 @@ NvrtcFunction jit_pwise_function(
const std::string compute = std::string("--gpu-architecture=") +
(compile_to_sass ? "sm_" : "compute_") + std::to_string(cuda_major) +
std::to_string(cuda_minor);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const char*> args = {
"--std=c++17", compute.c_str(), "-default-device"};
#endif

View File

@ -620,8 +620,7 @@ Tensor& multinomial_out_mps(const Tensor& self,
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(is_valid.to<bool>(), "probability tensor contains either `inf`, `nan` or element < 0");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool zero_prob_condition;
bool zero_prob_condition = false;
if (self.dim() == 1) {
zero_prob_condition = (self.sum() == 0).item().to<bool>();
} else {

View File

@ -105,11 +105,9 @@ static void adaptive_avg_pool_single_out_frame(
/* compute local average: */
int64_t sum = 0;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int id, ih, iw;
for (id = 0; id < kD; id++) {
for (ih = 0; ih < kH; ih++) {
for (iw = 0; iw < kW; iw++) {
for (int id = 0; id < kD; id++) {
for (int ih = 0; ih < kH; ih++) {
for (int iw = 0; iw < kW; iw++) {
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
int64_t val = (ip +
id * istrideD +

View File

@ -56,8 +56,6 @@ static void avg_pool2d_out_frame(
at::parallel_for(0, nInputPlane, 0, [&](int64_t start, int64_t end) {
for (const auto k : c10::irange(start, end)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t xx, yy;
/* For all output pixels... */
scalar_t* ptr_output = output_data + k * outputWidth * outputHeight;
const scalar_t* ptr_input = input_data + k * inputWidth * inputHeight;
@ -65,8 +63,8 @@ static void avg_pool2d_out_frame(
std::numeric_limits<typename scalar_t::underlying>::lowest();
auto maximum = std::numeric_limits<typename scalar_t::underlying>::max();
for (yy = 0; yy < outputHeight; yy++) {
for (xx = 0; xx < outputWidth; xx++) {
for (int64_t yy = 0; yy < outputHeight; yy++) {
for (int64_t xx = 0; xx < outputWidth; xx++) {
/* Compute the mean of the input image... */
int64_t hstart = yy * dH - padH;
int64_t wstart = xx * dW - padW;
@ -81,8 +79,7 @@ static void avg_pool2d_out_frame(
int sum_int = 0;
ptr_output->val_ = 0;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t divide_factor;
int64_t divide_factor = 0;
int64_t size = (hend - hstart) * (wend - wstart);
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
@ -94,10 +91,8 @@ static void avg_pool2d_out_frame(
}
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t kx, ky;
for (ky = hstart; ky < hend; ky++) {
for (kx = wstart; kx < wend; kx++)
for (int64_t ky = hstart; ky < hend; ky++) {
for (int64_t kx = wstart; kx < wend; kx++)
sum_int += (ptr_input + ky * inputWidth + kx)->val_;
}
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
@ -185,7 +180,6 @@ Tensor q_avg_pool2d(
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
auto [kW, kH] = get_kernel(kernel_size);
auto [dW, dH] = get_stride(stride, kW, kH);
auto [padW, padH] = get_padding(padding);

View File

@ -59,11 +59,9 @@ void spatial_dilated_max_pooling(
T* oData) { // output arrays (data and max-index)
at::parallel_for(0, iC, 0, [&](int64_t start, int64_t end) {
for (const auto p : c10::irange(start, end)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t row, col;
const T* i_p = iData + p * iW * iH;
for (row = 0; row < oH; ++row) {
for (col = 0; col < oW; ++col) {
for (int64_t row = 0; row < oH; ++row) {
for (int64_t col = 0; col < oW; ++col) {
int64_t h_start = row * sH - pH;
int64_t w_start = col * sW - pW;
int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
@ -79,10 +77,8 @@ void spatial_dilated_max_pooling(
// local max
auto max_val = std::numeric_limits<typename T::underlying>::lowest();
int64_t tcntr = 0; // center point
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t x, y;
for (y = h_start; y < h_end; y += dH) {
for (x = w_start; x < w_end; x += dW) {
for (int64_t y = h_start; y < h_end; y += dH) {
for (int64_t x = w_start; x < w_end; x += dW) {
tcntr = y * iW + x;
auto val = (i_p + tcntr)->val_;
if (val > max_val) {
@ -161,11 +157,9 @@ void spatial_dilated_max_pooling3d(
// local max
auto max_val = std::numeric_limits<typename T::underlying>::lowest();
int64_t tcntr = 0; // center point
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t t, x, y;
for (t = t_start; t < t_end; t += dT) {
for (y = h_start; y < h_end; y += dH) {
for (x = w_start; x < w_end; x += dW) {
for (int64_t t = t_start; t < t_end; t += dT) {
for (int64_t y = h_start; y < h_end; y += dH) {
for (int64_t x = w_start; x < w_end; x += dW) {
tcntr = t * iH * iW + y * iW + x;
auto val = (i_p + tcntr)->val_;
if (val > max_val) {

View File

@ -981,8 +981,7 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
output_zero_point,
channels_last);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
pytorch_qnnp_status run_status;
pytorch_qnnp_status run_status{};
if (transpose()) {
run_status = qnnpack::qnnpackDeConv(
convolution_op.get(),

View File

@ -100,9 +100,8 @@ at::Tensor& embedding_lookup_fallback_impl(
if (per_sample_weights_.has_value()) {
weight_val = per_sample_weights_data[current];
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float scale, bias;
if (BIT_RATE == 8) {
float scale = std::numeric_limits<float>::quiet_NaN(), bias = std::numeric_limits<float>::quiet_NaN();
if constexpr (BIT_RATE == 8) {
const uint8_t* scale_bias =
weight_data + (idx + 1) * weight_size - 2 * sizeof(float);
uint32_t scale_val_int32 = 0;
@ -1077,6 +1076,8 @@ class QEmbedding final {
const auto offsets_size = indices.numel();
at::Tensor offsets = at::arange(0, offsets_size, indices.scalar_type());
at::Tensor output;
static_assert(bit_rate==4 || bit_rate ==8,
"Currently only support 8-bit embedding quantization");
if (bit_rate == 8) {
return packed_weight->embeddingbag_byte(
indices,
@ -1095,10 +1096,6 @@ class QEmbedding final {
std::nullopt,
false,
true);
} else {
TORCH_INTERNAL_ASSERT(
false,
"Currently only support 8-bit embedding quantization");
}
return output;
}

View File

@ -47,8 +47,7 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
at::Tensor weight_contig =
qweight.contiguous(qweight.suggest_memory_format());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int bit_width, scale_bias_bytes;
int bit_width = 0, scale_bias_bytes = 0;
uint8_t* weight_data = static_cast<uint8_t*>(weight_contig.data_ptr());
if (qweight.scalar_type() == c10::kQUInt8) {
bit_width = 8;
@ -436,8 +435,7 @@ Tensor _qembeddingbag_nbit_prepack_helper(
const float* input_row = weight_data + row * embedding_cols;
std::uint8_t* output_row = output_data + row * output_columns;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float Xmin, Xmax;
float Xmin = std::numeric_limits<float>::quiet_NaN(), Xmax = std::numeric_limits<float>::quiet_NaN();
if (optimized_qparams) {
auto [xmax_tensor, xmin_tensor] = at::choose_qparams_optimized(
float_weight[row], embedding_cols, nbins, ratio, bit_width);

View File

@ -26,8 +26,7 @@ at::Tensor PackedEmbeddingBagWeight::unpack() {
if (bit_rate_ == 8 || bit_rate_ == 4) {
const auto input_rows = packed_weight.size(0);
const auto input_columns = packed_weight.size(1);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int scale_bias_bytes;
int scale_bias_bytes = 0;
const auto num_elem_per_byte = 8 / bit_rate_;
if (bit_rate_ == 8) {
// The last 2 values are used to store the FP32 scale and zero_point
@ -51,8 +50,7 @@ at::Tensor PackedEmbeddingBagWeight::unpack() {
w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kFloat));
auto output_columns = output_shape[1];
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint8_t* output_data;
uint8_t* output_data = nullptr;
// Allocate output weight tensor based on the bit_width
if (bit_rate_ == 8) {

View File

@ -67,8 +67,7 @@ at::Tensor PackedLinearWeight::apply_dynamic_impl(
std::to_string(K));
// Calculate statistics for quantization of the input Tensor
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float x_min, x_max;
float x_min = std::numeric_limits<float>::quiet_NaN(), x_max = std::numeric_limits<float>::quiet_NaN();
fbgemm::FindMinMax(
/*m=*/input_ptr,
/*min=*/&x_min,
@ -274,18 +273,14 @@ at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(
// Calculate statistics for quantization of input Tensor
// TODO: optimized kernel
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float x_min;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float x_max;
float x_min = 0;
float x_max = 0;
if (input.numel() > 0) {
x_min = input_contig.min().item<float>();
x_max = input_contig.max().item<float>();
} else {
// On empty input, no output data will be generated,
// so use arbitrary qparams.
x_min = 0;
x_max = 0;
}
auto q_params = quant_utils::ChooseQuantizationParams(

View File

@ -220,10 +220,8 @@ Tensor _mul_scalar_out(Tensor& out, const Tensor& self, const Scalar& other) {
double self_scale = self.q_scale();
double other_val = other.toDouble();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double scale_prime;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t zero_point_prime;
double scale_prime = 0;
int64_t zero_point_prime = 0;
AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qmul_scalar", [&]() {
// NOLINTNEXTLINE(bugprone-signed-char-misuse)

View File

@ -201,8 +201,7 @@ ContextConv2D create(
xnn_operator_t convolution_op{};
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
xnn_status create_status;
xnn_status create_status{};
std::array<int64_t, 4> weight_sizes{};
if (transposed) {
@ -323,8 +322,7 @@ Tensor run(
padded_input_nhwc.opt_names());
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
xnn_status setup_status;
xnn_status setup_status{};
/*
* Input Pointer Caching:

View File

@ -112,8 +112,7 @@ void NnapiCompilation::init2(
void NnapiCompilation::run(
std::vector<at::Tensor> inputs,
std::vector<at::Tensor> outputs) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
ANeuralNetworksExecution* execution;
ANeuralNetworksExecution* execution = nullptr;
check_nnapi->Execution_create(compilation_.get(), &execution);
ExecutionPtr execution_unique_ptr(execution);
@ -150,8 +149,7 @@ void NnapiCompilation::run(
// TODO: Maybe skip this for fixed-size outputs?
for (const auto i : c10::irange(outputs.size())) {
auto& t = outputs[i];
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t rank;
uint32_t rank = 0;
check_nnapi->Execution_getOutputOperandRank(execution, i, &rank);
std::vector<uint32_t> dims(rank);
check_nnapi->Execution_getOutputOperandDimensions(execution, i, dims.data());

View File

@ -174,8 +174,7 @@ int load_nnapi_model(
uint32_t len = values[i].source_length;
const uint8_t* stored_pointer = next_pointer;
const void* value_pointer = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t value_length;
size_t value_length = 0;
switch ((SourceType)values[i].source_type) {
case SOURCE_IMMEDIATE:

View File

@ -82,8 +82,7 @@ QTensorImpl* get_qtensorimpl(const TensorBase& self) {
}
static int64_t get_sub_byte_tensor_size(IntArrayRef sizes, size_t dtype_itemsize, at::ScalarType t) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t element_per_byte;
int64_t element_per_byte = 1;
switch(t) {
case at::ScalarType::QUInt4x2:
element_per_byte = 2;

View File

@ -6,7 +6,6 @@
namespace c10 {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SymbolicShapeMeta::SymbolicShapeMeta(const SymbolicShapeMeta& other)
// Non-mutables can be accessed outside the mutex
: sizes_(other.sizes_),

View File

@ -29,7 +29,6 @@ class C10_API SizesAndStrides {
using strides_iterator = int64_t*;
using strides_const_iterator = const int64_t*;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SizesAndStrides() {
size_at_unchecked(0) = 0;
stride_at_unchecked(0) = 1;
@ -42,7 +41,6 @@ class C10_API SizesAndStrides {
}
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) {
if (C10_LIKELY(rhs.isInline())) {
copyDataInline(rhs);

View File

@ -108,8 +108,7 @@ TEST(BFloat16Math, Addition) {
// 0 | 10000001 | 10010000000000000000000 = 6.25
float expected = float_from_bytes(0, 0, 0x40c80000);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
c10::BFloat16 b;
c10::BFloat16 b{};
b.x = c10::detail::bits_from_f32(input);
b = b + b;
@ -131,8 +130,7 @@ TEST(BFloat16Math, Subtraction) {
// 0 | 10000000 | 01010000000000000000000 = 2.625
float expected = float_from_bytes(0, 0, 0x40280000);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
c10::BFloat16 b;
c10::BFloat16 b{};
b.x = c10::detail::bits_from_f32(input);
b = b - 5;
@ -140,7 +138,6 @@ TEST(BFloat16Math, Subtraction) {
EXPECT_EQ(res, expected);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BFloat16Math, NextAfterZero) {
const c10::BFloat16 zero{0};