Compile without -Wno-unused-variable (#65954)

Summary:
Delete `-Wno-unused-variable` from top level `CMakeLists.txt`
Still suppress those warnings for tests and `torch_python`

Delete number of unused variables from caffe2 code
Use `(void)var;` to suppress unused variable in range loops
Use `C10_UNUSED` for global constructors and use `constexpr` instead of `static` for global constants

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65954

Reviewed By: ngimel

Differential Revision: D31326599

Pulled By: malfet

fbshipit-source-id: 924155f1257a2ba1896c50512f615e45ca1f61f3
This commit is contained in:
Nikita Shulga 2021-10-01 17:36:36 -07:00 committed by Facebook GitHub Bot
parent 10f6294281
commit a6280ab653
56 changed files with 73 additions and 94 deletions

View File

@ -744,7 +744,6 @@ if(NOT MSVC)
string(APPEND CMAKE_CXX_FLAGS " -Wno-unknown-pragmas")
string(APPEND CMAKE_CXX_FLAGS " -Wno-sign-compare")
string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-parameter")
string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-variable")
string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-function")
string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-result")
string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-local-typedefs")

View File

@ -6,6 +6,6 @@ namespace at { namespace native {
// Since size of MKL_LONG varies on different platforms (linux 64 bit, windows
// 32 bit), we need to programmatically calculate the max.
static int64_t MKL_LONG_MAX = ((1LL << (sizeof(MKL_LONG) * 8 - 2)) - 1) * 2 + 1;
constexpr int64_t MKL_LONG_MAX = ((1LL << (sizeof(MKL_LONG) * 8 - 2)) - 1) * 2 + 1;
}} // namespace

View File

@ -173,6 +173,7 @@ TEST(TestStream, StreamPoolTest) {
if (!at::cuda::is_available()) return;
std::vector<at::cuda::CUDAStream> streams{};
for (const auto i : c10::irange(200)) {
(void)i;
streams.emplace_back(at::cuda::getStreamFromPool());
}

View File

@ -84,6 +84,8 @@ TEST(CUDAPytorchToCaffe2, Op) {
auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), caffe2::Tensor(at_tensor_a));
auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), caffe2::Tensor(at_tensor_b));
(void)c2_tensor_a;
(void)c2_tensor_b;
// Test Alias
{

View File

@ -54,8 +54,6 @@ TEST(MathKernelTest, NativeGroupNorm) {
TEST(MathKernelTest, NativeLayerNorm) {
const auto input = rand({20, 10, 10, 10});
const auto input_shape = input.sizes();
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
const auto input_ndim = input.dim();
double eps = 1e-05;
for (bool undef_weight: {true, false}) {

View File

@ -15,4 +15,8 @@ if(USE_CUDA)
main.cpp)
target_link_libraries(nvfuser_bench PRIVATE torch_library benchmark)
if(NOT MSVC)
target_compile_options(nvfuser_bench PRIVATE -Wno-unused-variable)
endif()
endif()

View File

@ -6,6 +6,9 @@ if(BUILD_TEST)
get_filename_component(test_file_name ${test_src} NAME_WE)
set(test_name "c10_${test_file_name}")
add_executable(${test_name} "${test_src}")
if(NOT MSVC)
target_compile_options(${test_name} PRIVATE -Wno-unused-variable)
endif()
target_link_libraries(${test_name} c10 gmock gtest gtest_main)
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
if(INSTALL_TEST)

View File

@ -1762,6 +1762,9 @@ if(BUILD_TEST)
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
if(NOT MSVC)
target_compile_options(${test_name} PRIVATE -Wno-unused-variable)
endif()
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
if(INSTALL_TEST)
install(TARGETS ${test_name} DESTINATION test)

View File

@ -399,8 +399,8 @@ void TensorSerializer::SerializeWithOptions(
std::vector<std::future<void>> futures;
if (tensor.numel() > chunk_size) {
futures.reserve(FLAGS_caffe2_max_tensor_serializer_threads);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for (const auto i : c10::irange(FLAGS_caffe2_max_tensor_serializer_threads)) {
(void)i;
futures.emplace_back(std::async(std::launch::async, task));
}
}

View File

@ -19,7 +19,6 @@ class CastOp : public Operator<Context> {
: Operator<Context>(operator_def, ws) {
const ArgumentHelper helper(operator_def);
TensorProto_DataType to = cast::GetCastDataType(helper, "to");
TensorProto_DataType from = cast::GetCastDataType(helper, "from_type");
SetBody(to);
}

View File

@ -576,7 +576,9 @@ bool CudnnConvOp::DoRunWithType() {
return true;
}
#if !CUDNN_VERSION_MIN(7, 0, 0)
int group_offset_filter = filter.numel() / group_;
#endif
// Set up the cudnn algorithms & workspace if necessary
bool input_changed = (X.sizes() != cudnn_input_dims_);
@ -951,7 +953,9 @@ bool CudnnConvGradientOp::DoRunWithType() {
"If you set group, the number of output channels should be divisible "
"by group.");
#if !CUDNN_VERSION_MIN(7, 0, 0)
int group_offset_filter = filter.numel() / group_;
#endif
if (kernel_.size() == 1) {
ConvPoolOpBase<CUDAContext>::ComputePads({H});
} else if (kernel_.size() == 2) {

View File

@ -25,11 +25,10 @@ float compress_uniform_simplified_(
float inverse_scale = 1.0f / scale;
float norm = 0.0f;
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
constexpr int VLEN = 8;
int i = 0;
#ifdef __AVX__
constexpr int VLEN = 8;
// vectorized loop
__m256 norm_v = _mm256_setzero_ps();
for (; i < N / VLEN * VLEN; i += VLEN) {

View File

@ -72,7 +72,6 @@ class FloatToFusedNBitFakeRowwiseQuantizedOp final
CAFFE_THROW("Unsupported data type");
}
bool use_openmp = GREEDY;
#ifdef _OPENMP
vector<float> tmp_vec(input_columns * (GREEDY ? omp_get_max_threads() : 1));
#else

View File

@ -30,7 +30,6 @@ class GatherFused8BitRowwiseOp : public Operator<Context> {
const std::vector<int64_t> shape = {indices.size(0), data.size(1) - 8};
auto* output = Output(0, shape, at::dtype<float>());
int block_size = shape[1];
auto block_bytesize = data.size_from_dim(1) * data.dtype().itemsize();
int N = indices.numel();

View File

@ -133,8 +133,7 @@ std::vector<int> soft_nms_cpu_upright(
// Find proposal with max score among remaining proposals
int max_pos;
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
auto max_score = GetSubArray(*out_scores, pending).maxCoeff(&max_pos);
GetSubArray(*out_scores, pending).maxCoeff(&max_pos);
int i = pending[max_pos];
keep.push_back(i);
@ -635,8 +634,7 @@ std::vector<int> soft_nms_cpu_rotated(
// Find proposal with max score among remaining proposals
int max_pos;
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
auto max_score = GetSubArray(*out_scores, pending).maxCoeff(&max_pos);
GetSubArray(*out_scores, pending).maxCoeff(&max_pos);
int i = pending[max_pos];
keep.push_back(i);

View File

@ -458,8 +458,6 @@ bool HuffmanTreeHierarchyOp<T, Context>::RunOnDevice() {
std::vector<int> labelIndices;
labelIndices.resize(num_classes_);
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
int current_node_index = 0;
for (int i = 0; i < num_classes_; ++i) {
Node node(i, labelCounts[i]);
nodes.push(node);

View File

@ -132,8 +132,6 @@ class LayerNormGradientOp final : public Operator<Context> {
template <typename T>
bool DoRunWithType() {
const auto& dY = Input(0);
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
const auto& Y = Input(1);
const auto& mean = Input(2);
const auto& sigma = Input(3);
const auto& X = Input(4);

View File

@ -53,7 +53,6 @@ class SparseLengths8BitsRowwiseOp : public Operator<Context> {
"the second dim of scale_bias has to be equal to 2");
CAFFE_ENFORCE_EQ(1, indicesInput.dim(), "INDICES must be a vector");
const IndexType* indices = indicesInput.template data<IndexType>();
int64_t dataToReduceSize = indicesInput.size(0);
const int* lengths = lengthsInput.template data<int>();
vector<int64_t> shape = dataInput.sizes().vec();

View File

@ -193,8 +193,6 @@ bool CuDNNLRNGradientOp::DoRunWithType() {
bool CuDNNLRNGradientOp::RunOnDevice() {
// dispatch based on contents of tensor(s)
const auto& X = Input(0);
const auto& Y = Input(1);
const auto& dY = Input(2);
auto* dX = Output(0);

View File

@ -55,8 +55,10 @@ class Int8AddOp final : public Operator<CPUContext> {
initQNNPACK();
#if !defined(FBCODE_CAFFE2) && defined(USE_INTERNAL_PTHREADPOOL_IMPL)
pthreadpool_t threadpool =
reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
#endif
if (this->qnnpackOperator_ == nullptr) {
const qnnp_status createStatus = qnnp_create_add_nc_q8(

View File

@ -47,7 +47,6 @@ class Int8ChannelShuffleOp final : public ConvPoolOpBase<CPUContext> {
const auto C = X.t.dim32(3);
const auto G = this->group_;
CAFFE_ENFORCE(C % G == 0, "");
const auto B = X.t.numel() / C;
initQNNPACK();

View File

@ -60,8 +60,10 @@ class Int8ConvOp final : public ConvPoolOpBase<CPUContext> {
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
initQNNPACK();
#if !defined(FBCODE_CAFFE2) && defined(USE_INTERNAL_PTHREADPOOL_IMPL)
pthreadpool_t threadpool =
reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
#endif
if (this->qnnpackObject_ == nullptr) {
CAFFE_ENFORCE(

View File

@ -39,17 +39,12 @@ class Int8ConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
const auto& W = Inputs()[1]->template Get<Int8TensorCPU>();
const auto& B = Inputs()[2]->template Get<Int8TensorCPU>();
auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
const auto X_offset = -X.zero_point;
const auto W_offset = -W.zero_point;
const int32_t Y_offset =
this->template GetSingleArgument<int>("Y_zero_point", 0);
double Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
Y->scale = Y_scale;
Y->zero_point = Y_offset;
const auto N = X.t.size(0);
const auto IH = X.t.size(1);
const auto IW = X.t.size(2);
const auto IC = X.t.size(3);
CHECK_EQ(IC, W.t.size(0));
@ -64,8 +59,10 @@ class Int8ConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
initQNNPACK();
#if !defined(FBCODE_CAFFE2) && defined(USE_INTERNAL_PTHREADPOOL_IMPL)
pthreadpool_t threadpool =
reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
#endif
if (this->qnnpackObject_ == nullptr) {
const qnnp_status createStatus = qnnp_create_deconvolution2d_nhwc_q8(

View File

@ -47,8 +47,10 @@ class Int8FCOp final : public Operator<CPUContext> {
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
initQNNPACK();
#if !defined(FBCODE_CAFFE2) && defined(USE_INTERNAL_PTHREADPOOL_IMPL)
pthreadpool_t threadpool =
reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
#endif
if (this->qnnpackObject_ == nullptr) {
const qnnp_status createStatus = qnnp_create_fully_connected_nc_q8(

View File

@ -19,10 +19,10 @@ void Int8Quantize(
const int64_t N,
const float Y_scale,
const int32_t Y_offset) {
const float inv_scale = 1.0f / Y_scale;
uint32_t i = 0;
#ifdef INT8_NEON_SIMD
const float inv_scale = 1.0f / Y_scale;
const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
// magic float and magic int to take care of rounding
// int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000

View File

@ -38,7 +38,6 @@ class Int8SoftmaxOp final : public Operator<CPUContext> {
* in-place, we may overwrite these parameters later, when we set
* quantization parameters for output tensor.
*/
const uint8_t X_zero_point = X.zero_point;
const float X_scale = X.scale;
Y->scale = Y_scale;

View File

@ -141,8 +141,6 @@ class TextFileReaderReadOp : public Operator<CPUContext> {
(field > 0 && token.startDelimId == 1),
"Invalid number of columns at row ",
instance->rowsRead + rowsRead + 1);
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
const auto& meta = instance->fieldMetas[field];
char*& data = datas[field];
convert(
(TensorProto_DataType)instance->fieldTypes[field],

View File

@ -686,8 +686,6 @@ class ScatterAssignOp : public Operator<Context> {
const auto dataType = TypeMetaToDataType(data.dtype());
const auto slicesType = TypeMetaToDataType(slices.dtype());
const auto indicesType = TypeMetaToDataType(indices.dtype());
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
auto* output = Output(0);
auto runner = GetRunner(dataType, slicesType, indicesType);
(this->*runner)();

View File

@ -57,8 +57,6 @@ bool WeightedSampleOp<float, CPUContext>::RunOnDevice() {
}
}
} else {
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
auto* out_idx = Output(0, {0}, at::dtype<int>());
if (OutputSize() == 2) {
auto* out_value = Output(1, {0}, at::dtype<float>());
out_value->template mutable_data<float>();

View File

@ -436,13 +436,6 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
op.type() == "SparseLengthsWeightedSum4BitRowwiseSparse" ||
op.type() == "SparseLengthsSum4BitRowwiseSparse");
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
const bool isSparse =
(op.type() == "SparseLengthsSum4BitRowwiseSparse" ||
op.type() == "SparseLengthsWeightedSum4BitRowwiseSparse" ||
op.type() == "SparseLengthsSum8BitRowwiseSparse" ||
op.type() == "SparseLengthsWeightedSum8BitRowwiseSparse");
if (weight) {
CAFFE_ENFORCE_GE(
op.input_size(),

View File

@ -533,8 +533,6 @@ bool fuseActivation(repr::NNModule* nn, caffe2::Workspace* ws) {
continue;
}
auto relu_node = consumers.front();
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
auto relu = repr::nn::get<repr::Relu>(relu_node);
auto relu_outputs = repr::nn::getOutputs(relu_node);
if (relu_outputs.size() != 1) {
@ -893,10 +891,6 @@ void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
initValue(strides, {1, 1});
auto pads = convTranspose->getPads();
initValue(pads, {0, 0, 0, 0});
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
auto* op = getMutableOpDef(*convTranspose);
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
auto aalgorithm = ialgo::deconvolution_direct;
auto dataType = filter->get_data_type();
ideep::tensor::dims filter_dims_mkldnn{filter->get_dim(1),
filter->get_dim(0),

View File

@ -64,12 +64,12 @@ void RemapHistograms(Histogram& src_hist, Histogram& dst_hist) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float dst_bin_begin = dst_hist.Min() + dst_bin_width * dst_bin;
float dst_bin_end = dst_bin_begin + dst_bin_width;
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
int dst_bin2 =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
dst_bin_width == 0 ? 0 : (src_bin_end - dst_hist.Min()) / dst_bin_width;
// 1 src_bin is mapped to at most 2 dst bin
assert(dst_bin2 <= dst_bin + 2);
(void)dst_bin2;
// dst_bin_cnt is the count from src_bin that should go to dst_bin
// The remainder should go to dst_bin2

View File

@ -698,9 +698,8 @@ TypeIdentifier Int8ConvDNNLowpPackedWeightBlobShapeFunctions::GetTypeMetaId() {
TypeMeta Int8FCDNNLowpPackedWeightBlobShapeFunctions::GetExternalTensorType(
const void* c) {
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
const Int8FCDNNLowPPackedWeightBlob* int8_tensor =
reinterpret_cast<const Int8FCDNNLowPPackedWeightBlob*>(c);
// const Int8FCDNNLowPPackedWeightBlob* int8_tensor =
// reinterpret_cast<const Int8FCDNNLowPPackedWeightBlob*>(c);
// We forced the output type to be uint8_t since we know it always is.
// If it is going to be implemented elsewhere, we might need to change here.
// return (int8_tensor->original_tensor).dtype();
@ -709,9 +708,8 @@ TypeMeta Int8FCDNNLowpPackedWeightBlobShapeFunctions::GetExternalTensorType(
TypeMeta Int8ConvDNNLowpPackedWeightBlobShapeFunctions::GetExternalTensorType(
const void* c) {
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
const Int8ConvDNNLowPPackedWeightBlob* int8_tensor =
reinterpret_cast<const Int8ConvDNNLowPPackedWeightBlob*>(c);
// const Int8ConvDNNLowPPackedWeightBlob* int8_tensor =
// reinterpret_cast<const Int8ConvDNNLowPPackedWeightBlob*>(c);
// return (int8_tensor->original_tensor).dtype();
return TypeMeta::Make<uint8_t>();
}

View File

@ -21,11 +21,7 @@ TensorQuantizationParams P99::ChooseQuantizationParams(
float org_min = min;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
float bin_width = (max - min) / nbins;
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,clang-analyzer-deadcode.DeadStores)
int zero_bin = round(-min / bin_width);
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
int best_width = 0;
double total_sum = 0;
for (int i = 0; i < nbins; ++i) {
total_sum += bins_f[i];

View File

@ -152,8 +152,7 @@ bool NNPACKConvOp::RunOnDeviceWithOrderNCHW() {
auto& filter = Input(1);
auto* Y = Output(0);
CAFFE_ENFORCE(X.ndim() == 4, "Input dim should be 4");
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
const int N = X.dim32(0), C = X.dim32(1), H = X.dim32(2), W = X.dim32(3);
const int C = X.dim32(1), H = X.dim32(2), W = X.dim32(3);
CAFFE_ENFORCE(filter.ndim() == 4, "");
const int M = filter.dim32(0);
CAFFE_ENFORCE(C % this->group_ == 0, "");
@ -181,12 +180,6 @@ bool NNPACKConvOp::RunOnDeviceWithOrderNCHW() {
biasData = dummyBias_.data();
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
const size_t batch_size = X.dim32(0);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
const size_t input_channels = X.dim32(1);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
const size_t output_channels = Y->dim32(1);
const nnp_size input_size = {.width = static_cast<size_t>(X.dim32(3)),
.height = static_cast<size_t>(X.dim32(2))};
// filter is MCHW

View File

@ -53,6 +53,9 @@ endif()
add_executable(test_api ${TORCH_API_TEST_SOURCES})
target_include_directories(test_api PRIVATE ${ATen_CPU_INCLUDE})
target_link_libraries(test_api PRIVATE torch gtest)
if(NOT MSVC)
target_compile_options(test_api PRIVATE -Wno-unused-variable)
endif()
if(USE_CUDA)
target_link_libraries(test_api PRIVATE

View File

@ -97,6 +97,9 @@ endif(MSVC)
target_link_libraries(test_jit PRIVATE ${JIT_TEST_DEPENDENCIES})
target_include_directories(test_jit PRIVATE ${ATen_CPU_INCLUDE})
if(NOT MSVC)
target_compile_options(test_jit PRIVATE -Wno-unused-variable)
endif()
if(LINUX)
#Update to target_link_options when CMake version can be upgraded

View File

@ -89,7 +89,6 @@ class BackendWithCompiler : public PyTorchBackendInterface {
at::Tensor h = val1.toTensor();
c10::List<at::Tensor> output_list;
double scalar_val = 1.0;
for (const auto& token : handle.toList()) {
IValue val = token;
auto instruction = val.toTuple()->elements()[0].toStringRef();

View File

@ -38,6 +38,9 @@ add_executable(test_tensorexpr
target_link_libraries(test_tensorexpr PRIVATE torch gtest)
target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE})
target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST)
if(NOT MSVC)
target_compile_options(test_tensorexpr PRIVATE -Wno-unused-variable)
endif()
add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp)
target_link_libraries(tutorial_tensorexpr PRIVATE torch)

View File

@ -354,6 +354,10 @@ if(USE_PRECOMPILED_HEADERS)
"$<$<COMPILE_LANGUAGE:CXX>:ATen/ATen.h>")
endif()
if(NOT MSVC)
target_compile_options(torch_python PRIVATE -Wno-unused-variable)
endif()
# Required workaround for generated sources
# See https://samthursfield.wordpress.com/2015/11/21/cmake-dependencies-between-targets-and-files-and-custom-commands/#custom-commands-in-different-directories
add_dependencies(torch_python generate-torch-sources)

View File

@ -270,6 +270,7 @@ c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
bool groupComm_updated = false;
MPI_Barrier(MPI_COMM_WORLD);
for (const auto i : c10::irange(kMaxNumRetries)) {
(void)i;
if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) {
groupComm_updated = true;
break;

View File

@ -939,6 +939,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
// created before encountering any communication calls. This is why we need
// the following for loop.
for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
(void)i;
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
}
@ -978,6 +979,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
// See [Group Start/End Note]
for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
(void)i;
C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
}

View File

@ -1383,8 +1383,6 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) {
auto& replica = bucket.replicas[replica_index];
for (const auto intra_bucket_index : c10::irange(replica.variables.size())) {
auto& variable = replica.variables[intra_bucket_index];
const auto offset = replica.offsets[intra_bucket_index];
const auto length = replica.lengths[intra_bucket_index];
bool global_unused = false;
// See Note [Skip allreducing local_used_map_dev]
@ -1634,6 +1632,7 @@ void Reducer::sync_bucket_indices(
std::vector<size_t> bucket;
bucket.reserve(bucket_size);
for (const auto j : c10::irange(bucket_size)) {
(void)j;
bucket.push_back(indices_accessor[indices_accessor_Index++]);
}
bucket_indices.emplace_back(std::move(bucket));

View File

@ -280,7 +280,7 @@ void ThreadPredicateMap::insert(
kir::Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const {
// No thread predicate is needed when tv is an output of a
// parallel broadcast expression.
if (auto bop = dynamic_cast<BroadcastOp*>(tv->definition())) {
if (dynamic_cast<BroadcastOp*>(tv->definition())) {
if (getParallelBroadcastDomains(tv).any()) {
return kir::IrBuilder(GpuLower::current()->kernel()).trueVal();
}

View File

@ -5163,7 +5163,7 @@ std::unique_ptr<Function> CompilationUnit::define(
if (shouldMangle) {
// If `shouldMangle` is set, we should generate a unique name for this
// function if there is already an existing one.
if (auto fn = find_function(name)) {
if (find_function(name)) {
name = mangle(name);
}
}

View File

@ -153,7 +153,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
}
} else if (auto classType = value_->type()->cast<ClassType>()) {
// This is a class, emit the proper attribute lookup
if (auto method = classType->findMethod(field)) {
if (classType->findMethod(field)) {
return std::make_shared<MethodValue>(getValue(), field);
}
if (classType->hasAttribute(field)) {
@ -169,7 +169,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
}
} else if (auto iface = value_->type()->cast<InterfaceType>()) {
// accessing methods of interfaces
if (auto schema = iface->getMethod(field)) {
if (iface->getMethod(field)) {
return std::make_shared<MethodValue>(getValue(), field);
}
} else if (auto enum_type = value_->type()->cast<EnumType>()) {

View File

@ -166,7 +166,7 @@ void numToTensorBool(Stack& stack) {
push(stack, at::scalar_to_tensor(b));
}
static const std::array<mobile::prim_op_fn_register, 14> op_reg = {
static const C10_UNUSED std::array<mobile::prim_op_fn_register, 14> op_reg = {
mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex),
mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
mobile::prim_op_fn_register("aten::format", aten_format),

View File

@ -3,6 +3,9 @@
namespace torch {
namespace jit {
// Start UUID at 1
static GraphPassNameType graphPassID = 1;
std::vector<GraphPassEntry>& getCustomPostPasses() {
static std::vector<GraphPassEntry> passes;
return passes;

View File

@ -26,8 +26,6 @@ using GraphPass = std::function<void(std::shared_ptr<Graph>&)>;
// Since Passes are std::functions, we associate a UUID to each pass, this way
// if we want to deregister a pass, we have something to reference it by.
using GraphPassNameType = unsigned int;
// Start UUID at 1
static GraphPassNameType graphPassID = 1;
// Graph pass entries have a name associated with them
using GraphPassEntry = std::pair<GraphPass, GraphPassNameType>;

View File

@ -158,8 +158,8 @@ struct CaptureList {
case CAPTURE_LIST: {
c10::List<at::Tensor> lst;
auto size = *size_it++;
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for (const auto i : c10::irange(size)) {
(void)i;
lst.emplace_back(var_capture_it->unpack(saved_for));
var_capture_it++;
}

View File

@ -54,7 +54,7 @@ Registerer& registerer() {
}
// global instance to run its constructor on startup
Registerer& dummy = registerer();
C10_UNUSED Registerer& dummy = registerer();
} // namespace

View File

@ -103,7 +103,7 @@ auto initBindings() {
return nullptr;
}
const auto torchBindInitializer = initBindings();
const auto C10_UNUSED torchBindInitializer = initBindings();
} // namespace

View File

@ -25,8 +25,8 @@ std::vector<at::Tensor> constructTensors(
for (const auto i : c10::irange(bufs_num)) {
buf_data_vec.push_back(buf_data[i]);
buf_dims_vec.emplace_back();
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
for (const auto dim : c10::irange(buf_ranks[i])) {
(void)dim;
buf_dims_vec[i].push_back(buf_dims[buf_dims_idx++]);
}
buf_dtypes_vec.push_back(static_cast<c10::ScalarType>(buf_dtypes[i]));

View File

@ -1151,9 +1151,6 @@ void LLVMCodeGenImpl::visit(LoadPtr v) {
throw std::runtime_error("invalid dtype in Load");
}
// Detect whether the vector mask is all true
bool unmasked_load = true;
// Handle the case where the load is contiguous and unmasked efficiently
auto idx_ramp = to<Ramp>(v->flat_index());
if (idx_ramp) {
@ -1805,9 +1802,6 @@ void LLVMCodeGenImpl::visit(IntrinsicsPtr v) {
}
void LLVMCodeGenImpl::visit(ExternalCallPtr v) {
constexpr int max_buffers = 10;
constexpr int max_dimensions = 40;
auto& func_registry = getNNCFunctionRegistry();
if (!func_registry.count(v->func_name())) {
throw unimplemented_lowering(v);

View File

@ -19,7 +19,7 @@ enum class TrainingMode {
// We pin IR version instead of using onnx::IR_VERSION so that the
// test_operators.py will be more stable. Only bump it when
// necessary.
static const size_t IR_VERSION = 7;
static const char* PRODUCER_VERSION = "1.11";
constexpr size_t IR_VERSION = 7;
constexpr const char* PRODUCER_VERSION = "1.11";
} // namespace onnx
} // namespace torch

View File

@ -236,8 +236,8 @@ void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order,
{
memcpy(dst, src, sizeof(int16_t) * len);
if (order != THP_nativeByteOrder()) {
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for(const auto i : c10::irange(len)) {
(void)i;
swapBytes16(dst);
dst += sizeof(int16_t);
}
@ -248,8 +248,8 @@ void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order,
{
memcpy(dst, src, sizeof(int32_t) * len);
if (order != THP_nativeByteOrder()) {
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for(const auto i : c10::irange(len)) {
(void)i;
swapBytes32(dst);
dst += sizeof(int32_t);
}
@ -260,8 +260,8 @@ void THP_encodeInt64Buffer(uint8_t* dst, const int64_t* src, THPByteOrder order,
{
memcpy(dst, src, sizeof(int64_t) * len);
if (order != THP_nativeByteOrder()) {
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for(const auto i : c10::irange(len)) {
(void)i;
swapBytes64(dst);
dst += sizeof(int64_t);
}
@ -272,8 +272,8 @@ void THP_encodeFloatBuffer(uint8_t* dst, const float* src, THPByteOrder order, s
{
memcpy(dst, src, sizeof(float) * len);
if (order != THP_nativeByteOrder()) {
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for(const auto i : c10::irange(len)) {
(void)i;
swapBytes32(dst);
dst += sizeof(float);
}
@ -284,8 +284,8 @@ void THP_encodeDoubleBuffer(uint8_t* dst, const double* src, THPByteOrder order,
{
memcpy(dst, src, sizeof(double) * len);
if (order != THP_nativeByteOrder()) {
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for(const auto i : c10::irange(len)) {
(void)i;
swapBytes64(dst);
dst += sizeof(double);
}