[1/N] Enable Wunused-result and Wunused-variable in torch targets (#110722)

They are useful for checking results of function calls.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110722
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy 2023-10-08 23:43:45 +00:00 committed by PyTorch MergeBot
parent e1f0f9c64e
commit 3ec33957eb
17 changed files with 63 additions and 46 deletions

View File

@ -121,7 +121,11 @@ struct Workspace {
constexpr size_t nnpack_memory_alignment_boundary = 64; constexpr size_t nnpack_memory_alignment_boundary = 64;
// Won't work on Windows, but NNPACK doesn't support Windows either // Won't work on Windows, but NNPACK doesn't support Windows either
posix_memalign(&buffer, nnpack_memory_alignment_boundary, size); auto res = posix_memalign(&buffer, nnpack_memory_alignment_boundary, size);
if (res != 0) {
TORCH_CHECK(false, "posix_memalign failed:", strerror(errno), " (", errno, ")");
}
return;
} }
~Workspace() { ~Workspace() {

View File

@ -231,14 +231,16 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 256; constexpr int Headdim = 256;
int device; int device = -1;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block; int max_smem_per_sm = 0, max_smem_per_block = 0;
cudaError status_ = cudaDeviceGetAttribute( cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
C10_CUDA_CHECK(status_);
status_ = cudaDeviceGetAttribute( status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
C10_CUDA_CHECK(status_);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem). // For A100, we want to run with 128 x 64 (128KB smem).

View File

@ -1,12 +1,12 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/test/test_assert.h>
#include <cmath>
#include <iostream> #include <iostream>
#include <limits> #include <limits>
#include <sstream> #include <sstream>
#include <cmath>
#include <type_traits> #include <type_traits>
#include <ATen/test/test_assert.h>
using namespace at; using namespace at;
@ -118,7 +118,9 @@ ASSERT_SAME_TYPE(traps);
ASSERT_SAME_TYPE(tinyness_before); ASSERT_SAME_TYPE(tinyness_before);
TEST(TestHalf, CommonMath) { TEST(TestHalf, CommonMath) {
#ifndef NDEBUG
float threshold = 0.00001; float threshold = 0.00001;
#endif
assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold); assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold);
assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold); assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold);
assert(std::abs(std::log(Half(1.0)) - std::log(1.0f)) <= threshold); assert(std::abs(std::log(Half(1.0)) - std::log(1.0f)) <= threshold);

View File

@ -62,9 +62,6 @@ constexpr int64_t _max_to() {
template<typename RNG, c10::ScalarType S, typename T> template<typename RNG, c10::ScalarType S, typename T>
void test_random_from_to(const at::Device& device) { void test_random_from_to(const at::Device& device) {
constexpr int64_t min_val = _min_val<T>();
constexpr int64_t min_from = _min_from<T>();
constexpr int64_t max_val = _max_val<T>(); constexpr int64_t max_val = _max_val<T>();
constexpr int64_t max_to = _max_to<T>(); constexpr int64_t max_to = _max_to<T>();
@ -81,6 +78,7 @@ void test_random_from_to(const at::Device& device) {
static_cast<c10::optional<int64_t>>(c10::nullopt) static_cast<c10::optional<int64_t>>(c10::nullopt)
}; };
} else if constexpr (::std::is_signed<T>::value) { } else if constexpr (::std::is_signed<T>::value) {
constexpr int64_t min_from = _min_from<T>();
froms = { froms = {
min_from, min_from,
-42L, -42L,
@ -161,6 +159,8 @@ void test_random_from_to(const at::Device& device) {
} }
if constexpr (::std::is_same_v<T, int64_t>) { if constexpr (::std::is_same_v<T, int64_t>) {
ASSERT_TRUE(full_64_bit_range_case_covered); ASSERT_TRUE(full_64_bit_range_case_covered);
} else {
(void)full_64_bit_range_case_covered;
} }
ASSERT_TRUE(from_to_case_covered); ASSERT_TRUE(from_to_case_covered);
ASSERT_TRUE(from_case_covered); ASSERT_TRUE(from_case_covered);

View File

@ -11,18 +11,21 @@ namespace serialize {
FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) { FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) {
fp_ = fopen(file_name.c_str(), "rb"); fp_ = fopen(file_name.c_str(), "rb");
if (fp_ == nullptr) { if (fp_ == nullptr) {
auto old_errno = errno;
#if defined(_WIN32) && (defined(__MINGW32__) || defined(_MSC_VER))
char buf[1024]; char buf[1024];
buf[0] = '\0'; buf[0] = '\0';
#if defined(_WIN32) && (defined(__MINGW32__) || defined(_MSC_VER)) char* error_msg = buf;
strerror_s(buf, sizeof(buf), errno); strerror_s(buf, sizeof(buf), old_errno);
#else #else
strerror_r(errno, buf, sizeof(buf)); auto error_msg =
std::system_category().default_error_condition(old_errno).message();
#endif #endif
AT_ERROR( AT_ERROR(
"open file failed because of errno ", "open file failed because of errno ",
errno, old_errno,
" on fopen: ", " on fopen: ",
buf, error_msg,
", file path: ", ", file path: ",
file_name); file_name);
} }
@ -35,7 +38,7 @@ FileAdapter::RAIIFile::~RAIIFile() {
} }
// FileAdapter directly calls C file API. // FileAdapter directly calls C file API.
FileAdapter::FileAdapter(const std::string& file_name): file_(file_name) { FileAdapter::FileAdapter(const std::string& file_name) : file_(file_name) {
const int fseek_ret = fseek(file_.fp_, 0L, SEEK_END); const int fseek_ret = fseek(file_.fp_, 0L, SEEK_END);
TORCH_CHECK(fseek_ret == 0, "fseek returned ", fseek_ret); TORCH_CHECK(fseek_ret == 0, "fseek returned ", fseek_ret);
#if defined(_MSC_VER) #if defined(_MSC_VER)
@ -68,11 +71,7 @@ size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
const int fseek_ret = fseeko(file_.fp_, pos, SEEK_SET); const int fseek_ret = fseeko(file_.fp_, pos, SEEK_SET);
#endif #endif
TORCH_CHECK( TORCH_CHECK(
fseek_ret == 0, fseek_ret == 0, "fseek returned ", fseek_ret, ", context: ", what);
"fseek returned ",
fseek_ret,
", context: ",
what);
return fread(buf, 1, n, file_.fp_); return fread(buf, 1, n, file_.fp_);
} }

View File

@ -21,8 +21,7 @@ using transform::Graph;
*/ */
TEST(CommonSubexpressionEliminationTest, TestSimple) { TEST(CommonSubexpressionEliminationTest, TestSimple) {
NetDef netdef; NetDef netdef;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables) OperatorDef* op [[maybe_unused]] = nullptr;
OperatorDef* op;
// This operator simply reads input and outputs it. // This operator simply reads input and outputs it.
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
@ -76,8 +75,7 @@ TEST(CommonSubexpressionEliminationTest, TestSimple) {
*/ */
TEST(CommonSubexpressionEliminationTest, TestFromExternal) { TEST(CommonSubexpressionEliminationTest, TestFromExternal) {
NetDef netdef; NetDef netdef;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables) OperatorDef* op [[maybe_unused]] = nullptr;
OperatorDef* op;
// This operator simply reads input and outputs it. // This operator simply reads input and outputs it.
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)

View File

@ -40,7 +40,8 @@ struct AllocAligned {
#elif defined(_MSC_VER) #elif defined(_MSC_VER)
p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize); p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize);
#else #else
posix_memalign((void**)&p, kGEMMLOWPCacheLineSize, sizeof(T)); auto res = posix_memalign((void**)&p, kGEMMLOWPCacheLineSize, sizeof(T));
(void)res;
#endif #endif
if (p) { if (p) {

View File

@ -2926,7 +2926,6 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) {
fusion.addOutput(tv_N); fusion.addOutput(tv_N);
auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0); auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0); at::manual_seed(0);
std::vector<TensorView*> outputs_of_red; std::vector<TensorView*> outputs_of_red;
at::Tensor aten_input = at::Tensor aten_input =
@ -7704,7 +7703,6 @@ TEST_F(NVFuserTest, FusionIssue970_CUDA) {
tv1->split(1, 4); tv1->split(1, 4);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0); at::manual_seed(0);
at::Tensor t0 = at::randn({nelm, nelm}, options); at::Tensor t0 = at::randn({nelm, nelm}, options);

View File

@ -3149,7 +3149,6 @@ TEST_F(NVFuserTest, FusionPropagateParallelTypesToSiblings_CUDA) {
} }
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0); at::manual_seed(0);
at::Tensor t0 = at::randn({9999}, options); at::Tensor t0 = at::randn({9999}, options);
@ -4943,8 +4942,6 @@ TEST_F(NVFuserTest, FusionIssueRepro1844_CUDA) {
const auto options = const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
const auto mask_options =
at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0);
at::manual_seed(0); at::manual_seed(0);
at::Tensor a = at::randn(shape, options); at::Tensor a = at::randn(shape, options);

View File

@ -4012,7 +4012,6 @@ TEST_F(NVFuserTest, FusionShiftNoPaddingChain_CUDA) {
int numel_y = 101; int numel_y = 101;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0); at::manual_seed(0);
at::Tensor t0 = at::randn({numel_x, numel_y}, options); at::Tensor t0 = at::randn({numel_x, numel_y}, options);
std::vector<IValue> inputs = {t0}; std::vector<IValue> inputs = {t0};
@ -4162,7 +4161,6 @@ TEST_F(NVFuserTest, FusionPartialSplit1_CUDA) {
"Invalid extent of outer domain of partial split"); "Invalid extent of outer domain of partial split");
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0); at::manual_seed(0);
at::Tensor t0 = at::randn({numel_x}, options); at::Tensor t0 = at::randn({numel_x}, options);
std::vector<IValue> inputs = {t0}; std::vector<IValue> inputs = {t0};
@ -4242,7 +4240,6 @@ TEST_F(NVFuserTest, FusionPartialSplit3_CUDA) {
const int numel_y = 32 + 3; const int numel_y = 32 + 3;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0); at::manual_seed(0);
at::Tensor t0 = at::randn({numel_x, numel_y}, options); at::Tensor t0 = at::randn({numel_x, numel_y}, options);
std::vector<IValue> inputs = {t0}; std::vector<IValue> inputs = {t0};

View File

@ -960,8 +960,6 @@ void addStorageDeleterFns(
if (storage_pair != storages.end()) { if (storage_pair != storages.end()) {
auto ctx = storage_pair->second->data_ptr().get_context(); auto ctx = storage_pair->second->data_ptr().get_context();
TORCH_CHECK(ctx == nullptr, " Not expecting deleter function"); TORCH_CHECK(ctx == nullptr, " Not expecting deleter function");
auto curr_deleter = storage_pair->second->data_ptr().get_deleter();
storage_pair->second->set_data_ptr_noswap(std::move(data_ptr)); storage_pair->second->set_data_ptr_noswap(std::move(data_ptr));
} else { } else {
data_ptr.release_context(); data_ptr.release_context();
@ -1102,7 +1100,6 @@ static void registerCudaPluggableAllocator(PyObject* module) {
m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) { m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) {
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
auto alloc = c10::cuda::CUDACachingAllocator::get(); auto alloc = c10::cuda::CUDACachingAllocator::get();
auto data_ptr = storage_impl->data_ptr().get();
return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter()); return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter());
}); });
@ -1197,7 +1194,6 @@ static void registerCudaPluggableAllocator(PyObject* module) {
auto delta = c10::cuda::CUDACachingAllocator::setCheckpointPoolState( auto delta = c10::cuda::CUDACachingAllocator::setCheckpointPoolState(
device, pps); device, pps);
auto& freed_pointers = delta.ptrs_freed; auto& freed_pointers = delta.ptrs_freed;
auto& allocd_pointers = delta.dataptrs_allocd;
std::unordered_set<void*> allocd_set; std::unordered_set<void*> allocd_set;
for (auto& data_ptr : delta.dataptrs_allocd) { for (auto& data_ptr : delta.dataptrs_allocd) {

View File

@ -165,7 +165,21 @@ void TCPStoreMasterDaemon::closeStopSignal() {
void TCPStoreMasterDaemon::stop() { void TCPStoreMasterDaemon::stop() {
if (controlPipeFd_[1] != -1) { if (controlPipeFd_[1] != -1) {
::write(controlPipeFd_[1], "\0", 1); ssize_t written_bytes = -1;
while (true) {
written_bytes = ::write(controlPipeFd_[1], "\0", 1);
if (written_bytes < 0) {
if (errno == EAGAIN) {
continue;
}
TORCH_CHECK(false, "Failed to write the control pipe:", errno);
}
break;
}
if (written_bytes == 0) {
TORCH_CHECK(false, "Failed to write the control pipe");
}
// close the write end of the pipe // close the write end of the pipe
::close(controlPipeFd_[1]); ::close(controlPipeFd_[1]);
controlPipeFd_[1] = -1; controlPipeFd_[1] = -1;

View File

@ -489,7 +489,6 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
} }
// If the device of indices tensor is not the same with it of the input // If the device of indices tensor is not the same with it of the input
// tensor, move it to the device of the input tensor // tensor, move it to the device of the input tensor
auto indices_val = node->input(1);
if (inputTensorValues[0].device() != indices.device()) { if (inputTensorValues[0].device() != indices.device()) {
indices = indices.to(inputTensorValues[0].device()); indices = indices.to(inputTensorValues[0].device());
} }

View File

@ -154,7 +154,6 @@ TensorTypePtr TorchTensorTypeFromONNX(
ListTypePtr TorchListTypeFromONNX( ListTypePtr TorchListTypeFromONNX(
const onnx::TypeProto_Sequence& onnx_sequence_type, const onnx::TypeProto_Sequence& onnx_sequence_type,
SymbolDimMap& symbol_dim_map) { SymbolDimMap& symbol_dim_map) {
c10::optional<at::ScalarType> scalar_type;
if (onnx_sequence_type.has_elem_type()) { if (onnx_sequence_type.has_elem_type()) {
const auto& onnx_seq_elem_type = onnx_sequence_type.elem_type(); const auto& onnx_seq_elem_type = onnx_sequence_type.elem_type();
if (onnx_seq_elem_type.has_tensor_type()) { if (onnx_seq_elem_type.has_tensor_type()) {

View File

@ -41,7 +41,7 @@ BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
for (const auto thread_id : c10::irange(config.num_calling_threads)) { for (const auto thread_id : c10::irange(config.num_calling_threads)) {
// Just in case we generate num_iters inputs for each of the threads // Just in case we generate num_iters inputs for each of the threads
// This was if one thread does all the work we will be fine // This was if one thread does all the work we will be fine
for (const auto i : for (const auto i [[maybe_unused]] :
c10::irange(config.num_iters + config.num_warmup_iters)) { c10::irange(config.num_iters + config.num_warmup_iters)) {
thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)])); thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)]));
} }

View File

@ -38,7 +38,8 @@ void start_manager() {
std::string msg("ERROR: execl failed: "); std::string msg("ERROR: execl failed: ");
msg += std::strerror(errno); msg += std::strerror(errno);
msg += '\n'; msg += '\n';
write(1, msg.c_str(), msg.size()); auto res = write(1, msg.c_str(), msg.size());
(void)res;
exit(1); exit(1);
} }

View File

@ -54,9 +54,19 @@ void unregister_fd(int fd) {
client_sessions.erase(fd); client_sessions.erase(fd);
} }
void print_init_message(const char* message) { void print_init_message(std::string_view message) {
write(1, message, strlen(message)); ssize_t written_bytes = -1;
write(1, "\n", 1); while (!message.empty()) {
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
SYSCHECK_ERR_RETURN_NEG1(
written_bytes = write(1, message.data(), message.size()));
message.remove_prefix(written_bytes);
}
written_bytes = 0;
while (written_bytes != 1) {
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
SYSCHECK_ERR_RETURN_NEG1(written_bytes = write(1, "\n", 1));
}
} }
bool object_exists(const char* name) { bool object_exists(const char* name) {
@ -111,10 +121,10 @@ int main(int argc, char* argv[]) {
std::vector<int> to_add; std::vector<int> to_add;
std::vector<int> to_remove; std::vector<int> to_remove;
for (;;) { for (;;) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables) int nevents = -1;
int nevents;
if (client_sessions.empty()) if (client_sessions.empty())
timeout = SHUTDOWN_TIMEOUT; timeout = SHUTDOWN_TIMEOUT;
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
SYSCHECK_ERR_RETURN_NEG1( SYSCHECK_ERR_RETURN_NEG1(
nevents = poll(pollfds.data(), pollfds.size(), timeout)); nevents = poll(pollfds.data(), pollfds.size(), timeout));
timeout = -1; timeout = -1;