mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: How did we get so many uses of `NULL` again? ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/11047 Differential Revision: D9566799 Pulled By: goldsborough fbshipit-source-id: 83469f352ac69aa65bdaf1a1a21f922d892e0db3
56 lines
1.4 KiB
C++
56 lines
1.4 KiB
C++
#ifndef THDP_COPY_UTILS_H
|
|
#define THDP_COPY_UTILS_H
|
|
|
|
extern THDTensorDescriptor* THDPModule_makeDescriptor(PyObject *obj);
|
|
template <typename TensorSrc>
|
|
void THDPInsertCopyFunctionFromWorker(
|
|
THPCopyList& copyList,
|
|
void (*copyFunc)(THDTensorDescriptor* x, TensorSrc *z))
|
|
{
|
|
auto wrapper = [copyFunc](PyObject* dst_, PyObject* src_) {
|
|
TensorSrc* src = THPTypeInfo<TensorSrc>::cdata(src_);
|
|
|
|
PyThreadState *_save = nullptr;
|
|
try {
|
|
Py_UNBLOCK_THREADS;
|
|
copyFunc(LIBRARY_STATE THDPModule_makeDescriptor(dst_), src);
|
|
Py_BLOCK_THREADS;
|
|
} catch (...) {
|
|
if (_save) {
|
|
Py_BLOCK_THREADS;
|
|
}
|
|
throw;
|
|
}
|
|
};
|
|
|
|
PyTypeObject* srcType = THPTypeInfo<TensorSrc>::pyType();
|
|
copyList.push_back({ srcType, wrapper, false });
|
|
}
|
|
|
|
template <typename TensorDst>
|
|
void THDPInsertCopyFunctionFromMaster(
|
|
THPCopyList& copyList,
|
|
void (*copyFunc)(TensorDst *x, THDTensorDescriptor* z),
|
|
PyTypeObject *srcType)
|
|
{
|
|
auto wrapper = [copyFunc](PyObject* dst_, PyObject* src_) {
|
|
TensorDst* dst = THPTypeInfo<TensorDst>::cdata(dst_);
|
|
|
|
PyThreadState *_save = nullptr;
|
|
try {
|
|
Py_UNBLOCK_THREADS;
|
|
copyFunc(LIBRARY_STATE dst, THDPModule_makeDescriptor(src_));
|
|
Py_BLOCK_THREADS;
|
|
} catch (...) {
|
|
if (_save) {
|
|
Py_BLOCK_THREADS;
|
|
}
|
|
throw;
|
|
}
|
|
};
|
|
|
|
copyList.push_back({ srcType, wrapper, false });
|
|
}
|
|
|
|
#endif
|