pytorch/torch/csrc/cuda/utils.cpp
Nikita Shulga eac02f85cf Fix more clang-tidy errors (#57235)
Summary:
In my last PR I've missed CUDA and distributed folders, fixing this now
This change is autogenerated by `python tool/clang_tidy.py -s`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57235

Reviewed By: janeyx99

Differential Revision: D28084444

Pulled By: malfet

fbshipit-source-id: bf222f69ee90c7872c3cb0931e8cdb84f0cb3cda
2021-04-28 23:29:10 -07:00

52 lines
1.9 KiB
C++

#include <torch/csrc/python_headers.h>
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <stdarg.h>
#include <string>
#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/cuda/override_macros.h>
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
#include <THC/THCGenerateAllTypes.h>
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
#include <THC/THCGenerateComplexTypes.h>
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
#include <THC/THCGenerateBoolType.h>
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
#include <THC/THCGenerateBFloat16Type.h>
#ifdef USE_CUDA
// NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
// whatever the current stream of the device the input is associated with was.
std::vector<c10::optional<at::cuda::CUDAStream>> THPUtils_PySequence_to_CUDAStreamList(PyObject *obj) {
if (!PySequence_Check(obj)) {
throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
}
THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
if (seq.get() == nullptr) {
throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj)));
}
std::vector<c10::optional<at::cuda::CUDAStream>> streams;
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
for (Py_ssize_t i = 0; i < length; i++) {
PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i);
if (PyObject_IsInstance(stream, THCPStreamClass)) {
// Spicy hot reinterpret cast!!
streams.emplace_back( at::cuda::CUDAStream::unpack((reinterpret_cast<THCPStream*>(stream))->cdata) );
} else if (stream == Py_None) {
streams.emplace_back();
} else {
// NOLINTNEXTLINE(bugprone-throw-keyword-missing)
std::runtime_error("Unknown data type found in stream list. Need torch.cuda.Stream or None");
}
}
return streams;
}
#endif