From 59b9eeff495fca1e52451760b97da61d8025f0ce Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 Dec 2016 23:18:05 +0100 Subject: [PATCH] Expose gather and equals for CUDA tensors --- test/test_cuda.py | 2 +- tools/cwrap/plugins/THPPlugin.py | 4 ++-- torch/csrc/generic/TensorMethods.cwrap | 6 ++++++ torch/csrc/generic/methods/Tensor.cwrap | 6 ++---- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index 2d85f740c4f..4bed1bd0958 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -139,7 +139,7 @@ tests = [ ('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ), ('ne', small_3d_ones, lambda t: [small_3d(t)], ), ('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ), - ('equal', small_3d_ones, lambda t: [small_3d_ones(t)], ), + ('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ), ('equal', small_3d_ones, lambda t: [small_3d(t)], ), ('expand', new_t(M, 1, M), lambda t: [M, 4, M], ), ('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)], ), diff --git a/tools/cwrap/plugins/THPPlugin.py b/tools/cwrap/plugins/THPPlugin.py index eb98142a971..5eddba8aecc 100644 --- a/tools/cwrap/plugins/THPPlugin.py +++ b/tools/cwrap/plugins/THPPlugin.py @@ -22,7 +22,7 @@ class THPPlugin(CWrapPlugin): 'void*': Template('THPUtils_unpackLong($arg)'), 'long': Template('THPUtils_unpackLong($arg)'), 'int': Template('THPUtils_unpackLong($arg)'), - 'bool': Template('THPUtils_unpackLong($arg)'), + 'bool': Template('($arg == Py_True ? true : false)'), 'float': Template('THPFloatUtils_unpackReal($arg)'), 'double': Template('THPDoubleUtils_unpackReal($arg)'), 'real': Template('THPUtils_(unpackReal)($arg)'), @@ -46,7 +46,7 @@ class THPPlugin(CWrapPlugin): 'void*': Template('THPUtils_checkLong($arg)'), 'long': Template('THPUtils_checkLong($arg)'), 'int': Template('THPUtils_checkLong($arg)'), - 'bool': Template('THPUtils_checkLong($arg)'), + 'bool': Template('(($arg == Py_True) || ($arg == Py_False))'), 'float': Template('THPFloatUtils_checkReal($arg)'), 'double': Template('THPDoubleUtils_checkReal($arg)'), 'real': Template('THPUtils_(checkReal)($arg)'), diff --git a/torch/csrc/generic/TensorMethods.cwrap b/torch/csrc/generic/TensorMethods.cwrap index b07f6b11d93..dc7fa978a08 100644 --- a/torch/csrc/generic/TensorMethods.cwrap +++ b/torch/csrc/generic/TensorMethods.cwrap @@ -28,9 +28,13 @@ #endif #if IS_CUDA +#define THIndexTensor THCudaLongTensor +#define THIndexTensor_(NAME) TH_CONCAT_2(THCudaLongTensor_,NAME) #define THPIndexTensor THCPLongTensor #define THPIndexTensorClass THCPLongTensorClass #else +#define THIndexTensor THLongTensor +#define THIndexTensor_(NAME) TH_CONCAT_2(THLongTensor_,NAME) #define THPIndexTensor THPLongTensor #define THPIndexTensorClass THPLongTensorClass #endif @@ -74,6 +78,8 @@ typedef THLongStorage THStride; #undef CUDA_FLOAT #undef CUDA_DOUBLE #undef CUDA_HALF +#undef THIndexTensor +#undef THIndexTensor_ #undef THPIndexTensor #undef THPIndexTensorClass #undef THPBoolTensor diff --git a/torch/csrc/generic/methods/Tensor.cwrap b/torch/csrc/generic/methods/Tensor.cwrap index 5cfa847855f..9e8da6b4856 100644 --- a/torch/csrc/generic/methods/Tensor.cwrap +++ b/torch/csrc/generic/methods/Tensor.cwrap @@ -581,18 +581,17 @@ PyObject * THPTensor_(stride)(PyObject *self, PyObject *args, PyObject *kwargs) [[ name: gather - defined_if: "!IS_CUDA" with_stateless: True return: argument 0 before_call: | - THLongStoragePtr _size = THLongTensor_newSizeOf(LIBRARY_STATE ((THPLongTensor*)$arg3)->cdata); + THLongStoragePtr _size = THIndexTensor_(newSizeOf)(LIBRARY_STATE ((THPIndexTensor*)$arg3)->cdata); THTensor_(resize)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, _size, NULL); arguments: - arg: THTensor* result allocate: True - THTensor* self - long dim - - THLongTensor* index + - THIndexTensor* index ]] [[ @@ -665,7 +664,6 @@ invalid_arguments: [[ name: equal - defined_if: "!IS_CUDA" with_stateless: True return: bool arguments: