pytorch/torch/csrc/cuda/utils.cpp
Richard Barnes ed327876f5 [codemod] c10:optional -> std::optional (#126135)
Generated by running the following from PyTorch root:
```
find . -regex ".*\.\(cpp\|h\|cu\|hpp\|cc\|cxx\)$" | grep -v "build/" | xargs -n 50 -P 4 perl -pi -e 's/c10::optional/std::optional/'
```

`c10::optional` is just an alias for `std::optional`. This removes usages of that alias in preparation for eliminating it entirely.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126135
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi
2024-05-14 19:35:51 +00:00

45 lines
1.6 KiB
C++

#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/python_headers.h>
#include <cstdarg>
#include <string>
#ifdef USE_CUDA
// NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
// whatever the current stream of the device the input is associated with was.
std::vector<std::optional<at::cuda::CUDAStream>>
THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) {
if (!PySequence_Check(obj)) {
throw std::runtime_error(
"Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
}
THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
if (seq.get() == nullptr) {
throw std::runtime_error(
"expected PySequence, but got " + std::string(THPUtils_typename(obj)));
}
std::vector<std::optional<at::cuda::CUDAStream>> streams;
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
for (Py_ssize_t i = 0; i < length; i++) {
PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i);
if (PyObject_IsInstance(stream, THCPStreamClass)) {
// Spicy hot reinterpret cast!!
streams.emplace_back(at::cuda::CUDAStream::unpack3(
(reinterpret_cast<THCPStream*>(stream))->stream_id,
(reinterpret_cast<THCPStream*>(stream))->device_index,
static_cast<c10::DeviceType>(
(reinterpret_cast<THCPStream*>(stream))->device_type)));
} else if (stream == Py_None) {
streams.emplace_back();
} else {
// NOLINTNEXTLINE(bugprone-throw-keyword-missing)
std::runtime_error(
"Unknown data type found in stream list. Need torch.cuda.Stream or None");
}
}
return streams;
}
#endif