mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allowing fill assign when narrowing/selecting
This commit is contained in:
parent
916afcc290
commit
d9e051ba70
|
|
@ -478,41 +478,55 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
|
|||
void *src;
|
||||
if (lua_isnumber(L,3)) {
|
||||
real value = (real)luaL_checknumber(L,3);
|
||||
luaL_argcheck(L, tensor->nDimension == 1, 1, "must be a one dimensional tensor");
|
||||
luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range");
|
||||
THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value);
|
||||
if (tensor->nDimension == 1) {
|
||||
luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range");
|
||||
THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value);
|
||||
} else {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(fill)(tensor, value);
|
||||
THTensor_(free)(tensor);
|
||||
}
|
||||
} else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copy)(tensor, src);
|
||||
THTensor_(free)(tensor);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyByte)(tensor, src);
|
||||
THTensor_(free)(tensor);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyChar)(tensor, src);
|
||||
THTensor_(free)(tensor);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyShort)(tensor, src);
|
||||
THTensor_(free)(tensor);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyInt)(tensor, src);
|
||||
THTensor_(free)(tensor);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyLong)(tensor, src);
|
||||
THTensor_(free)(tensor);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyFloat)(tensor, src);
|
||||
THTensor_(free)(tensor);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_DoubleTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyDouble)(tensor, src);
|
||||
THTensor_(free)(tensor);
|
||||
} else {
|
||||
luaL_typerror(L, 3, "torch.*Tensor");
|
||||
}
|
||||
|
|
@ -585,7 +599,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
|
|||
// doing a copy
|
||||
void *src;
|
||||
if (lua_isnumber(L,3)) {
|
||||
luaL_typerror(L, 3, "torch.*Tensor");
|
||||
THTensor_(fill)(tensor, lua_tonumber(L,3));
|
||||
} else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) {
|
||||
THTensor_(copy)(tensor, src);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user