Expose gather and equals for CUDA tensors

This commit is contained in:
Adam Paszke 2016-12-19 23:18:05 +01:00 committed by Soumith Chintala
parent e46d942ca6
commit 59b9eeff49
4 changed files with 11 additions and 7 deletions

View File

@ -139,7 +139,7 @@ tests = [
('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
('ne', small_3d_ones, lambda t: [small_3d(t)], ),
('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], ),
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
('equal', small_3d_ones, lambda t: [small_3d(t)], ),
('expand', new_t(M, 1, M), lambda t: [M, 4, M], ),
('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)], ),

View File

@ -22,7 +22,7 @@ class THPPlugin(CWrapPlugin):
'void*': Template('THPUtils_unpackLong($arg)'),
'long': Template('THPUtils_unpackLong($arg)'),
'int': Template('THPUtils_unpackLong($arg)'),
'bool': Template('THPUtils_unpackLong($arg)'),
'bool': Template('($arg == Py_True ? true : false)'),
'float': Template('THPFloatUtils_unpackReal($arg)'),
'double': Template('THPDoubleUtils_unpackReal($arg)'),
'real': Template('THPUtils_(unpackReal)($arg)'),
@ -46,7 +46,7 @@ class THPPlugin(CWrapPlugin):
'void*': Template('THPUtils_checkLong($arg)'),
'long': Template('THPUtils_checkLong($arg)'),
'int': Template('THPUtils_checkLong($arg)'),
'bool': Template('THPUtils_checkLong($arg)'),
'bool': Template('(($arg == Py_True) || ($arg == Py_False))'),
'float': Template('THPFloatUtils_checkReal($arg)'),
'double': Template('THPDoubleUtils_checkReal($arg)'),
'real': Template('THPUtils_(checkReal)($arg)'),

View File

@ -28,9 +28,13 @@
#endif
#if IS_CUDA
#define THIndexTensor THCudaLongTensor
#define THIndexTensor_(NAME) TH_CONCAT_2(THCudaLongTensor_,NAME)
#define THPIndexTensor THCPLongTensor
#define THPIndexTensorClass THCPLongTensorClass
#else
#define THIndexTensor THLongTensor
#define THIndexTensor_(NAME) TH_CONCAT_2(THLongTensor_,NAME)
#define THPIndexTensor THPLongTensor
#define THPIndexTensorClass THPLongTensorClass
#endif
@ -74,6 +78,8 @@ typedef THLongStorage THStride;
#undef CUDA_FLOAT
#undef CUDA_DOUBLE
#undef CUDA_HALF
#undef THIndexTensor
#undef THIndexTensor_
#undef THPIndexTensor
#undef THPIndexTensorClass
#undef THPBoolTensor

View File

@ -581,18 +581,17 @@ PyObject * THPTensor_(stride)(PyObject *self, PyObject *args, PyObject *kwargs)
[[
name: gather
defined_if: "!IS_CUDA"
with_stateless: True
return: argument 0
before_call: |
THLongStoragePtr _size = THLongTensor_newSizeOf(LIBRARY_STATE ((THPLongTensor*)$arg3)->cdata);
THLongStoragePtr _size = THIndexTensor_(newSizeOf)(LIBRARY_STATE ((THPIndexTensor*)$arg3)->cdata);
THTensor_(resize)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, _size, NULL);
arguments:
- arg: THTensor* result
allocate: True
- THTensor* self
- long dim
- THLongTensor* index
- THIndexTensor* index
]]
[[
@ -665,7 +664,6 @@ invalid_arguments:
[[
name: equal
defined_if: "!IS_CUDA"
with_stateless: True
return: bool
arguments: