#include #include #include #include #include #include #include #include "THCP.h" #include "override_macros.h" class THCPAutoGPU { public: THCPAutoGPU(PyObject *args, PyObject *self=NULL) { if (self && setDevice(self)) return; if (!args) return; for (int i = 0; i < PyTuple_Size(args); i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); if (setDevice(arg)) return; } } bool setDevice(PyObject *obj) { int new_device = -1; PyObject *obj_type = (PyObject*)Py_TYPE(obj); if (obj_type == THCPDoubleTensorClass) { new_device = THCudaDoubleTensor_getDevice(LIBRARY_STATE ((THCPDoubleTensor*)obj)->cdata); } else if (obj_type == THCPFloatTensorClass) { new_device = THCudaTensor_getDevice(LIBRARY_STATE ((THCPFloatTensor*)obj)->cdata); } else if (obj_type == THCPLongTensorClass) { new_device = THCudaLongTensor_getDevice(LIBRARY_STATE ((THCPLongTensor*)obj)->cdata); } else if (obj_type == THCPIntTensorClass) { new_device = THCudaIntTensor_getDevice(LIBRARY_STATE ((THCPIntTensor*)obj)->cdata); } else if (obj_type == THCPShortTensorClass) { new_device = THCudaShortTensor_getDevice(LIBRARY_STATE ((THCPShortTensor*)obj)->cdata); } else if (obj_type == THCPCharTensorClass) { new_device = THCudaCharTensor_getDevice(LIBRARY_STATE ((THCPCharTensor*)obj)->cdata); } else if (obj_type == THCPByteTensorClass) { new_device = THCudaByteTensor_getDevice(LIBRARY_STATE ((THCPByteTensor*)obj)->cdata); } if (new_device != -1) { THCudaCheck(cudaGetDevice(&device)); THCPModule_setDevice(new_device); return true; } return false; } // This can throw... But if it does I have no idea how to recover. ~THCPAutoGPU() { if (device != -1) THCPModule_setDevice(device); } int device = -1; }; #define THC_GENERIC_FILE "torch/csrc/generic/Tensor.cpp" #include #define THC_GENERIC_FILE "torch/csrc/generic/TensorCopy.cpp" #include