[Clang-tidy header][17/N] Apply clang-tidy on headers in torch/csrc/cuda (#117829)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117829
Approved by: https://github.com/albanD
This commit is contained in:
cyy 2024-01-26 13:33:24 +00:00 committed by PyTorch MergeBot
parent 8ff55c7e68
commit 6da0e7f84b
10 changed files with 45 additions and 40 deletions

View File

@ -100,6 +100,7 @@ c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) const {
cudaStream_t stream = cudaStream_t stream =
c10::cuda::getCurrentCUDAStream(static_cast<c10::DeviceIndex>(device)); c10::cuda::getCurrentCUDAStream(static_cast<c10::DeviceIndex>(device));
void* r = void* r =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<CUDAPluggableAllocator*>(this)->malloc(size, device, stream); const_cast<CUDAPluggableAllocator*>(this)->malloc(size, device, stream);
c10::DataPtr data_ptr = { c10::DataPtr data_ptr = {
r, r,

View File

@ -80,17 +80,13 @@ static void poison_fork() {
// CUDA management methods // CUDA management methods
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void THCPModule_setDevice(int device) {
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
}
PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) { PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice"); TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
int64_t device = THPUtils_unpackLong(arg); auto device = THPUtils_unpackLong(arg);
torch::utils::cuda_lazy_init(); torch::utils::cuda_lazy_init();
THCPModule_setDevice(device); c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
Py_RETURN_NONE; Py_RETURN_NONE;
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
@ -259,6 +255,7 @@ PyObject* THCPModule_setStream_wrap(
args, args,
kwargs, kwargs,
"|LLL", "|LLL",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(kwlist), const_cast<char**>(kwlist),
&stream_id, &stream_id,
&device_index, &device_index,
@ -266,11 +263,13 @@ PyObject* THCPModule_setStream_wrap(
} }
auto stream = at::cuda::CUDAStream::unpack3( auto stream = at::cuda::CUDAStream::unpack3(
stream_id, device_index, static_cast<c10::DeviceType>(device_type)); stream_id,
static_cast<c10::DeviceIndex>(device_index),
static_cast<c10::DeviceType>(device_type));
auto device = c10::cuda::current_device(); auto device = c10::cuda::current_device();
if (device != stream.device_index()) { if (device != stream.device_index()) {
THCPModule_setDevice(stream.device_index()); c10::cuda::set_device(stream.device_index());
} }
at::cuda::setCurrentCUDAStream(stream); at::cuda::setCurrentCUDAStream(stream);
Py_RETURN_NONE; Py_RETURN_NONE;
@ -926,7 +925,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
static_cast<void (*)( static_cast<void (*)(
c10::optional<std::string>, c10::optional<std::string>,
c10::optional<std::string>, c10::optional<std::string>,
std::string, const std::string&,
size_t)>(torch::cuda::_record_memory_history)); size_t)>(torch::cuda::_record_memory_history));
m.def("_cuda_isHistoryEnabled", []() { m.def("_cuda_isHistoryEnabled", []() {
@ -1211,7 +1210,7 @@ static void registerCudaPluggableAllocator(PyObject* module) {
} }
} }
auto delta = c10::cuda::CUDACachingAllocator::setCheckpointPoolState( auto delta = c10::cuda::CUDACachingAllocator::setCheckpointPoolState(
device, pps); device, std::move(pps));
auto& freed_pointers = delta.ptrs_freed; auto& freed_pointers = delta.ptrs_freed;
std::unordered_set<void*> allocd_set; std::unordered_set<void*> allocd_set;

View File

@ -1,7 +1,6 @@
#ifndef THCP_CUDA_MODULE_INC #ifndef THCP_CUDA_MODULE_INC
#define THCP_CUDA_MODULE_INC #define THCP_CUDA_MODULE_INC
void THCPModule_setDevice(int idx);
PyObject* THCPModule_getDevice_wrap(PyObject* self); PyObject* THCPModule_getDevice_wrap(PyObject* self);
PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg); PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg);
PyObject* THCPModule_getDeviceName_wrap(PyObject* self, PyObject* arg); PyObject* THCPModule_getDeviceName_wrap(PyObject* self, PyObject* arg);

View File

@ -5,6 +5,7 @@
#include <torch/csrc/Stream.h> #include <torch/csrc/Stream.h>
#include <torch/csrc/python_headers.h> #include <torch/csrc/python_headers.h>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct THCPStream : THPStream { struct THCPStream : THPStream {
at::cuda::CUDAStream cuda_stream; at::cuda::CUDAStream cuda_stream;
}; };

View File

@ -103,8 +103,9 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
if (device != tensor.get_device()) { if (device != tensor.get_device()) {
diff_device_dst_tensors.emplace_back(at::empty( diff_device_dst_tensors.emplace_back(at::empty(
tensor.sizes(), tensor.sizes(),
tensor.options().device( tensor.options().device(at::Device(
at::Device(DeviceType::CUDA, device)))); // preserve memory format DeviceType::CUDA,
static_cast<DeviceIndex>(device))))); // preserve memory format
} }
} }
_broadcast_out_impl(tensor, diff_device_dst_tensors); _broadcast_out_impl(tensor, diff_device_dst_tensors);
@ -178,7 +179,7 @@ tensor_list2d broadcast_coalesced(
o.reserve(tensors.size()); o.reserve(tensors.size());
unique_type_checker type_checker; unique_type_checker type_checker;
at::cuda::CUDAGuard device_guard(devices[0]); at::cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(devices[0]));
for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) { for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) {
auto type_id = chunk.type_id(); auto type_id = chunk.type_id();
type_checker.show(type_id); type_checker.show(type_id);
@ -189,7 +190,7 @@ tensor_list2d broadcast_coalesced(
auto broadcast_values = broadcast(flat_tuple.second, devices); auto broadcast_values = broadcast(flat_tuple.second, devices);
results.reserve(devices.size()); results.reserve(devices.size());
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) { for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
device_guard.set_index(devices[i]); device_guard.set_index(static_cast<DeviceIndex>(devices[i]));
auto& device_outputs = outputs[i]; auto& device_outputs = outputs[i];
auto& inds = broadcast_indices[i]; auto& inds = broadcast_indices[i];
auto& vals = broadcast_values[i]; auto& vals = broadcast_values[i];
@ -203,7 +204,7 @@ tensor_list2d broadcast_coalesced(
auto results = broadcast( auto results = broadcast(
torch::utils::flatten_dense_tensors(chunk.tensors), devices); torch::utils::flatten_dense_tensors(chunk.tensors), devices);
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) { for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
device_guard.set_index(devices[i]); device_guard.set_index(static_cast<DeviceIndex>(devices[i]));
auto& device_outputs = outputs[i]; auto& device_outputs = outputs[i];
for (auto& var : for (auto& var :
torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) { torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
@ -327,10 +328,10 @@ std::vector<at::Tensor> scatter(
chunk_sizes->size()); chunk_sizes->size());
} }
dim = at::maybe_wrap_dim(dim, tensor); dim = at::maybe_wrap_dim(dim, tensor);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<at::Tensor> chunks = chunk_sizes std::vector<at::Tensor> chunks = chunk_sizes
? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim) ? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim)
: tensor.chunk(/*chunks=*/devices.size(), /*dim=*/dim); : tensor.chunk(
/*chunks=*/static_cast<int64_t>(devices.size()), /*dim=*/dim);
at::cuda::OptionalCUDAStreamGuard cuda_guard; at::cuda::OptionalCUDAStreamGuard cuda_guard;
for (const auto i : c10::irange(chunks.size())) { for (const auto i : c10::irange(chunks.size())) {
const auto device_index = static_cast<int16_t>(devices[i]); const auto device_index = static_cast<int16_t>(devices[i]);
@ -494,7 +495,9 @@ at::Tensor gather(
at::Device device(DeviceType::CPU); at::Device device(DeviceType::CPU);
if (!destination_index || *destination_index != -1) { if (!destination_index || *destination_index != -1) {
device = at::Device( device = at::Device(
DeviceType::CUDA, destination_index ? *destination_index : -1); DeviceType::CUDA,
destination_index ? static_cast<DeviceIndex>(*destination_index)
: DeviceIndex(-1));
} }
at::Tensor result = at::Tensor result =

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <bitset> #include <bitset>
#include <cstddef>
namespace torch { namespace torch {

View File

@ -68,7 +68,7 @@ std::vector<IValue> ivalue_symbolize(
for (const auto& e : t) { for (const auto& e : t) {
l.push_back(all_frames.at(e)); l.push_back(all_frames.at(e));
} }
py_unique_frames.push_back(std::move(l)); py_unique_frames.emplace_back(std::move(l));
} }
std::vector<IValue> result; std::vector<IValue> result;
@ -132,7 +132,7 @@ static void checkOptionIn(
void _record_memory_history( void _record_memory_history(
c10::optional<std::string> enabled, c10::optional<std::string> enabled,
c10::optional<std::string> context, c10::optional<std::string> context,
std::string stacks, const std::string& stacks,
size_t max_entries) { size_t max_entries) {
if (enabled) { if (enabled) {
checkOptionIn( checkOptionIn(

View File

@ -18,7 +18,7 @@ TORCH_CUDA_CU_API void _record_memory_history(
TORCH_CUDA_CU_API void _record_memory_history( TORCH_CUDA_CU_API void _record_memory_history(
c10::optional<std::string> enabled = "all", c10::optional<std::string> enabled = "all",
c10::optional<std::string> context = "all", c10::optional<std::string> context = "all",
std::string stacks = "all", const std::string& stacks = "all",
size_t max_entries = UINT64_MAX); size_t max_entries = UINT64_MAX);
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled(); TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();

View File

@ -21,7 +21,7 @@ using namespace torch::cuda::nccl::detail;
static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator"; static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) { PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
return PyInt_FromLong(version()); return PyLong_FromUnsignedLongLong(version());
} }
PyObject* THCPModule_nccl_version_suffix(PyObject* self, PyObject* args) { PyObject* THCPModule_nccl_version_suffix(PyObject* self, PyObject* args) {
@ -99,10 +99,10 @@ static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) { PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
int nranks; int nranks = 0;
const char* id; const char* id = nullptr;
Py_ssize_t id_len; Py_ssize_t id_len = 0;
int rank; int rank = 0;
if (!PyArg_ParseTuple( if (!PyArg_ParseTuple(
args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) { args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
@ -118,7 +118,7 @@ PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
ncclUniqueId commId; ncclUniqueId commId;
memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES); memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
ncclComm_t comm; ncclComm_t comm = nullptr;
{ {
pybind11::gil_scoped_release no_gil; pybind11::gil_scoped_release no_gil;
comm = comm_init_rank(nranks, commId, rank); comm = comm_init_rank(nranks, commId, rank);
@ -129,8 +129,9 @@ PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) { PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
PyObject *_inputs, *_output, *_streams, *_comms; PyObject *_inputs = nullptr, *_output = nullptr, *_streams = nullptr,
int root, op; *_comms = nullptr;
int root = 0, op = 0;
if (!PyArg_ParseTuple( if (!PyArg_ParseTuple(
args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) { args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) {
@ -161,8 +162,9 @@ PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) { PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
PyObject *_inputs, *_outputs, *_streams, *_comms; PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
int op; *_comms = nullptr;
int op = 0;
if (!PyArg_ParseTuple( if (!PyArg_ParseTuple(
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) { args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
@ -193,8 +195,8 @@ PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) { PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
PyObject *_inputs, *_streams, *_comms; PyObject *_inputs = nullptr, *_streams = nullptr, *_comms = nullptr;
int root; int root = 0;
if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) { if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
THPUtils_invalidArguments( THPUtils_invalidArguments(
@ -224,7 +226,8 @@ PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) { PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
PyObject *_inputs, *_outputs, *_streams, *_comms; PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
*_comms = nullptr;
if (!PyArg_ParseTuple( if (!PyArg_ParseTuple(
args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) { args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
@ -255,8 +258,9 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) { PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
PyObject *_inputs, *_outputs, *_streams, *_comms; PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
int op; *_comms = nullptr;
int op = 0;
if (!PyArg_ParseTuple( if (!PyArg_ParseTuple(
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) { args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {

View File

@ -15,9 +15,6 @@
#include <torch/csrc/Stream.h> #include <torch/csrc/Stream.h>
#include <torch/csrc/utils/tensor_memoryformats.h> #include <torch/csrc/utils/tensor_memoryformats.h>
#include <stdexcept>
#include <utility>
namespace py = pybind11; namespace py = pybind11;
// This makes intrusive_ptr to be available as a custom pybind11 holder type, // This makes intrusive_ptr to be available as a custom pybind11 holder type,