mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Following up on this: https://github.com/pytorch/pytorch/pull/35851 cross dtype storage copy is not being used internally, so I have not included cross dtype copy for complex. Pull Request resolved: https://github.com/pytorch/pytorch/pull/35771 Differential Revision: D21319650 Pulled By: anjali411 fbshipit-source-id: 07c72996ee598eba0cf401ad61534494d6f5b5b3
50 lines
1.8 KiB
C++
50 lines
1.8 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
#include <stdarg.h>
|
|
#include <string>
|
|
#include <torch/csrc/cuda/THCP.h>
|
|
|
|
#include <torch/csrc/cuda/override_macros.h>
|
|
|
|
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
|
#include <THC/THCGenerateAllTypes.h>
|
|
|
|
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
|
#include <THC/THCGenerateComplexTypes.h>
|
|
|
|
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
|
#include <THC/THCGenerateBoolType.h>
|
|
|
|
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
|
#include <THC/THCGenerateBFloat16Type.h>
|
|
|
|
#ifdef USE_CUDA
|
|
// NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
|
|
// whatever the current stream of the device the input is associated with was.
|
|
std::vector<c10::optional<at::cuda::CUDAStream>> THPUtils_PySequence_to_CUDAStreamList(PyObject *obj) {
|
|
if (!PySequence_Check(obj)) {
|
|
throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
|
|
}
|
|
THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
|
|
if (seq.get() == nullptr) {
|
|
throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj)));
|
|
}
|
|
|
|
std::vector<c10::optional<at::cuda::CUDAStream>> streams;
|
|
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
|
|
for (Py_ssize_t i = 0; i < length; i++) {
|
|
PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i);
|
|
|
|
if (PyObject_IsInstance(stream, THCPStreamClass)) {
|
|
// Spicy hot reinterpret cast!!
|
|
streams.emplace_back( at::cuda::CUDAStream::unpack((reinterpret_cast<THCPStream*>(stream))->cdata) );
|
|
} else if (stream == Py_None) {
|
|
streams.emplace_back();
|
|
} else {
|
|
std::runtime_error("Unknown data type found in stream list. Need torch.cuda.Stream or None");
|
|
}
|
|
}
|
|
return streams;
|
|
}
|
|
|
|
#endif
|