mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[19/N] Fix extra warnings brought by clang-tidy-17 (#144448)
Apply more clang-tidy fixes. There was a bug introduced by #144014 due to incorrect namespace concatenation which is reverted here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144448 Approved by: https://github.com/albanD
This commit is contained in:
parent
1353f3beb4
commit
b0be30dd79
|
|
@ -86,14 +86,14 @@ TaskThreadPoolBase& _get_intraop_pool() {
|
|||
#endif // C10_MOBILE
|
||||
|
||||
// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
|
||||
// `fn` will be called with params: (thread_pool_task_id, task_id).
|
||||
void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
|
||||
// `fn` will be called with params: task_id.
|
||||
static void _run_with_pool(const std::function<void(size_t)>& fn, size_t range) {
|
||||
#ifndef C10_MOBILE
|
||||
for (const auto i : c10::irange(1, range)) {
|
||||
_get_intraop_pool().run([fn, i]() { fn((int)i, i); });
|
||||
_get_intraop_pool().run([fn, i]() { fn(i); });
|
||||
}
|
||||
// Run the first task on the current thread directly.
|
||||
fn(0, 0);
|
||||
fn(0);
|
||||
#else
|
||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
|
||||
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
||||
|
|
@ -102,7 +102,7 @@ void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
|
|||
// PThreadPool::run() is blocking. A std::function [const] reference to
|
||||
// this lambda cannot go out of scope before PThreadPool::run() returns.
|
||||
[&fn](const size_t task_id) {
|
||||
fn(0 /* unused */, task_id);
|
||||
fn(task_id);
|
||||
}, range);
|
||||
#endif // C10_MOBILE
|
||||
}
|
||||
|
|
@ -113,6 +113,10 @@ struct ParallelRegionGuard {
|
|||
internal::set_thread_num(task_id);
|
||||
_set_in_parallel_region(true);
|
||||
}
|
||||
ParallelRegionGuard(const ParallelRegionGuard&) = delete;
|
||||
ParallelRegionGuard(ParallelRegionGuard&&) = delete;
|
||||
ParallelRegionGuard& operator=(const ParallelRegionGuard&) = delete;
|
||||
ParallelRegionGuard& operator=(ParallelRegionGuard&&) = delete;
|
||||
|
||||
~ParallelRegionGuard() {
|
||||
_set_in_parallel_region(false);
|
||||
|
|
@ -124,16 +128,16 @@ struct ParallelRegionGuard {
|
|||
|
||||
namespace internal {
|
||||
|
||||
inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
|
||||
static std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
|
||||
int64_t begin, int64_t end, int64_t grain_size) {
|
||||
if ((end - begin) < grain_size) {
|
||||
return std::make_tuple(1, std::max((int64_t)0, end - begin));
|
||||
}
|
||||
// Choose number of tasks based on grain size and number of threads.
|
||||
size_t chunk_size = divup((end - begin), get_num_threads());
|
||||
int64_t chunk_size = divup((end - begin), get_num_threads());
|
||||
// Make sure each task is at least grain_size size.
|
||||
chunk_size = std::max((size_t)grain_size, chunk_size);
|
||||
size_t num_tasks = divup((end - begin), chunk_size);
|
||||
chunk_size = std::max(grain_size, chunk_size);
|
||||
size_t num_tasks = static_cast<size_t>(divup((end - begin), chunk_size));
|
||||
return std::make_tuple(num_tasks, chunk_size);
|
||||
}
|
||||
|
||||
|
|
@ -157,12 +161,12 @@ void invoke_parallel(
|
|||
} state;
|
||||
|
||||
auto task = [f, &state, begin, end, chunk_size]
|
||||
(int /* unused */, size_t task_id) {
|
||||
int64_t local_start = begin + task_id * chunk_size;
|
||||
(size_t task_id) {
|
||||
int64_t local_start = static_cast<int64_t>(begin + task_id * chunk_size);
|
||||
if (local_start < end) {
|
||||
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
|
||||
int64_t local_end = std::min(end, static_cast<int64_t>(chunk_size + local_start));
|
||||
try {
|
||||
ParallelRegionGuard guard(task_id);
|
||||
ParallelRegionGuard guard(static_cast<int>(task_id));
|
||||
f(local_start, local_end);
|
||||
} catch (...) {
|
||||
if (!state.err_flag.test_and_set()) {
|
||||
|
|
|
|||
|
|
@ -284,6 +284,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
|
|||
}
|
||||
template <typename T>
|
||||
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
|
||||
// NOLINTNEXTLINE(bugprone-sizeof-expression)
|
||||
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value)));
|
||||
}
|
||||
};
|
||||
|
|
@ -392,7 +393,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
|||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
|
||||
#endif
|
||||
|
||||
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
|
|
@ -901,12 +902,10 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
|||
#else
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 5) {
|
||||
#ifndef USE_ROCM
|
||||
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
}
|
||||
#endif
|
||||
// Disallow fp16 reductions that could lead to unexpected overflow issues.
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
|
||||
TORCH_CUDABLAS_CHECK(cublasGemmEx(
|
||||
|
|
@ -1284,7 +1283,7 @@ void gemm_and_bias(
|
|||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
|
||||
#endif
|
||||
|
||||
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
|
|
@ -1466,7 +1465,7 @@ void scaled_gemm(
|
|||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||
}
|
||||
size_t workspaceSize = _getWorkspaceSize();
|
||||
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
CuBlasLtMatmulPreference preference;
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type) {
|
|||
}
|
||||
}
|
||||
|
||||
#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||
cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch_offset, bool is_const=false) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == kStrided);
|
||||
IntArrayRef input_strides = input.strides();
|
||||
|
|
@ -121,7 +120,6 @@ CuSparseDnMatDescriptor::CuSparseDnMatDescriptor(const Tensor& input, int64_t ba
|
|||
CuSparseConstDnMatDescriptor::CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset) {
|
||||
descriptor_.reset(createRawDnMatDescriptor(input, batch_offset, /*is_const*/true));
|
||||
}
|
||||
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||
|
||||
CuSparseDnVecDescriptor::CuSparseDnVecDescriptor(const Tensor& input) {
|
||||
// cuSPARSE doesn't support batched vectors
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
|||
|
||||
|
||||
virtual void recordMemoryHistory(
|
||||
std::optional<std::string> enabled,
|
||||
const std::optional<std::string>& enabled,
|
||||
const std::string& stacks,
|
||||
size_t max_entries) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
|
|
|
|||
|
|
@ -162,6 +162,7 @@ grid_sample_backward_helper_in(
|
|||
|
||||
static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
|
||||
grid_sample_backward_helper_out(
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::tuple<Tensor, Tensor> bw_out,
|
||||
int64_t grad_input_out_bdim,
|
||||
int64_t grad_grid_out_bdim,
|
||||
|
|
@ -261,7 +262,7 @@ struct UpsampleBackwardBatchRuleHelper<F, Func, typelist<A, B, C, T...>> {
|
|||
|
||||
auto out = Func(
|
||||
std::move(grad_output_),
|
||||
std::move(output_size),
|
||||
output_size,
|
||||
std::move(physical_input_size),
|
||||
std::forward<T>(extra_args)...);
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
|
|
|
|||
|
|
@ -579,7 +579,7 @@ static void _rrelu_with_noise_train(
|
|||
Tensor& noise,
|
||||
const Scalar& lower_,
|
||||
const Scalar& upper_,
|
||||
std::optional<Generator> generator) {
|
||||
const std::optional<Generator>& generator) {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t lower = lower_.to<opmath_t>();
|
||||
opmath_t upper = upper_.to<opmath_t>();
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@
|
|||
#include <ATen/ops/mm_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native::xpu {
|
||||
namespace at::native {
|
||||
namespace xpu {
|
||||
|
||||
// result = beta * self + alpha * (mat1 * mat2)
|
||||
Tensor& addmm_out(
|
||||
|
|
@ -454,7 +455,7 @@ Tensor& tensordot_out(
|
|||
TORCH_LIBRARY_IMPL(aten, XPU, m) {
|
||||
m.impl("tensordot.out", TORCH_FN(tensordot_out));
|
||||
}
|
||||
} // namespace at::native::xpu
|
||||
} // namespace xpu
|
||||
|
||||
TORCH_IMPL_FUNC(addmm_out_xpu)
|
||||
(const Tensor& self,
|
||||
|
|
@ -469,11 +470,13 @@ TORCH_IMPL_FUNC(addmm_out_xpu)
|
|||
|
||||
TORCH_IMPL_FUNC(mm_out_xpu)
|
||||
(const Tensor& self, const Tensor& mat2, const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::mm_out(self, mat2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(bmm_out_xpu)
|
||||
(const Tensor& self, const Tensor& batch2, const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::bmm_out(self, batch2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
|
|
@ -498,7 +501,13 @@ TORCH_IMPL_FUNC(baddbmm_out_xpu)
|
|||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
xpu::baddbmm_out(
|
||||
self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result));
|
||||
self,
|
||||
batch1,
|
||||
batch2,
|
||||
beta,
|
||||
alpha,
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmv_out_xpu)
|
||||
|
|
@ -508,5 +517,8 @@ TORCH_IMPL_FUNC(addmv_out_xpu)
|
|||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::addmv_out(self, mat, vec, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ PyObject* THPGenerator_initDefaultGenerator(const at::Generator& cdata) {
|
|||
if (!self)
|
||||
throw python_error();
|
||||
auto self_ = reinterpret_cast<THPGenerator*>(self.get());
|
||||
self_->cdata = std::move(cdata);
|
||||
self_->cdata = cdata;
|
||||
return self.release();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -118,8 +118,7 @@ struct TensorDataContainer {
|
|||
type_(TensorDataContainerType::InitList) {}
|
||||
#define TENSOR(T, S) \
|
||||
TensorDataContainer(T value) \
|
||||
: sizes_(), \
|
||||
scalar_type_(at::k##S), \
|
||||
: scalar_type_(at::k##S), \
|
||||
type_(TensorDataContainerType::Scalar), \
|
||||
scalar_(value) {}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ namespace torch::autograd {
|
|||
// NOTE: this function is written in a way that assumes it's only called for
|
||||
// backward; it's used by engine.cpp. This is responsible for forwarding a call
|
||||
// from C++'s Node::apply to a Python method "apply".
|
||||
// NOLINTNEXTLINE(*-rvalue-reference*)
|
||||
auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::OptionalDeviceGuard _device_guard;
|
||||
|
|
@ -184,7 +185,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
|||
}
|
||||
|
||||
auto PyNode::defer_to_dynamo(
|
||||
variable_list&& inputs,
|
||||
const variable_list& inputs,
|
||||
const std::optional<PyObject*>& compiler) -> variable_list {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::OptionalDeviceGuard _device_guard;
|
||||
|
|
@ -526,7 +527,7 @@ static void THPFunction_dealloc(THPFunction* self) {
|
|||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
PyObject* THPFunction_new(
|
||||
static PyObject* THPFunction_new(
|
||||
PyTypeObject* type,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
|
|
@ -875,6 +876,7 @@ struct InputFlags {
|
|||
std::vector<bool> is_variable_input;
|
||||
};
|
||||
|
||||
namespace {
|
||||
template <bool enforce_variables>
|
||||
std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
|
||||
UnpackedInput unpacked;
|
||||
|
|
@ -938,7 +940,7 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
|
|||
// value is assigned by the prim::PythonOp node and helps to eventually route
|
||||
// the outputs of the subgraph correctly This newly created subgraph is then
|
||||
// added to the prim::PythonOp node as a subgraph attribute
|
||||
static void _append_subgraph(
|
||||
void _append_subgraph(
|
||||
torch::jit::Node* node,
|
||||
torch::jit::Graph* graph,
|
||||
std::vector<torch::jit::Value*> trace_outputs,
|
||||
|
|
@ -980,7 +982,7 @@ static void _append_subgraph(
|
|||
}
|
||||
}
|
||||
|
||||
static torch::jit::Node* _trace_pre_record(
|
||||
torch::jit::Node* _trace_pre_record(
|
||||
PyObject* op_obj,
|
||||
PyObject* input_objects,
|
||||
const variable_list& input_vars) {
|
||||
|
|
@ -1011,7 +1013,7 @@ static torch::jit::Node* _trace_pre_record(
|
|||
std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
|
||||
}
|
||||
|
||||
static void _trace_post_record(
|
||||
void _trace_post_record(
|
||||
torch::jit::Node* node,
|
||||
PyObject* op_obj,
|
||||
const variable_list& input_vars,
|
||||
|
|
@ -1218,8 +1220,6 @@ PyObject* THPFunction_maybe_clear_saved_tensors(
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
THPObjectPtr make_ctx_input_tuple(
|
||||
THPFunction* ctx,
|
||||
const UnpackedInput& unpacked_input,
|
||||
|
|
@ -1253,8 +1253,6 @@ THPObjectPtr make_ctx_input_output_tuple(
|
|||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static PyObject* THPFunction_setup_context = nullptr;
|
||||
|
||||
static PyObject* get_base_setup_context() {
|
||||
|
|
@ -1652,6 +1650,7 @@ PyObject* THPFunction_metadata(THPFunction* self, void* _unused) {
|
|||
return metadata;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using getter = PyObject* (*)(PyObject*, void*);
|
||||
using setter = int (*)(PyObject*, PyObject*, void*);
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ struct PyNode : public Node {
|
|||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list defer_to_dynamo(
|
||||
variable_list&& inputs,
|
||||
const variable_list& inputs,
|
||||
const std::optional<PyObject*>& compiler);
|
||||
|
||||
void release_variables() override;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||
|
|
@ -29,8 +28,7 @@ int device_count = 0;
|
|||
|
||||
void custom_raw_deleter(void* ptr);
|
||||
|
||||
_AllocationMetadata::_AllocationMetadata()
|
||||
: size(0), device_idx(-1), stream{} {}
|
||||
_AllocationMetadata::_AllocationMetadata() : size(0), device_idx(-1) {}
|
||||
|
||||
_AllocationMetadata::_AllocationMetadata(
|
||||
size_t size,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext {
|
|||
void* data_;
|
||||
size_t size_;
|
||||
int device_;
|
||||
cudaStream_t stream_;
|
||||
cudaStream_t stream_{};
|
||||
};
|
||||
|
||||
#if defined(TORCH_HIP_VERSION)
|
||||
|
|
@ -63,7 +63,7 @@ struct _AllocationMetadata {
|
|||
cudaStream_t stream);
|
||||
size_t size;
|
||||
c10::DeviceIndex device_idx;
|
||||
cudaStream_t stream;
|
||||
cudaStream_t stream{};
|
||||
};
|
||||
|
||||
struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ class DetectorMap {
|
|||
public:
|
||||
DetectorMap(const DetectorMap&) = delete;
|
||||
DetectorMap& operator=(const DetectorMap&) = delete;
|
||||
DetectorMap(DetectorMap&&) = delete;
|
||||
DetectorMap& operator=(DetectorMap&&) = delete;
|
||||
~DetectorMap() = default;
|
||||
static DetectorMap& get() {
|
||||
static DetectorMap instance;
|
||||
return instance;
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ class Lock {
|
|||
|
||||
Lock(const Lock& that) = delete;
|
||||
|
||||
Lock& operator=(const Lock& other) = delete;
|
||||
Lock& operator=(Lock&& other) noexcept {
|
||||
if (this != &other) {
|
||||
fd_ = other.fd_;
|
||||
|
|
|
|||
|
|
@ -512,8 +512,7 @@ void TCPStoreMasterDaemon::run() {
|
|||
tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
|
||||
|
||||
// receive the queries
|
||||
bool finished = false;
|
||||
while (!finished) {
|
||||
while (true) {
|
||||
for (const auto i : c10::irange(sockets_.size())) {
|
||||
fds[i].revents = 0;
|
||||
}
|
||||
|
|
@ -524,7 +523,6 @@ void TCPStoreMasterDaemon::run() {
|
|||
if (res == 0) {
|
||||
auto rv = WaitForSingleObject(ghStopEvent_, 0);
|
||||
if (rv != WAIT_TIMEOUT) {
|
||||
finished = true;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
|
|
@ -567,8 +565,7 @@ void TCPStoreMasterDaemon::run() {
|
|||
tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
|
||||
|
||||
// receive the queries
|
||||
bool finished = false;
|
||||
while (!finished) {
|
||||
while (true) {
|
||||
for (const auto i : c10::irange(sockets_.size())) {
|
||||
fds[i].revents = 0;
|
||||
}
|
||||
|
|
@ -602,7 +599,6 @@ void TCPStoreMasterDaemon::run() {
|
|||
"Unexpected poll revent on the control pipe's reading fd: " +
|
||||
std::to_string(fds[1].revents));
|
||||
}
|
||||
finished = true;
|
||||
break;
|
||||
}
|
||||
queryFds(fds);
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class RequestImpl : public Request {
|
|||
}
|
||||
|
||||
private:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const httplib::Request& req_;
|
||||
};
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ class ResponseImpl : public Response {
|
|||
}
|
||||
|
||||
private:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
httplib::Response& res_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@ class TORCH_API Message final : public torch::CustomClassHolder {
|
|||
Message(Message&& other) = delete;
|
||||
Message& operator=(Message const& rhs) = delete;
|
||||
Message& operator=(Message&& rhs) = delete;
|
||||
~Message() override = default;
|
||||
|
||||
// Destructively retrieves the payload.
|
||||
std::vector<char>&& movePayload() &&;
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ struct TORCH_API GloballyUniqueId final {
|
|||
GloballyUniqueId(worker_id_t createdOn, local_id_t localId);
|
||||
GloballyUniqueId(const GloballyUniqueId& other) = default;
|
||||
GloballyUniqueId& operator=(const GloballyUniqueId& other) = delete;
|
||||
GloballyUniqueId(GloballyUniqueId&& other) = default;
|
||||
GloballyUniqueId& operator=(GloballyUniqueId&& other) = delete;
|
||||
~GloballyUniqueId() = default;
|
||||
|
||||
bool operator==(const GloballyUniqueId& other) const;
|
||||
bool operator!=(const GloballyUniqueId& other) const;
|
||||
|
|
|
|||
|
|
@ -73,8 +73,7 @@ static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
|
|||
static void throw_python_error() {
|
||||
python_error err;
|
||||
err.persist();
|
||||
// NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference)
|
||||
throw err;
|
||||
throw std::move(err);
|
||||
}
|
||||
|
||||
static PyObject* check(PyObject* pyresult) {
|
||||
|
|
@ -109,20 +108,21 @@ struct PythonLogger {
|
|||
|
||||
// must be called while GIL is held
|
||||
void log(Level level, std::string_view msg) const {
|
||||
THPObjectPtr pymethod(PyUnicode_FromString(levelNames_[level].data()));
|
||||
THPObjectPtr pymethod(PyUnicode_FromString(levelNames_[level]));
|
||||
TORCH_INTERNAL_ASSERT(pymethod != nullptr);
|
||||
THPObjectPtr pyfunc(PyObject_GetAttr(logger_, pymethod.get()));
|
||||
if (pyfunc == nullptr) {
|
||||
throw_python_error();
|
||||
}
|
||||
PyObject* result = PyObject_CallFunction(pyfunc.get(), "s", msg.data());
|
||||
PyObject* result =
|
||||
PyObject_CallFunction(pyfunc.get(), "s", std::string(msg).c_str());
|
||||
if (result == nullptr) {
|
||||
throw_python_error();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr std::array<std::string_view, COUNT> levelNames_ = {
|
||||
static constexpr std::array<const char*, COUNT> levelNames_ = {
|
||||
"debug", // Level::DEBUG
|
||||
"info", // Level::INFO
|
||||
"warning", // Level::WARNING
|
||||
|
|
@ -421,7 +421,7 @@ static struct PyModuleDef _module = {
|
|||
-1,
|
||||
_methods};
|
||||
|
||||
PyObject* wrap_lifted_ivalue_args(
|
||||
static PyObject* wrap_lifted_ivalue_args(
|
||||
const std::vector<LiftedIValueArg>& lifted_ivalue_args) {
|
||||
PyObject* pyivalueargs =
|
||||
PyList_New(static_cast<Py_ssize_t>(lifted_ivalue_args.size()));
|
||||
|
|
@ -440,7 +440,7 @@ PyObject* wrap_lifted_ivalue_args(
|
|||
return pyivalueargs;
|
||||
}
|
||||
|
||||
PyObject* wrap_node_origins(
|
||||
static PyObject* wrap_node_origins(
|
||||
const AutogradCompilerCall& compiler,
|
||||
size_t dynamic_sizes) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
|
|
@ -475,7 +475,7 @@ PyObject* wrap_node_origins(
|
|||
return pyallorigins;
|
||||
}
|
||||
|
||||
void set_ivalue_proxies(
|
||||
static void set_ivalue_proxies(
|
||||
PyObject* fake_ivalue_args,
|
||||
std::vector<LiftedIValueArg>& lifted_ivalue_args) {
|
||||
TORCH_INTERNAL_ASSERT(PyList_Check(fake_ivalue_args));
|
||||
|
|
@ -569,7 +569,7 @@ static SizeInput::DynType get_default_dyn_type() {
|
|||
}
|
||||
|
||||
// Only call this function while holding GIL
|
||||
CacheNode* _compiled_autograd_impl(
|
||||
static CacheNode* _compiled_autograd_impl(
|
||||
const std::shared_ptr<Node>& graph_root,
|
||||
GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
|
|
@ -808,10 +808,11 @@ struct LockGuardWithErrorLogs {
|
|||
mtx_.unlock();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
std::mutex& mtx_;
|
||||
};
|
||||
|
||||
variable_list compiled_autograd(
|
||||
static variable_list compiled_autograd(
|
||||
const std::shared_ptr<Node>& graph_root,
|
||||
GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ class Cache {
|
|||
}
|
||||
|
||||
mutable std::mutex lock_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const size_t max_size_ = 0;
|
||||
ElementList element_list_;
|
||||
ElementMap element_map_;
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ class MaybeRef {
|
|||
|
||||
private:
|
||||
std::optional<T> storage_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const T& ref_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -90,11 +90,11 @@ void initModule(PyObject* module) {
|
|||
|
||||
m.def(
|
||||
"_mtia_recordMemoryHistory",
|
||||
[](std::optional<std::string> enabled,
|
||||
[](const std::optional<std::string>& enabled,
|
||||
const std::string& stacks,
|
||||
size_t max_entries) {
|
||||
at::detail::getMTIAHooks().recordMemoryHistory(
|
||||
std::move(enabled), stacks, max_entries);
|
||||
enabled, stacks, max_entries);
|
||||
});
|
||||
|
||||
m.def("_mtia_memorySnapshot", []() {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ template <typename T>
|
|||
struct strong_pointer_type_caster {
|
||||
template <typename T_>
|
||||
static handle cast(
|
||||
T_&& src,
|
||||
const T_& src,
|
||||
return_value_policy /*policy*/,
|
||||
handle /*parent*/) {
|
||||
const auto* ptr = reinterpret_cast<const void*>(src.value_of());
|
||||
|
|
@ -33,7 +33,7 @@ template <typename T>
|
|||
struct strong_uint_type_caster {
|
||||
template <typename T_>
|
||||
static handle cast(
|
||||
T_&& src,
|
||||
const T_& src,
|
||||
return_value_policy /*policy*/,
|
||||
handle /*parent*/) {
|
||||
return handle(THPUtils_packUInt64(src.value_of()));
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ struct DefaultStubs : public ProfilerStubs {
|
|||
TORCH_CHECK(false, name_, " used in profiler but not enabled.");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const char* const name_;
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ struct LineNumberProgram {
|
|||
uint64_t form;
|
||||
};
|
||||
std::vector<Member> directory_members;
|
||||
directory_members.reserve(directory_entry_format_count);
|
||||
for (size_t i = 0; i < directory_entry_format_count; i++) {
|
||||
directory_members.push_back({L.readULEB128(), L.readULEB128()});
|
||||
}
|
||||
|
|
@ -85,6 +86,7 @@ struct LineNumberProgram {
|
|||
}
|
||||
auto file_name_entry_format_count = L.read<uint8_t>();
|
||||
std::vector<Member> file_members;
|
||||
file_members.reserve(file_name_entry_format_count);
|
||||
for (size_t i = 0; i < file_name_entry_format_count; i++) {
|
||||
file_members.push_back({L.readULEB128(), L.readULEB128()});
|
||||
}
|
||||
|
|
@ -314,6 +316,7 @@ struct LineNumberProgram {
|
|||
uint64_t length_ = 0;
|
||||
bool is_64bit_ = false;
|
||||
std::vector<uint8_t> standard_opcode_lengths_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
Sections& s_;
|
||||
uint64_t offset_;
|
||||
uint64_t start_address_ = 0;
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ ScriptModuleOutput ScriptModuleBenchmark::runOnce(
|
|||
}
|
||||
|
||||
template <>
|
||||
// NOLINTNEXTLINE(*-rvalue-reference-param-not-moved)
|
||||
void ModuleBenchmark::runOnce(ModuleInput&& input) const {
|
||||
CHECK(initialized_);
|
||||
pybind11::gil_scoped_acquire gil_guard;
|
||||
|
|
@ -101,6 +102,7 @@ ModuleOutput ModuleBenchmark::runOnce(
|
|||
}
|
||||
|
||||
template <>
|
||||
// NOLINTNEXTLINE(*-rvalue-reference-param-not-moved)
|
||||
void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
|
||||
jit::Stack stack = jit::createStackForSchema(
|
||||
model_.get_method("forward").function().getSchema(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user