#include #include #include #include #include #include // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use // whatever the current stream of the device the input is associated with was. std::vector> THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) { TORCH_CHECK( PySequence_Check(obj), "Expected a sequence in THPUtils_PySequence_to_CUDAStreamList"); THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr)); TORCH_CHECK( seq.get() != nullptr, "expected PySequence, but got " + std::string(THPUtils_typename(obj))); std::vector> streams; Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); streams.reserve(length); for (Py_ssize_t i = 0; i < length; i++) { PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i); if (PyObject_IsInstance(stream, (PyObject*)THPStreamClass)) { // Spicy hot reinterpret cast!! streams.emplace_back(at::cuda::CUDAStream::unpack3( (reinterpret_cast(stream))->stream_id, static_cast( reinterpret_cast(stream)->device_index), static_cast( (reinterpret_cast(stream))->device_type))); } else if (stream == Py_None) { streams.emplace_back(); } else { TORCH_CHECK( false, "Unknown data type found in stream list. Need torch.cuda.Stream or None"); } } return streams; }