mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Expose gather and equals for CUDA tensors
This commit is contained in:
parent
e46d942ca6
commit
59b9eeff49
|
|
@ -139,7 +139,7 @@ tests = [
|
|||
('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('ne', small_3d_ones, lambda t: [small_3d(t)], ),
|
||||
('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], ),
|
||||
('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal' ),
|
||||
('equal', small_3d_ones, lambda t: [small_3d(t)], ),
|
||||
('expand', new_t(M, 1, M), lambda t: [M, 4, M], ),
|
||||
('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)], ),
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class THPPlugin(CWrapPlugin):
|
|||
'void*': Template('THPUtils_unpackLong($arg)'),
|
||||
'long': Template('THPUtils_unpackLong($arg)'),
|
||||
'int': Template('THPUtils_unpackLong($arg)'),
|
||||
'bool': Template('THPUtils_unpackLong($arg)'),
|
||||
'bool': Template('($arg == Py_True ? true : false)'),
|
||||
'float': Template('THPFloatUtils_unpackReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_unpackReal($arg)'),
|
||||
'real': Template('THPUtils_(unpackReal)($arg)'),
|
||||
|
|
@ -46,7 +46,7 @@ class THPPlugin(CWrapPlugin):
|
|||
'void*': Template('THPUtils_checkLong($arg)'),
|
||||
'long': Template('THPUtils_checkLong($arg)'),
|
||||
'int': Template('THPUtils_checkLong($arg)'),
|
||||
'bool': Template('THPUtils_checkLong($arg)'),
|
||||
'bool': Template('(($arg == Py_True) || ($arg == Py_False))'),
|
||||
'float': Template('THPFloatUtils_checkReal($arg)'),
|
||||
'double': Template('THPDoubleUtils_checkReal($arg)'),
|
||||
'real': Template('THPUtils_(checkReal)($arg)'),
|
||||
|
|
|
|||
|
|
@ -28,9 +28,13 @@
|
|||
#endif
|
||||
|
||||
#if IS_CUDA
|
||||
#define THIndexTensor THCudaLongTensor
|
||||
#define THIndexTensor_(NAME) TH_CONCAT_2(THCudaLongTensor_,NAME)
|
||||
#define THPIndexTensor THCPLongTensor
|
||||
#define THPIndexTensorClass THCPLongTensorClass
|
||||
#else
|
||||
#define THIndexTensor THLongTensor
|
||||
#define THIndexTensor_(NAME) TH_CONCAT_2(THLongTensor_,NAME)
|
||||
#define THPIndexTensor THPLongTensor
|
||||
#define THPIndexTensorClass THPLongTensorClass
|
||||
#endif
|
||||
|
|
@ -74,6 +78,8 @@ typedef THLongStorage THStride;
|
|||
#undef CUDA_FLOAT
|
||||
#undef CUDA_DOUBLE
|
||||
#undef CUDA_HALF
|
||||
#undef THIndexTensor
|
||||
#undef THIndexTensor_
|
||||
#undef THPIndexTensor
|
||||
#undef THPIndexTensorClass
|
||||
#undef THPBoolTensor
|
||||
|
|
|
|||
|
|
@ -581,18 +581,17 @@ PyObject * THPTensor_(stride)(PyObject *self, PyObject *args, PyObject *kwargs)
|
|||
|
||||
[[
|
||||
name: gather
|
||||
defined_if: "!IS_CUDA"
|
||||
with_stateless: True
|
||||
return: argument 0
|
||||
before_call: |
|
||||
THLongStoragePtr _size = THLongTensor_newSizeOf(LIBRARY_STATE ((THPLongTensor*)$arg3)->cdata);
|
||||
THLongStoragePtr _size = THIndexTensor_(newSizeOf)(LIBRARY_STATE ((THPIndexTensor*)$arg3)->cdata);
|
||||
THTensor_(resize)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, _size, NULL);
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
allocate: True
|
||||
- THTensor* self
|
||||
- long dim
|
||||
- THLongTensor* index
|
||||
- THIndexTensor* index
|
||||
]]
|
||||
|
||||
[[
|
||||
|
|
@ -665,7 +664,6 @@ invalid_arguments:
|
|||
|
||||
[[
|
||||
name: equal
|
||||
defined_if: "!IS_CUDA"
|
||||
with_stateless: True
|
||||
return: bool
|
||||
arguments:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user