mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Clang-tidy header][17/N] Apply clang-tidy on headers in torch/csrc/cuda (#117829)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117829 Approved by: https://github.com/albanD
This commit is contained in:
parent
8ff55c7e68
commit
6da0e7f84b
|
|
@ -100,6 +100,7 @@ c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) const {
|
||||||
cudaStream_t stream =
|
cudaStream_t stream =
|
||||||
c10::cuda::getCurrentCUDAStream(static_cast<c10::DeviceIndex>(device));
|
c10::cuda::getCurrentCUDAStream(static_cast<c10::DeviceIndex>(device));
|
||||||
void* r =
|
void* r =
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
const_cast<CUDAPluggableAllocator*>(this)->malloc(size, device, stream);
|
const_cast<CUDAPluggableAllocator*>(this)->malloc(size, device, stream);
|
||||||
c10::DataPtr data_ptr = {
|
c10::DataPtr data_ptr = {
|
||||||
r,
|
r,
|
||||||
|
|
|
||||||
|
|
@ -80,17 +80,13 @@ static void poison_fork() {
|
||||||
// CUDA management methods
|
// CUDA management methods
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
void THCPModule_setDevice(int device) {
|
|
||||||
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
|
|
||||||
}
|
|
||||||
|
|
||||||
PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
|
PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
|
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
|
||||||
int64_t device = THPUtils_unpackLong(arg);
|
auto device = THPUtils_unpackLong(arg);
|
||||||
|
|
||||||
torch::utils::cuda_lazy_init();
|
torch::utils::cuda_lazy_init();
|
||||||
THCPModule_setDevice(device);
|
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
|
||||||
|
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
|
|
@ -259,6 +255,7 @@ PyObject* THCPModule_setStream_wrap(
|
||||||
args,
|
args,
|
||||||
kwargs,
|
kwargs,
|
||||||
"|LLL",
|
"|LLL",
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
const_cast<char**>(kwlist),
|
const_cast<char**>(kwlist),
|
||||||
&stream_id,
|
&stream_id,
|
||||||
&device_index,
|
&device_index,
|
||||||
|
|
@ -266,11 +263,13 @@ PyObject* THCPModule_setStream_wrap(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stream = at::cuda::CUDAStream::unpack3(
|
auto stream = at::cuda::CUDAStream::unpack3(
|
||||||
stream_id, device_index, static_cast<c10::DeviceType>(device_type));
|
stream_id,
|
||||||
|
static_cast<c10::DeviceIndex>(device_index),
|
||||||
|
static_cast<c10::DeviceType>(device_type));
|
||||||
|
|
||||||
auto device = c10::cuda::current_device();
|
auto device = c10::cuda::current_device();
|
||||||
if (device != stream.device_index()) {
|
if (device != stream.device_index()) {
|
||||||
THCPModule_setDevice(stream.device_index());
|
c10::cuda::set_device(stream.device_index());
|
||||||
}
|
}
|
||||||
at::cuda::setCurrentCUDAStream(stream);
|
at::cuda::setCurrentCUDAStream(stream);
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
|
|
@ -926,7 +925,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
|
||||||
static_cast<void (*)(
|
static_cast<void (*)(
|
||||||
c10::optional<std::string>,
|
c10::optional<std::string>,
|
||||||
c10::optional<std::string>,
|
c10::optional<std::string>,
|
||||||
std::string,
|
const std::string&,
|
||||||
size_t)>(torch::cuda::_record_memory_history));
|
size_t)>(torch::cuda::_record_memory_history));
|
||||||
|
|
||||||
m.def("_cuda_isHistoryEnabled", []() {
|
m.def("_cuda_isHistoryEnabled", []() {
|
||||||
|
|
@ -1211,7 +1210,7 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto delta = c10::cuda::CUDACachingAllocator::setCheckpointPoolState(
|
auto delta = c10::cuda::CUDACachingAllocator::setCheckpointPoolState(
|
||||||
device, pps);
|
device, std::move(pps));
|
||||||
auto& freed_pointers = delta.ptrs_freed;
|
auto& freed_pointers = delta.ptrs_freed;
|
||||||
|
|
||||||
std::unordered_set<void*> allocd_set;
|
std::unordered_set<void*> allocd_set;
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
#ifndef THCP_CUDA_MODULE_INC
|
#ifndef THCP_CUDA_MODULE_INC
|
||||||
#define THCP_CUDA_MODULE_INC
|
#define THCP_CUDA_MODULE_INC
|
||||||
|
|
||||||
void THCPModule_setDevice(int idx);
|
|
||||||
PyObject* THCPModule_getDevice_wrap(PyObject* self);
|
PyObject* THCPModule_getDevice_wrap(PyObject* self);
|
||||||
PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg);
|
PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg);
|
||||||
PyObject* THCPModule_getDeviceName_wrap(PyObject* self, PyObject* arg);
|
PyObject* THCPModule_getDeviceName_wrap(PyObject* self, PyObject* arg);
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <torch/csrc/Stream.h>
|
#include <torch/csrc/Stream.h>
|
||||||
#include <torch/csrc/python_headers.h>
|
#include <torch/csrc/python_headers.h>
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||||
struct THCPStream : THPStream {
|
struct THCPStream : THPStream {
|
||||||
at::cuda::CUDAStream cuda_stream;
|
at::cuda::CUDAStream cuda_stream;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -103,8 +103,9 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
|
||||||
if (device != tensor.get_device()) {
|
if (device != tensor.get_device()) {
|
||||||
diff_device_dst_tensors.emplace_back(at::empty(
|
diff_device_dst_tensors.emplace_back(at::empty(
|
||||||
tensor.sizes(),
|
tensor.sizes(),
|
||||||
tensor.options().device(
|
tensor.options().device(at::Device(
|
||||||
at::Device(DeviceType::CUDA, device)))); // preserve memory format
|
DeviceType::CUDA,
|
||||||
|
static_cast<DeviceIndex>(device))))); // preserve memory format
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_broadcast_out_impl(tensor, diff_device_dst_tensors);
|
_broadcast_out_impl(tensor, diff_device_dst_tensors);
|
||||||
|
|
@ -178,7 +179,7 @@ tensor_list2d broadcast_coalesced(
|
||||||
o.reserve(tensors.size());
|
o.reserve(tensors.size());
|
||||||
|
|
||||||
unique_type_checker type_checker;
|
unique_type_checker type_checker;
|
||||||
at::cuda::CUDAGuard device_guard(devices[0]);
|
at::cuda::CUDAGuard device_guard(static_cast<DeviceIndex>(devices[0]));
|
||||||
for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) {
|
for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) {
|
||||||
auto type_id = chunk.type_id();
|
auto type_id = chunk.type_id();
|
||||||
type_checker.show(type_id);
|
type_checker.show(type_id);
|
||||||
|
|
@ -189,7 +190,7 @@ tensor_list2d broadcast_coalesced(
|
||||||
auto broadcast_values = broadcast(flat_tuple.second, devices);
|
auto broadcast_values = broadcast(flat_tuple.second, devices);
|
||||||
results.reserve(devices.size());
|
results.reserve(devices.size());
|
||||||
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
|
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
|
||||||
device_guard.set_index(devices[i]);
|
device_guard.set_index(static_cast<DeviceIndex>(devices[i]));
|
||||||
auto& device_outputs = outputs[i];
|
auto& device_outputs = outputs[i];
|
||||||
auto& inds = broadcast_indices[i];
|
auto& inds = broadcast_indices[i];
|
||||||
auto& vals = broadcast_values[i];
|
auto& vals = broadcast_values[i];
|
||||||
|
|
@ -203,7 +204,7 @@ tensor_list2d broadcast_coalesced(
|
||||||
auto results = broadcast(
|
auto results = broadcast(
|
||||||
torch::utils::flatten_dense_tensors(chunk.tensors), devices);
|
torch::utils::flatten_dense_tensors(chunk.tensors), devices);
|
||||||
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
|
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
|
||||||
device_guard.set_index(devices[i]);
|
device_guard.set_index(static_cast<DeviceIndex>(devices[i]));
|
||||||
auto& device_outputs = outputs[i];
|
auto& device_outputs = outputs[i];
|
||||||
for (auto& var :
|
for (auto& var :
|
||||||
torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
|
torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
|
||||||
|
|
@ -327,10 +328,10 @@ std::vector<at::Tensor> scatter(
|
||||||
chunk_sizes->size());
|
chunk_sizes->size());
|
||||||
}
|
}
|
||||||
dim = at::maybe_wrap_dim(dim, tensor);
|
dim = at::maybe_wrap_dim(dim, tensor);
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
||||||
std::vector<at::Tensor> chunks = chunk_sizes
|
std::vector<at::Tensor> chunks = chunk_sizes
|
||||||
? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim)
|
? tensor.split_with_sizes(/*split_sizes=*/*chunk_sizes, /*dim=*/dim)
|
||||||
: tensor.chunk(/*chunks=*/devices.size(), /*dim=*/dim);
|
: tensor.chunk(
|
||||||
|
/*chunks=*/static_cast<int64_t>(devices.size()), /*dim=*/dim);
|
||||||
at::cuda::OptionalCUDAStreamGuard cuda_guard;
|
at::cuda::OptionalCUDAStreamGuard cuda_guard;
|
||||||
for (const auto i : c10::irange(chunks.size())) {
|
for (const auto i : c10::irange(chunks.size())) {
|
||||||
const auto device_index = static_cast<int16_t>(devices[i]);
|
const auto device_index = static_cast<int16_t>(devices[i]);
|
||||||
|
|
@ -494,7 +495,9 @@ at::Tensor gather(
|
||||||
at::Device device(DeviceType::CPU);
|
at::Device device(DeviceType::CPU);
|
||||||
if (!destination_index || *destination_index != -1) {
|
if (!destination_index || *destination_index != -1) {
|
||||||
device = at::Device(
|
device = at::Device(
|
||||||
DeviceType::CUDA, destination_index ? *destination_index : -1);
|
DeviceType::CUDA,
|
||||||
|
destination_index ? static_cast<DeviceIndex>(*destination_index)
|
||||||
|
: DeviceIndex(-1));
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor result =
|
at::Tensor result =
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <bitset>
|
#include <bitset>
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ std::vector<IValue> ivalue_symbolize(
|
||||||
for (const auto& e : t) {
|
for (const auto& e : t) {
|
||||||
l.push_back(all_frames.at(e));
|
l.push_back(all_frames.at(e));
|
||||||
}
|
}
|
||||||
py_unique_frames.push_back(std::move(l));
|
py_unique_frames.emplace_back(std::move(l));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<IValue> result;
|
std::vector<IValue> result;
|
||||||
|
|
@ -132,7 +132,7 @@ static void checkOptionIn(
|
||||||
void _record_memory_history(
|
void _record_memory_history(
|
||||||
c10::optional<std::string> enabled,
|
c10::optional<std::string> enabled,
|
||||||
c10::optional<std::string> context,
|
c10::optional<std::string> context,
|
||||||
std::string stacks,
|
const std::string& stacks,
|
||||||
size_t max_entries) {
|
size_t max_entries) {
|
||||||
if (enabled) {
|
if (enabled) {
|
||||||
checkOptionIn(
|
checkOptionIn(
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ TORCH_CUDA_CU_API void _record_memory_history(
|
||||||
TORCH_CUDA_CU_API void _record_memory_history(
|
TORCH_CUDA_CU_API void _record_memory_history(
|
||||||
c10::optional<std::string> enabled = "all",
|
c10::optional<std::string> enabled = "all",
|
||||||
c10::optional<std::string> context = "all",
|
c10::optional<std::string> context = "all",
|
||||||
std::string stacks = "all",
|
const std::string& stacks = "all",
|
||||||
size_t max_entries = UINT64_MAX);
|
size_t max_entries = UINT64_MAX);
|
||||||
|
|
||||||
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();
|
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ using namespace torch::cuda::nccl::detail;
|
||||||
static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
|
static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
|
||||||
|
|
||||||
PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
|
PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
|
||||||
return PyInt_FromLong(version());
|
return PyLong_FromUnsignedLongLong(version());
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* THCPModule_nccl_version_suffix(PyObject* self, PyObject* args) {
|
PyObject* THCPModule_nccl_version_suffix(PyObject* self, PyObject* args) {
|
||||||
|
|
@ -99,10 +99,10 @@ static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
|
||||||
|
|
||||||
PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
|
PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
int nranks;
|
int nranks = 0;
|
||||||
const char* id;
|
const char* id = nullptr;
|
||||||
Py_ssize_t id_len;
|
Py_ssize_t id_len = 0;
|
||||||
int rank;
|
int rank = 0;
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(
|
if (!PyArg_ParseTuple(
|
||||||
args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
|
args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
|
||||||
|
|
@ -118,7 +118,7 @@ PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
|
||||||
|
|
||||||
ncclUniqueId commId;
|
ncclUniqueId commId;
|
||||||
memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
|
memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
|
||||||
ncclComm_t comm;
|
ncclComm_t comm = nullptr;
|
||||||
{
|
{
|
||||||
pybind11::gil_scoped_release no_gil;
|
pybind11::gil_scoped_release no_gil;
|
||||||
comm = comm_init_rank(nranks, commId, rank);
|
comm = comm_init_rank(nranks, commId, rank);
|
||||||
|
|
@ -129,8 +129,9 @@ PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
|
||||||
|
|
||||||
PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
|
PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
PyObject *_inputs, *_output, *_streams, *_comms;
|
PyObject *_inputs = nullptr, *_output = nullptr, *_streams = nullptr,
|
||||||
int root, op;
|
*_comms = nullptr;
|
||||||
|
int root = 0, op = 0;
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(
|
if (!PyArg_ParseTuple(
|
||||||
args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) {
|
args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) {
|
||||||
|
|
@ -161,8 +162,9 @@ PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
|
||||||
|
|
||||||
PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
|
PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
PyObject *_inputs, *_outputs, *_streams, *_comms;
|
PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
|
||||||
int op;
|
*_comms = nullptr;
|
||||||
|
int op = 0;
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(
|
if (!PyArg_ParseTuple(
|
||||||
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
|
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
|
||||||
|
|
@ -193,8 +195,8 @@ PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
|
||||||
|
|
||||||
PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
|
PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
PyObject *_inputs, *_streams, *_comms;
|
PyObject *_inputs = nullptr, *_streams = nullptr, *_comms = nullptr;
|
||||||
int root;
|
int root = 0;
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
|
if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
|
||||||
THPUtils_invalidArguments(
|
THPUtils_invalidArguments(
|
||||||
|
|
@ -224,7 +226,8 @@ PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
|
||||||
|
|
||||||
PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
|
PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
PyObject *_inputs, *_outputs, *_streams, *_comms;
|
PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
|
||||||
|
*_comms = nullptr;
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(
|
if (!PyArg_ParseTuple(
|
||||||
args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
|
args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
|
||||||
|
|
@ -255,8 +258,9 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
|
||||||
|
|
||||||
PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
|
PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
PyObject *_inputs, *_outputs, *_streams, *_comms;
|
PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
|
||||||
int op;
|
*_comms = nullptr;
|
||||||
|
int op = 0;
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(
|
if (!PyArg_ParseTuple(
|
||||||
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
|
args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,6 @@
|
||||||
#include <torch/csrc/Stream.h>
|
#include <torch/csrc/Stream.h>
|
||||||
#include <torch/csrc/utils/tensor_memoryformats.h>
|
#include <torch/csrc/utils/tensor_memoryformats.h>
|
||||||
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
// This makes intrusive_ptr to be available as a custom pybind11 holder type,
|
// This makes intrusive_ptr to be available as a custom pybind11 holder type,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user