mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
68 lines
2.1 KiB
C++
68 lines
2.1 KiB
C++
#include <Python.h>
|
|
#include <structmember.h>
|
|
|
|
#include <TH/THMath.h>
|
|
#include <stdbool.h>
|
|
#include <vector>
|
|
#include <stack>
|
|
#include <tuple>
|
|
#include "THCP.h"
|
|
|
|
#include "override_macros.h"
|
|
|
|
class THCPAutoGPU {
|
|
public:
|
|
THCPAutoGPU(PyObject *args, PyObject *self=NULL) {
|
|
if (self && setDevice(self))
|
|
return;
|
|
|
|
if (!args)
|
|
return;
|
|
for (int i = 0; i < PyTuple_Size(args); i++) {
|
|
PyObject *arg = PyTuple_GET_ITEM(args, i);
|
|
if (setDevice(arg)) return;
|
|
}
|
|
}
|
|
|
|
bool setDevice(PyObject *obj) {
|
|
int new_device = -1;
|
|
PyObject *obj_type = (PyObject*)Py_TYPE(obj);
|
|
if (obj_type == THCPDoubleTensorClass) {
|
|
new_device = THCudaDoubleTensor_getDevice(LIBRARY_STATE ((THCPDoubleTensor*)obj)->cdata);
|
|
} else if (obj_type == THCPFloatTensorClass) {
|
|
new_device = THCudaTensor_getDevice(LIBRARY_STATE ((THCPFloatTensor*)obj)->cdata);
|
|
} else if (obj_type == THCPLongTensorClass) {
|
|
new_device = THCudaLongTensor_getDevice(LIBRARY_STATE ((THCPLongTensor*)obj)->cdata);
|
|
} else if (obj_type == THCPIntTensorClass) {
|
|
new_device = THCudaIntTensor_getDevice(LIBRARY_STATE ((THCPIntTensor*)obj)->cdata);
|
|
} else if (obj_type == THCPShortTensorClass) {
|
|
new_device = THCudaShortTensor_getDevice(LIBRARY_STATE ((THCPShortTensor*)obj)->cdata);
|
|
} else if (obj_type == THCPCharTensorClass) {
|
|
new_device = THCudaCharTensor_getDevice(LIBRARY_STATE ((THCPCharTensor*)obj)->cdata);
|
|
} else if (obj_type == THCPByteTensorClass) {
|
|
new_device = THCudaByteTensor_getDevice(LIBRARY_STATE ((THCPByteTensor*)obj)->cdata);
|
|
}
|
|
if (new_device != -1) {
|
|
THCudaCheck(cudaGetDevice(&device));
|
|
THCPModule_setDevice(new_device);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// This can throw... But if it does I have no idea how to recover.
|
|
~THCPAutoGPU() {
|
|
if (device != -1)
|
|
THCPModule_setDevice(device);
|
|
}
|
|
|
|
int device = -1;
|
|
};
|
|
|
|
#define THC_GENERIC_FILE "torch/csrc/generic/Tensor.cpp"
|
|
#include <THC/THCGenerateAllTypes.h>
|
|
|
|
#define THC_GENERIC_FILE "torch/csrc/generic/TensorCopy.cpp"
|
|
#include <THC/THCGenerateAllTypes.h>
|
|
|