[19/N] Fix extra warnings brought by clang-tidy-17 (#144448)

Apply more clang-tidy fixes. There was a bug introduced by #144014 due to incorrect namespace concatenation which is reverted here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144448
Approved by: https://github.com/albanD
This commit is contained in:
cyy 2025-01-09 15:58:05 +00:00 committed by PyTorch MergeBot
parent 1353f3beb4
commit b0be30dd79
27 changed files with 89 additions and 64 deletions

View File

@ -86,14 +86,14 @@ TaskThreadPoolBase& _get_intraop_pool() {
#endif // C10_MOBILE
// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
// `fn` will be called with params: (thread_pool_task_id, task_id).
void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
// `fn` will be called with params: task_id.
static void _run_with_pool(const std::function<void(size_t)>& fn, size_t range) {
#ifndef C10_MOBILE
for (const auto i : c10::irange(1, range)) {
_get_intraop_pool().run([fn, i]() { fn((int)i, i); });
_get_intraop_pool().run([fn, i]() { fn(i); });
}
// Run the first task on the current thread directly.
fn(0, 0);
fn(0);
#else
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
@ -102,7 +102,7 @@ void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
// PThreadPool::run() is blocking. A std::function [const] reference to
// this lambda cannot go out of scope before PThreadPool::run() returns.
[&fn](const size_t task_id) {
fn(0 /* unused */, task_id);
fn(task_id);
}, range);
#endif // C10_MOBILE
}
@ -113,6 +113,10 @@ struct ParallelRegionGuard {
internal::set_thread_num(task_id);
_set_in_parallel_region(true);
}
ParallelRegionGuard(const ParallelRegionGuard&) = delete;
ParallelRegionGuard(ParallelRegionGuard&&) = delete;
ParallelRegionGuard& operator=(const ParallelRegionGuard&) = delete;
ParallelRegionGuard& operator=(ParallelRegionGuard&&) = delete;
~ParallelRegionGuard() {
_set_in_parallel_region(false);
@ -124,16 +128,16 @@ struct ParallelRegionGuard {
namespace internal {
inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
static std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
int64_t begin, int64_t end, int64_t grain_size) {
if ((end - begin) < grain_size) {
return std::make_tuple(1, std::max((int64_t)0, end - begin));
}
// Choose number of tasks based on grain size and number of threads.
size_t chunk_size = divup((end - begin), get_num_threads());
int64_t chunk_size = divup((end - begin), get_num_threads());
// Make sure each task is at least grain_size size.
chunk_size = std::max((size_t)grain_size, chunk_size);
size_t num_tasks = divup((end - begin), chunk_size);
chunk_size = std::max(grain_size, chunk_size);
size_t num_tasks = static_cast<size_t>(divup((end - begin), chunk_size));
return std::make_tuple(num_tasks, chunk_size);
}
@ -157,12 +161,12 @@ void invoke_parallel(
} state;
auto task = [f, &state, begin, end, chunk_size]
(int /* unused */, size_t task_id) {
int64_t local_start = begin + task_id * chunk_size;
(size_t task_id) {
int64_t local_start = static_cast<int64_t>(begin + task_id * chunk_size);
if (local_start < end) {
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
int64_t local_end = std::min(end, static_cast<int64_t>(chunk_size + local_start));
try {
ParallelRegionGuard guard(task_id);
ParallelRegionGuard guard(static_cast<int>(task_id));
f(local_start, local_end);
} catch (...) {
if (!state.err_flag.test_and_set()) {

View File

@ -284,6 +284,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
// NOLINTNEXTLINE(bugprone-sizeof-expression)
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value)));
}
};
@ -392,7 +393,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
#endif
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
@ -901,12 +902,10 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5) {
#ifndef USE_ROCM
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
#endif
// Disallow fp16 reductions that could lead to unexpected overflow issues.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
@ -1284,7 +1283,7 @@ void gemm_and_bias(
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
@ -1466,7 +1465,7 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
size_t workspaceSize = _getWorkspaceSize();
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
CuBlasLtMatmulPreference preference;
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);

View File

@ -56,7 +56,6 @@ cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type) {
}
}
#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch_offset, bool is_const=false) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == kStrided);
IntArrayRef input_strides = input.strides();
@ -121,7 +120,6 @@ CuSparseDnMatDescriptor::CuSparseDnMatDescriptor(const Tensor& input, int64_t ba
CuSparseConstDnMatDescriptor::CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset) {
descriptor_.reset(createRawDnMatDescriptor(input, batch_offset, /*is_const*/true));
}
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
CuSparseDnVecDescriptor::CuSparseDnVecDescriptor(const Tensor& input) {
// cuSPARSE doesn't support batched vectors

View File

@ -116,7 +116,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
virtual void recordMemoryHistory(
std::optional<std::string> enabled,
const std::optional<std::string>& enabled,
const std::string& stacks,
size_t max_entries) const {
FAIL_MTIAHOOKS_FUNC(__func__);

View File

@ -162,6 +162,7 @@ grid_sample_backward_helper_in(
static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
grid_sample_backward_helper_out(
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::tuple<Tensor, Tensor> bw_out,
int64_t grad_input_out_bdim,
int64_t grad_grid_out_bdim,
@ -261,7 +262,7 @@ struct UpsampleBackwardBatchRuleHelper<F, Func, typelist<A, B, C, T...>> {
auto out = Func(
std::move(grad_output_),
std::move(output_size),
output_size,
std::move(physical_input_size),
std::forward<T>(extra_args)...);
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)

View File

@ -579,7 +579,7 @@ static void _rrelu_with_noise_train(
Tensor& noise,
const Scalar& lower_,
const Scalar& upper_,
std::optional<Generator> generator) {
const std::optional<Generator>& generator) {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t lower = lower_.to<opmath_t>();
opmath_t upper = upper_.to<opmath_t>();

View File

@ -17,7 +17,8 @@
#include <ATen/ops/mm_native.h>
#endif
namespace at::native::xpu {
namespace at::native {
namespace xpu {
// result = beta * self + alpha * (mat1 * mat2)
Tensor& addmm_out(
@ -454,7 +455,7 @@ Tensor& tensordot_out(
TORCH_LIBRARY_IMPL(aten, XPU, m) {
m.impl("tensordot.out", TORCH_FN(tensordot_out));
}
} // namespace at::native::xpu
} // namespace xpu
TORCH_IMPL_FUNC(addmm_out_xpu)
(const Tensor& self,
@ -469,11 +470,13 @@ TORCH_IMPL_FUNC(addmm_out_xpu)
TORCH_IMPL_FUNC(mm_out_xpu)
(const Tensor& self, const Tensor& mat2, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
xpu::mm_out(self, mat2, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(bmm_out_xpu)
(const Tensor& self, const Tensor& batch2, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
xpu::bmm_out(self, batch2, const_cast<Tensor&>(result));
}
@ -498,7 +501,13 @@ TORCH_IMPL_FUNC(baddbmm_out_xpu)
const Scalar& alpha,
const Tensor& result) {
xpu::baddbmm_out(
self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result));
self,
batch1,
batch2,
beta,
alpha,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(addmv_out_xpu)
@ -508,5 +517,8 @@ TORCH_IMPL_FUNC(addmv_out_xpu)
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
xpu::addmv_out(self, mat, vec, beta, alpha, const_cast<Tensor&>(result));
}
} // namespace at::native

View File

@ -26,7 +26,7 @@ PyObject* THPGenerator_initDefaultGenerator(const at::Generator& cdata) {
if (!self)
throw python_error();
auto self_ = reinterpret_cast<THPGenerator*>(self.get());
self_->cdata = std::move(cdata);
self_->cdata = cdata;
return self.release();
}

View File

@ -118,8 +118,7 @@ struct TensorDataContainer {
type_(TensorDataContainerType::InitList) {}
#define TENSOR(T, S) \
TensorDataContainer(T value) \
: sizes_(), \
scalar_type_(at::k##S), \
: scalar_type_(at::k##S), \
type_(TensorDataContainerType::Scalar), \
scalar_(value) {}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)

View File

@ -136,6 +136,7 @@ namespace torch::autograd {
// NOTE: this function is written in a way that assumes it's only called for
// backward; it's used by engine.cpp. This is responsible for forwarding a call
// from C++'s Node::apply to a Python method "apply".
// NOLINTNEXTLINE(*-rvalue-reference*)
auto PyNode::apply(variable_list&& inputs) -> variable_list {
pybind11::gil_scoped_acquire gil;
at::OptionalDeviceGuard _device_guard;
@ -184,7 +185,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
}
auto PyNode::defer_to_dynamo(
variable_list&& inputs,
const variable_list& inputs,
const std::optional<PyObject*>& compiler) -> variable_list {
pybind11::gil_scoped_acquire gil;
at::OptionalDeviceGuard _device_guard;
@ -526,7 +527,7 @@ static void THPFunction_dealloc(THPFunction* self) {
Py_TYPE(self)->tp_free((PyObject*)self);
}
PyObject* THPFunction_new(
static PyObject* THPFunction_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwargs) {
@ -875,6 +876,7 @@ struct InputFlags {
std::vector<bool> is_variable_input;
};
namespace {
template <bool enforce_variables>
std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
UnpackedInput unpacked;
@ -938,7 +940,7 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
// value is assigned by the prim::PythonOp node and helps to eventually route
// the outputs of the subgraph correctly This newly created subgraph is then
// added to the prim::PythonOp node as a subgraph attribute
static void _append_subgraph(
void _append_subgraph(
torch::jit::Node* node,
torch::jit::Graph* graph,
std::vector<torch::jit::Value*> trace_outputs,
@ -980,7 +982,7 @@ static void _append_subgraph(
}
}
static torch::jit::Node* _trace_pre_record(
torch::jit::Node* _trace_pre_record(
PyObject* op_obj,
PyObject* input_objects,
const variable_list& input_vars) {
@ -1011,7 +1013,7 @@ static torch::jit::Node* _trace_pre_record(
std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
}
static void _trace_post_record(
void _trace_post_record(
torch::jit::Node* node,
PyObject* op_obj,
const variable_list& input_vars,
@ -1218,8 +1220,6 @@ PyObject* THPFunction_maybe_clear_saved_tensors(
END_HANDLE_TH_ERRORS
}
namespace {
THPObjectPtr make_ctx_input_tuple(
THPFunction* ctx,
const UnpackedInput& unpacked_input,
@ -1253,8 +1253,6 @@ THPObjectPtr make_ctx_input_output_tuple(
return result;
}
} // namespace
static PyObject* THPFunction_setup_context = nullptr;
static PyObject* get_base_setup_context() {
@ -1652,6 +1650,7 @@ PyObject* THPFunction_metadata(THPFunction* self, void* _unused) {
return metadata;
END_HANDLE_TH_ERRORS
}
} // namespace
using getter = PyObject* (*)(PyObject*, void*);
using setter = int (*)(PyObject*, PyObject*, void*);

View File

@ -36,7 +36,7 @@ struct PyNode : public Node {
variable_list apply(variable_list&& inputs) override;
variable_list defer_to_dynamo(
variable_list&& inputs,
const variable_list& inputs,
const std::optional<PyObject*>& compiler);
void release_variables() override;

View File

@ -1,7 +1,6 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <mutex>
#include <unordered_map>
#include <utility>
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
@ -29,8 +28,7 @@ int device_count = 0;
void custom_raw_deleter(void* ptr);
_AllocationMetadata::_AllocationMetadata()
: size(0), device_idx(-1), stream{} {}
_AllocationMetadata::_AllocationMetadata() : size(0), device_idx(-1) {}
_AllocationMetadata::_AllocationMetadata(
size_t size,

View File

@ -34,7 +34,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext {
void* data_;
size_t size_;
int device_;
cudaStream_t stream_;
cudaStream_t stream_{};
};
#if defined(TORCH_HIP_VERSION)
@ -63,7 +63,7 @@ struct _AllocationMetadata {
cudaStream_t stream);
size_t size;
c10::DeviceIndex device_idx;
cudaStream_t stream;
cudaStream_t stream{};
};
struct TORCH_CUDA_CPP_API CUDAPluggableAllocator

View File

@ -15,6 +15,9 @@ class DetectorMap {
public:
DetectorMap(const DetectorMap&) = delete;
DetectorMap& operator=(const DetectorMap&) = delete;
DetectorMap(DetectorMap&&) = delete;
DetectorMap& operator=(DetectorMap&&) = delete;
~DetectorMap() = default;
static DetectorMap& get() {
static DetectorMap instance;
return instance;

View File

@ -99,6 +99,7 @@ class Lock {
Lock(const Lock& that) = delete;
Lock& operator=(const Lock& other) = delete;
Lock& operator=(Lock&& other) noexcept {
if (this != &other) {
fd_ = other.fd_;

View File

@ -512,8 +512,7 @@ void TCPStoreMasterDaemon::run() {
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
// receive the queries
bool finished = false;
while (!finished) {
while (true) {
for (const auto i : c10::irange(sockets_.size())) {
fds[i].revents = 0;
}
@ -524,7 +523,6 @@ void TCPStoreMasterDaemon::run() {
if (res == 0) {
auto rv = WaitForSingleObject(ghStopEvent_, 0);
if (rv != WAIT_TIMEOUT) {
finished = true;
break;
}
continue;
@ -567,8 +565,7 @@ void TCPStoreMasterDaemon::run() {
tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
// receive the queries
bool finished = false;
while (!finished) {
while (true) {
for (const auto i : c10::irange(sockets_.size())) {
fds[i].revents = 0;
}
@ -602,7 +599,6 @@ void TCPStoreMasterDaemon::run() {
"Unexpected poll revent on the control pipe's reading fd: " +
std::to_string(fds[1].revents));
}
finished = true;
break;
}
queryFds(fds);

View File

@ -32,6 +32,7 @@ class RequestImpl : public Request {
}
private:
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const httplib::Request& req_;
};
@ -49,6 +50,7 @@ class ResponseImpl : public Response {
}
private:
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
httplib::Response& res_;
};

View File

@ -133,6 +133,7 @@ class TORCH_API Message final : public torch::CustomClassHolder {
Message(Message&& other) = delete;
Message& operator=(Message const& rhs) = delete;
Message& operator=(Message&& rhs) = delete;
~Message() override = default;
// Destructively retrieves the payload.
std::vector<char>&& movePayload() &&;

View File

@ -24,6 +24,9 @@ struct TORCH_API GloballyUniqueId final {
GloballyUniqueId(worker_id_t createdOn, local_id_t localId);
GloballyUniqueId(const GloballyUniqueId& other) = default;
GloballyUniqueId& operator=(const GloballyUniqueId& other) = delete;
GloballyUniqueId(GloballyUniqueId&& other) = default;
GloballyUniqueId& operator=(GloballyUniqueId&& other) = delete;
~GloballyUniqueId() = default;
bool operator==(const GloballyUniqueId& other) const;
bool operator!=(const GloballyUniqueId& other) const;

View File

@ -73,8 +73,7 @@ static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
static void throw_python_error() {
python_error err;
err.persist();
// NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference)
throw err;
throw std::move(err);
}
static PyObject* check(PyObject* pyresult) {
@ -109,20 +108,21 @@ struct PythonLogger {
// must be called while GIL is held
void log(Level level, std::string_view msg) const {
THPObjectPtr pymethod(PyUnicode_FromString(levelNames_[level].data()));
THPObjectPtr pymethod(PyUnicode_FromString(levelNames_[level]));
TORCH_INTERNAL_ASSERT(pymethod != nullptr);
THPObjectPtr pyfunc(PyObject_GetAttr(logger_, pymethod.get()));
if (pyfunc == nullptr) {
throw_python_error();
}
PyObject* result = PyObject_CallFunction(pyfunc.get(), "s", msg.data());
PyObject* result =
PyObject_CallFunction(pyfunc.get(), "s", std::string(msg).c_str());
if (result == nullptr) {
throw_python_error();
}
}
private:
static constexpr std::array<std::string_view, COUNT> levelNames_ = {
static constexpr std::array<const char*, COUNT> levelNames_ = {
"debug", // Level::DEBUG
"info", // Level::INFO
"warning", // Level::WARNING
@ -421,7 +421,7 @@ static struct PyModuleDef _module = {
-1,
_methods};
PyObject* wrap_lifted_ivalue_args(
static PyObject* wrap_lifted_ivalue_args(
const std::vector<LiftedIValueArg>& lifted_ivalue_args) {
PyObject* pyivalueargs =
PyList_New(static_cast<Py_ssize_t>(lifted_ivalue_args.size()));
@ -440,7 +440,7 @@ PyObject* wrap_lifted_ivalue_args(
return pyivalueargs;
}
PyObject* wrap_node_origins(
static PyObject* wrap_node_origins(
const AutogradCompilerCall& compiler,
size_t dynamic_sizes) {
TORCH_INTERNAL_ASSERT(
@ -475,7 +475,7 @@ PyObject* wrap_node_origins(
return pyallorigins;
}
void set_ivalue_proxies(
static void set_ivalue_proxies(
PyObject* fake_ivalue_args,
std::vector<LiftedIValueArg>& lifted_ivalue_args) {
TORCH_INTERNAL_ASSERT(PyList_Check(fake_ivalue_args));
@ -569,7 +569,7 @@ static SizeInput::DynType get_default_dyn_type() {
}
// Only call this function while holding GIL
CacheNode* _compiled_autograd_impl(
static CacheNode* _compiled_autograd_impl(
const std::shared_ptr<Node>& graph_root,
GraphTask& graph_task,
bool accumulate_grad,
@ -808,10 +808,11 @@ struct LockGuardWithErrorLogs {
mtx_.unlock();
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
std::mutex& mtx_;
};
variable_list compiled_autograd(
static variable_list compiled_autograd(
const std::shared_ptr<Node>& graph_root,
GraphTask& graph_task,
bool accumulate_grad,

View File

@ -134,6 +134,7 @@ class Cache {
}
mutable std::mutex lock_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const size_t max_size_ = 0;
ElementList element_list_;
ElementMap element_map_;

View File

@ -89,6 +89,7 @@ class MaybeRef {
private:
std::optional<T> storage_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const T& ref_;
};

View File

@ -90,11 +90,11 @@ void initModule(PyObject* module) {
m.def(
"_mtia_recordMemoryHistory",
[](std::optional<std::string> enabled,
[](const std::optional<std::string>& enabled,
const std::string& stacks,
size_t max_entries) {
at::detail::getMTIAHooks().recordMemoryHistory(
std::move(enabled), stacks, max_entries);
enabled, stacks, max_entries);
});
m.def("_mtia_memorySnapshot", []() {

View File

@ -14,7 +14,7 @@ template <typename T>
struct strong_pointer_type_caster {
template <typename T_>
static handle cast(
T_&& src,
const T_& src,
return_value_policy /*policy*/,
handle /*parent*/) {
const auto* ptr = reinterpret_cast<const void*>(src.value_of());
@ -33,7 +33,7 @@ template <typename T>
struct strong_uint_type_caster {
template <typename T_>
static handle cast(
T_&& src,
const T_& src,
return_value_policy /*policy*/,
handle /*parent*/) {
return handle(THPUtils_packUInt64(src.value_of()));

View File

@ -47,6 +47,7 @@ struct DefaultStubs : public ProfilerStubs {
TORCH_CHECK(false, name_, " used in profiler but not enabled.");
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const char* const name_;
};
} // namespace

View File

@ -61,6 +61,7 @@ struct LineNumberProgram {
uint64_t form;
};
std::vector<Member> directory_members;
directory_members.reserve(directory_entry_format_count);
for (size_t i = 0; i < directory_entry_format_count; i++) {
directory_members.push_back({L.readULEB128(), L.readULEB128()});
}
@ -85,6 +86,7 @@ struct LineNumberProgram {
}
auto file_name_entry_format_count = L.read<uint8_t>();
std::vector<Member> file_members;
file_members.reserve(file_name_entry_format_count);
for (size_t i = 0; i < file_name_entry_format_count; i++) {
file_members.push_back({L.readULEB128(), L.readULEB128()});
}
@ -314,6 +316,7 @@ struct LineNumberProgram {
uint64_t length_ = 0;
bool is_64bit_ = false;
std::vector<uint8_t> standard_opcode_lengths_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
Sections& s_;
uint64_t offset_;
uint64_t start_address_ = 0;

View File

@ -85,6 +85,7 @@ ScriptModuleOutput ScriptModuleBenchmark::runOnce(
}
template <>
// NOLINTNEXTLINE(*-rvalue-reference-param-not-moved)
void ModuleBenchmark::runOnce(ModuleInput&& input) const {
CHECK(initialized_);
pybind11::gil_scoped_acquire gil_guard;
@ -101,6 +102,7 @@ ModuleOutput ModuleBenchmark::runOnce(
}
template <>
// NOLINTNEXTLINE(*-rvalue-reference-param-not-moved)
void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
jit::Stack stack = jit::createStackForSchema(
model_.get_method("forward").function().getSchema(),