Allowing fill assign when narrowing/selecting

This commit is contained in:
Clement Farabet 2012-02-25 14:57:20 -05:00
parent 916afcc290
commit d9e051ba70

View File

@ -478,41 +478,55 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
void *src;
if (lua_isnumber(L,3)) {
real value = (real)luaL_checknumber(L,3);
luaL_argcheck(L, tensor->nDimension == 1, 1, "must be a one dimensional tensor");
luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range");
THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value);
if (tensor->nDimension == 1) {
luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range");
THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value);
} else {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(fill)(tensor, value);
THTensor_(free)(tensor);
}
} else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copy)(tensor, src);
THTensor_(free)(tensor);
} else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyByte)(tensor, src);
THTensor_(free)(tensor);
} else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyChar)(tensor, src);
THTensor_(free)(tensor);
} else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyShort)(tensor, src);
THTensor_(free)(tensor);
} else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyInt)(tensor, src);
THTensor_(free)(tensor);
} else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyLong)(tensor, src);
THTensor_(free)(tensor);
} else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyFloat)(tensor, src);
THTensor_(free)(tensor);
} else if( (src = luaT_toudata(L, 3, torch_DoubleTensor_id)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyDouble)(tensor, src);
THTensor_(free)(tensor);
} else {
luaL_typerror(L, 3, "torch.*Tensor");
}
@ -585,7 +599,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
// doing a copy
void *src;
if (lua_isnumber(L,3)) {
luaL_typerror(L, 3, "torch.*Tensor");
THTensor_(fill)(tensor, lua_tonumber(L,3));
} else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) {
THTensor_(copy)(tensor, src);
} else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) {