commit 053065ba239e3421d1b2cbec55e3f177a21d6829 Author: Ronan Collobert Date: Wed Jan 25 14:55:20 2012 +0100 initial revamp of torch7 tree diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000000..158e56aca18 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,18 @@ +SET(src DiskFile.c File.c MemoryFile.c PipeFile.c Storage.c Tensor.c Timer.c utils.c init.c TensorOperator.c TensorMathWrap.c random.c) +SET(luasrc init.lua File.lua Tensor.lua TensorMath.lua CmdLine.lua Tester.lua torch.lua test/test.lua) + +# Necessary do generate wrapper +ADD_TORCH_WRAP(tensormathwrap TensorMathWrap.lua) +ADD_TORCH_WRAP(randomwrap random.lua) + +ADD_TORCH_PACKAGE(torch "${src}" "${luasrc}" "Basics") +ADD_TORCH_DOK(dok torch "Fundamentals" "Torch package" 1.1) + +TARGET_LINK_LIBRARIES(torch luaT TH) + +CONFIGURE_FILE(torch.in "${Torch_BINARY_DIR}/torch") +INSTALL(FILES "${Torch_BINARY_DIR}/torch" + DESTINATION "${Torch_INSTALL_BIN_SUBDIR}" + PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ + GROUP_EXECUTE GROUP_READ + WORLD_EXECUTE WORLD_READ) diff --git a/CmdLine.lua b/CmdLine.lua new file mode 100644 index 00000000000..d86e9512917 --- /dev/null +++ b/CmdLine.lua @@ -0,0 +1,244 @@ +local CmdLine = torch.class('torch.CmdLine') + +local function strip(str) + return string.match(str, '%-*(.*)') +end + +local function pad(str, sz) + return str .. string.rep(' ', sz-#str) +end + +function CmdLine:error(msg) + print('') + print(msg) + print('') + self:help() + os.exit(0) +end + +function CmdLine:__readArgument__(params, arg, i, nArgument) + local argument = self.arguments[nArgument] + local value = arg[i] + + if nArgument > #self.arguments then + self:error('invalid argument: ' .. value) + end + if argument.type and type(value) ~= argument.type then + self:error('invalid argument type for argument ' .. argument.key .. ' (should be ' .. argument.type .. ')') + end + params[strip(argument.key)] = value + return 1 +end + +function CmdLine:__readOption__(params, arg, i) + local key = arg[i] + local option = self.options[key] + if not option then + self:error('unknown option ' .. key) + end + + if option.type and option.type == 'boolean' then + params[strip(key)] = not option.default + return 1 + else + local value = arg[i+1] + if not value then + self:error('missing argument for option ' .. key) + end + if not option.type or option.type == 'string' then + elseif option.type == 'number' then + value = tonumber(value) + else + self:error('unknown required option type ' .. option.type) + end + if not value then + self:error('invalid type for option ' .. key .. ' (should be ' .. option.type .. ')') + end + params[strip(key)] = value + return 2 + end +end + +function CmdLine:__init(argseparator_,keyseparator_) + self.argseparator = argseparator_ or ',' + self.keyseparator = keyseparator_ or '=' + self.options = {} + self.arguments = {} + self.helplines = {} +end + +function CmdLine:argument(key, help, _type_) + table.insert(self.arguments, {key=key, help=help, type=_type_}) + table.insert(self.helplines, self.arguments[#self.arguments]) +end + +function CmdLine:option(key, default, help, _type_) + if default == nil then + error('option ' .. key .. ' has no default value') + end + _type_ = _type_ or type(default) + if type(default) ~= _type_ then + error('option ' .. key .. ' has wrong default type value') + end + self.options[key] = {key=key, default=default, help=help, type=_type_} + table.insert(self.helplines, self.options[key]) +end + +function CmdLine:default() + local params = {} + for option,v in pairs(self.options) do + params[strip(option)] = v.default + end + return params +end + +function CmdLine:parse(arg) + local i = 1 + local params = self:default() + + local nArgument = 0 + + while i <= #arg do + if arg[i] == '-help' or arg[i] == '-h' or arg[i] == '--help' then + self:help(arg) + os.exit(0) + end + + if self.options[arg[i]] then + i = i + self:__readOption__(params, arg, i) + else + nArgument = nArgument + 1 + i = i + self:__readArgument__(params, arg, i, nArgument) + end + end + + if nArgument ~= #self.arguments then + self:error('not enough arguments') + end + + return params +end + +function CmdLine:string(prefix, params, ignore) + local arguments = {} + local options = {} + prefix = prefix or '' + + for k,v in pairs(params) do + if ignore[k] then + print('-- ignore option ' .. k) + elseif self.options['-' .. k] then + if v ~= self.options['-' .. k].default then + if type(v) == 'boolean' then + if v then + v = 't' + else + v = 'f' + end + end + table.insert(options, k .. self.keyseparator .. v) + print(k,v,self.options['-' .. k].default) + end + else + local narg + for i=1,#self.arguments do + if strip(self.arguments[i].key) == k then + narg = i + end + end + if narg then + arguments[narg] = k .. self.keyseparator .. v + else + print('WARNING: unknown option/argument: ' .. k .. ' IGNORING for DIRECTORY NAME') + end + end + end + table.sort(options) + local str = table.concat(arguments, self.argseparator) + if str == '' then + str = table.concat(options, self.argseparator) + else + str = str .. self.argseparator .. table.concat(options, self.argseparator) + end + if str == '' then + return prefix + else + return prefix .. self.argseparator .. str + end +end + +function CmdLine:log(file, params) + local f = io.open(file, 'w') + local oprint = print + function print(...) + local n = select("#", ...) + oprint(...) + for i=1,n do + f:write(tostring(select(i, ...))) + if i ~= n then + f:write(' ') + else + f:write('\n') + end + end + f:flush() + end + print('[program started on ' .. os.date() .. ']') + print('[command line arguments]') + if params then + for k,v in pairs(params) do + print(k,v) + end + end + print('[----------------------]') +end + +function CmdLine:text(txt) + txt = txt or '' + assert(type(txt) == 'string') + table.insert(self.helplines, txt) +end + +function CmdLine:help(arg) + io.write('Usage: ') + if arg then io.write(arg[0] .. ' ') end + io.write('[options] ') + for i=1,#self.arguments do + io.write('<' .. strip(self.arguments[i].key) .. '>') + end + io.write('\n') + + -- first pass to compute max length + local optsz = 0 + for _,option in ipairs(self.helplines) do + if type(option) == 'table' then + if option.default ~= nil then -- it is an option + if #option.key > optsz then + optsz = #option.key + end + else -- it is an argument + if #strip(option.key)+2 > optsz then + optsz = #strip(option.key)+2 + end + end + end + end + + -- second pass to print + for _,option in ipairs(self.helplines) do + if type(option) == 'table' then + io.write(' ') + if option.default ~= nil then -- it is an option + io.write(pad(option.key, optsz)) + if option.help then io.write(' ' .. option.help) end + io.write(' [' .. tostring(option.default) .. ']') + else -- it is an argument + io.write(pad('<' .. strip(option.key) .. '>', optsz)) + if option.help then io.write(' ' .. option.help) end + end + else + io.write(option) -- just some additional help + end + io.write('\n') + end +end diff --git a/DiskFile.c b/DiskFile.c new file mode 100644 index 00000000000..99c76d19493 --- /dev/null +++ b/DiskFile.c @@ -0,0 +1,87 @@ +#include "general.h" + +static const void* torch_DiskFile_id = NULL; + +static int torch_DiskFile_new(lua_State *L) +{ + const char *name = luaL_checkstring(L, 1); + const char *mode = luaL_optstring(L, 2, "r"); + int isQuiet = luaT_optboolean(L, 3, 0); + THFile *self = THDiskFile_new(name, mode, isQuiet); + + luaT_pushudata(L, self, torch_DiskFile_id); + return 1; +} + +static int torch_DiskFile_free(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); + THFile_free(self); + return 0; +} + +static int torch_DiskFile_isLittleEndianCPU(lua_State *L) +{ + lua_pushboolean(L, THDiskFile_isLittleEndianCPU()); + return 1; +} + +static int torch_DiskFile_isBigEndianCPU(lua_State *L) +{ + lua_pushboolean(L, !THDiskFile_isLittleEndianCPU()); + return 1; +} + +static int torch_DiskFile_nativeEndianEncoding(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); + THDiskFile_nativeEndianEncoding(self); + lua_settop(L, 1); + return 1; +} + +static int torch_DiskFile_littleEndianEncoding(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); + THDiskFile_littleEndianEncoding(self); + lua_settop(L, 1); + return 1; +} + +static int torch_DiskFile_bigEndianEncoding(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); + THDiskFile_bigEndianEncoding(self); + lua_settop(L, 1); + return 1; +} + +static int torch_DiskFile___tostring__(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id); + lua_pushfstring(L, "torch.DiskFile on <%s> [status: %s -- mode %c%c]", + THDiskFile_name(self), + (THFile_isOpened(self) ? "open" : "closed"), + (THFile_isReadable(self) ? 'r' : ' '), + (THFile_isWritable(self) ? 'w' : ' ')); + + return 1; +} +static const struct luaL_Reg torch_DiskFile__ [] = { + {"isLittleEndianCPU", torch_DiskFile_isLittleEndianCPU}, + {"isBigEndianCPU", torch_DiskFile_isBigEndianCPU}, + {"nativeEndianEncoding", torch_DiskFile_nativeEndianEncoding}, + {"littleEndianEncoding", torch_DiskFile_littleEndianEncoding}, + {"bigEndianEncoding", torch_DiskFile_bigEndianEncoding}, + {"__tostring__", torch_DiskFile___tostring__}, + {NULL, NULL} +}; + +void torch_DiskFile_init(lua_State *L) +{ + torch_DiskFile_id = luaT_newmetatable(L, "torch.DiskFile", "torch.File", + torch_DiskFile_new, torch_DiskFile_free, NULL); + + luaL_register(L, NULL, torch_DiskFile__); + lua_pop(L, 1); +} diff --git a/File.c b/File.c new file mode 100644 index 00000000000..0344d67d7f1 --- /dev/null +++ b/File.c @@ -0,0 +1,225 @@ +#include "THFile.h" +#include "luaT.h" + +static const void *torch_File_id = NULL; +static const void *torch_ByteStorage_id = NULL; +static const void *torch_CharStorage_id = NULL; +static const void *torch_ShortStorage_id = NULL; +static const void *torch_IntStorage_id = NULL; +static const void *torch_LongStorage_id = NULL; +static const void *torch_FloatStorage_id = NULL; +static const void *torch_DoubleStorage_id = NULL; + +#define IMPLEMENT_TORCH_FILE_FLAG(NAME) \ + static int torch_File_##NAME(lua_State *L) \ + { \ + THFile *self = luaT_checkudata(L, 1, torch_File_id); \ + lua_pushboolean(L, THFile_##NAME(self)); \ + return 1; \ + } + +IMPLEMENT_TORCH_FILE_FLAG(isQuiet) +IMPLEMENT_TORCH_FILE_FLAG(isReadable) +IMPLEMENT_TORCH_FILE_FLAG(isWritable) +IMPLEMENT_TORCH_FILE_FLAG(isBinary) +IMPLEMENT_TORCH_FILE_FLAG(isAutoSpacing) +IMPLEMENT_TORCH_FILE_FLAG(hasError) + +#define IMPLEMENT_TORCH_FILE_FUNC(NAME) \ + static int torch_File_##NAME(lua_State *L) \ + { \ + THFile *self = luaT_checkudata(L, 1, torch_File_id); \ + THFile_##NAME(self); \ + lua_settop(L, 1); \ + return 1; \ + } + +IMPLEMENT_TORCH_FILE_FUNC(binary) +IMPLEMENT_TORCH_FILE_FUNC(ascii) +IMPLEMENT_TORCH_FILE_FUNC(autoSpacing) +IMPLEMENT_TORCH_FILE_FUNC(noAutoSpacing) +IMPLEMENT_TORCH_FILE_FUNC(quiet) +IMPLEMENT_TORCH_FILE_FUNC(pedantic) +IMPLEMENT_TORCH_FILE_FUNC(clearError) + +IMPLEMENT_TORCH_FILE_FUNC(synchronize) + +static int torch_File_seek(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_File_id); + long position = luaL_checklong(L, 2)-1; + THFile_seek(self, position); + lua_settop(L, 1); + return 1; +} + +IMPLEMENT_TORCH_FILE_FUNC(seekEnd) + +static int torch_File_position(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_File_id); + lua_pushnumber(L, THFile_position(self)+1); + return 1; +} + +IMPLEMENT_TORCH_FILE_FUNC(close) + +#define IMPLEMENT_TORCH_FILE_RW(TYPEC, TYPE) \ + static int torch_File_read##TYPEC(lua_State *L) \ + { \ + THFile *self = luaT_checkudata(L, 1, torch_File_id); \ + int narg = lua_gettop(L); \ + \ + if(narg == 1) \ + { \ + lua_pushnumber(L, THFile_read##TYPEC##Scalar(self)); \ + return 1; \ + } \ + else if(narg == 2) \ + { \ + if(lua_isnumber(L, 2)) \ + { \ + long size = lua_tonumber(L, 2); \ + long nread; \ + \ + TH##TYPEC##Storage *storage = TH##TYPEC##Storage_newWithSize(size); \ + luaT_pushudata(L, storage, torch_##TYPEC##Storage_id); \ + nread = THFile_read##TYPEC(self, storage); \ + if(nread != size) \ + TH##TYPEC##Storage_resize(storage, size); \ + return 1; \ + } \ + else if(luaT_toudata(L, 2, torch_##TYPEC##Storage_id)) \ + { \ + TH##TYPEC##Storage *storage = luaT_toudata(L, 2, torch_##TYPEC##Storage_id); \ + lua_pushnumber(L, THFile_read##TYPEC(self, storage)); \ + return 1; \ + } \ + } \ + \ + luaL_error(L, "nothing, number, or Storage expected"); \ + return 0; \ + } \ + \ + static int torch_File_write##TYPEC(lua_State *L) \ + { \ + THFile *self = luaT_checkudata(L, 1, torch_File_id); \ + int narg = lua_gettop(L); \ + \ + if(narg == 2) \ + { \ + if(lua_isnumber(L, 2)) \ + { \ + TYPE value = lua_tonumber(L, 2); \ + THFile_write##TYPEC##Scalar(self, (TYPE)value); \ + return 0; \ + } \ + else if(luaT_toudata(L, 2, torch_##TYPEC##Storage_id)) \ + { \ + TH##TYPEC##Storage *storage = luaT_toudata(L, 2, torch_##TYPEC##Storage_id); \ + lua_pushnumber(L, THFile_write##TYPEC(self, storage)); \ + return 1; \ + } \ + } \ + \ + luaL_error(L, "number, or Storage expected"); \ + return 0; \ + } + + +IMPLEMENT_TORCH_FILE_RW(Byte, unsigned char) +IMPLEMENT_TORCH_FILE_RW(Char, char) +IMPLEMENT_TORCH_FILE_RW(Short, short) +IMPLEMENT_TORCH_FILE_RW(Int, int) +IMPLEMENT_TORCH_FILE_RW(Long, long) +IMPLEMENT_TORCH_FILE_RW(Float, float) +IMPLEMENT_TORCH_FILE_RW(Double, double) + +static int torch_File_readString(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_File_id); + const char *format = luaL_checkstring(L, 2); + char *str; + long size; + + size = THFile_readStringRaw(self, format, &str); + lua_pushlstring(L, str, size); + THFree(str); + + return 1; +} + +static int torch_File_writeString(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_File_id); + const char *str = NULL; + size_t size; + long nwrite; + + luaL_checktype(L, 2, LUA_TSTRING); + str = lua_tolstring(L, 2, &size); + lua_pushnumber(L, THFile_writeStringRaw(self, str, (long)size)); + return 1; +} + +static const struct luaL_Reg torch_File__ [] = { + {"isQuiet", torch_File_isQuiet}, + {"isReadable", torch_File_isReadable}, + {"isWritable", torch_File_isWritable}, + {"isBinary", torch_File_isBinary}, + {"isAutoSpacing", torch_File_isAutoSpacing}, + {"hasError", torch_File_hasError}, + {"binary", torch_File_binary}, + {"ascii", torch_File_ascii}, + {"autoSpacing", torch_File_autoSpacing}, + {"noAutoSpacing", torch_File_noAutoSpacing}, + {"quiet", torch_File_quiet}, + {"pedantic", torch_File_pedantic}, + {"clearError", torch_File_clearError}, + + /* DEBUG: CHECK DISK FREE & READ/WRITE STRING*/ + + {"readByte", torch_File_readByte}, + {"readChar", torch_File_readChar}, + {"readShort", torch_File_readShort}, + {"readInt", torch_File_readInt}, + {"readLong", torch_File_readLong}, + {"readFloat", torch_File_readFloat}, + {"readDouble", torch_File_readDouble}, + {"readString", torch_File_readString}, + + {"writeByte", torch_File_writeByte}, + {"writeChar", torch_File_writeChar}, + {"writeShort", torch_File_writeShort}, + {"writeInt", torch_File_writeInt}, + {"writeLong", torch_File_writeLong}, + {"writeFloat", torch_File_writeFloat}, + {"writeDouble", torch_File_writeDouble}, + {"writeString", torch_File_writeString}, + + {"synchronize", torch_File_synchronize}, + {"seek", torch_File_seek}, + {"seekEnd", torch_File_seekEnd}, + {"position", torch_File_position}, + {"close", torch_File_close}, + + {NULL, NULL} +}; + +void torch_File_init(lua_State *L) +{ + torch_File_id = luaT_newmetatable(L, "torch.File", NULL, NULL, NULL, NULL); + luaL_register(L, NULL, torch_File__); + lua_pop(L, 1); +} + +void torch_File_init_storage_id(lua_State *L) +{ + torch_ByteStorage_id = luaT_checktypename2id(L, "torch.ByteStorage"); + torch_CharStorage_id = luaT_checktypename2id(L, "torch.CharStorage"); + torch_ShortStorage_id = luaT_checktypename2id(L, "torch.ShortStorage"); + torch_IntStorage_id = luaT_checktypename2id(L, "torch.IntStorage"); + torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); + torch_FloatStorage_id = luaT_checktypename2id(L, "torch.FloatStorage"); + torch_DoubleStorage_id = luaT_checktypename2id(L, "torch.DoubleStorage"); +} diff --git a/File.lua b/File.lua new file mode 100644 index 00000000000..0ed648f47b8 --- /dev/null +++ b/File.lua @@ -0,0 +1,240 @@ +local File = torch.getmetatable('torch.File') + +function File:writeBool(value) + if value then + self:writeInt(1) + else + self:writeInt(0) + end +end + +function File:readBool() + return (self:readInt() == 1) +end + +local TYPE_NIL = 0 +local TYPE_NUMBER = 1 +local TYPE_STRING = 2 +local TYPE_TABLE = 3 +local TYPE_TORCH = 4 +local TYPE_BOOLEAN = 5 +local TYPE_FUNCTION = 6 + +function File:isWritableObject(object) + local typename = type(object) + local typeidx + if type(object) ~= 'boolean' and not object then + typeidx = TYPE_NIL + elseif torch.typename(object) and torch.factory(torch.typename(object)) then + typeidx = TYPE_TORCH + elseif typename == 'table' then + typeidx = TYPE_TABLE + elseif typename == 'number' then + typeidx = TYPE_NUMBER + elseif typename == 'string' then + typeidx = TYPE_STRING + elseif typename == 'boolean' then + typeidx = TYPE_BOOLEAN + elseif typename == 'function' and pcall(string.dump, object) then + typeidx = TYPE_FUNCTION + end + return typeidx +end + +function File:writeObject(object) + -- we use an environment to keep a record of written objects + if not torch.getenv(self).writeObjects then + torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) + end + + -- if nil object, only write the type and return + if type(object) ~= 'boolean' and not object then + self:writeInt(TYPE_NIL) + return + end + + -- check the type we are dealing with + local typeidx = self:isWritableObject(object) + if not typeidx then + error(string.format('Unwritable object <%s>', type(object))) + end + self:writeInt(typeidx) + + if typeidx == TYPE_NUMBER then + self:writeDouble(object) + elseif typeidx == TYPE_BOOLEAN then + self:writeBool(object) + elseif typeidx == TYPE_STRING then + local stringStorage = torch.CharStorage():string(object) + self:writeInt(#stringStorage) + self:writeChar(stringStorage) + elseif typeidx == TYPE_FUNCTION then + local upvalues = {} + while true do + local name,value = debug.getupvalue(object, #upvalues+1) + if not name then break end + table.insert(upvalues, value) + end + local dumped = string.dump(object) + local stringStorage = torch.CharStorage():string(dumped) + self:writeInt(#stringStorage) + self:writeChar(stringStorage) + self:writeObject(upvalues) + elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE then + -- check it exists already (we look at the pointer!) + local objects = torch.getenv(self).writeObjects + local objectsRef = torch.getenv(self).writeObjectsRef + local index = objects[torch.pointer(object)] + + if index then + -- if already exists, write only its index + self:writeInt(index) + else + -- else write the object itself + index = objects.nWriteObject or 0 + index = index + 1 + objects[torch.pointer(object)] = index + objectsRef[object] = index -- we make sure the object is not going to disappear + self:writeInt(index) + objects.nWriteObject = index + + if typeidx == TYPE_TORCH then + local version = torch.CharStorage():string('V ' .. torch.version(object)) + local className = torch.CharStorage():string(torch.typename(object)) + self:writeInt(#version) + self:writeChar(version) + self:writeInt(#className) + self:writeChar(className) + if object.write then + object:write(self) + elseif type(object) == 'table' then + local var = {} + for k,v in pairs(object) do + if self:isWritableObject(v) then + var[k] = v + else + print(string.format('$ Warning: cannot write object field <%s>', k)) + end + end + self:writeObject(var) + else + error(string.format('<%s> is a non-serializable Torch object', torch.typename(object))) + end + else -- it is a table + local size = 0; for k,v in pairs(object) do size = size + 1 end + self:writeInt(size) + for k,v in pairs(object) do + self:writeObject(k) + self:writeObject(v) + end + end + end + else + error('Unwritable object') + end +end + +function File:readObject() + -- we use an environment to keep a record of read objects + if not torch.getenv(self).writeObjects then + torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) + end + + -- read the typeidx + local typeidx = self:readInt() + + -- is it nil? + if typeidx == TYPE_NIL then + return nil + end + + if typeidx == TYPE_NUMBER then + return self:readDouble() + elseif typeidx == TYPE_BOOLEAN then + return self:readBool() + elseif typeidx == TYPE_STRING then + local size = self:readInt() + return self:readChar(size):string() + elseif typeidx == TYPE_FUNCTION then + local size = self:readInt() + local dumped = self:readChar(size):string() + local func = loadstring(dumped) + local upvalues = self:readObject() + for index,upvalue in ipairs(upvalues) do + debug.setupvalue(func, index, upvalue) + end + return func + elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH then + -- read the index + local index = self:readInt() + + -- check it is loaded already + local objects = torch.getenv(self).readObjects + if objects[index] then + return objects[index] + end + + -- otherwise read it + if typeidx == TYPE_TORCH then + local version, className, versionNumber + version = self:readChar(self:readInt()):string() + versionNumber = tonumber(string.match(version, '^V (.*)$')) + if not versionNumber then + className = version + versionNumber = 0 -- file created before existence of versioning system + else + className = self:readChar(self:readInt()):string() + end + if not torch.factory(className) then + error(string.format('unknown Torch class <%s>' .. className)) + end + local object = torch.factory(className)() + objects[index] = object + if object.read then + object:read(self, versionNumber) + elseif type(object) == 'table' then + local var = self:readObject(var) + for k,v in pairs(var) do + object[k] = v + end + else + error(string.format('Cannot load object class <%s>', className)) + end + return object + else -- it is a table + local size = self:readInt() + local object = {} + objects[index] = object + for i = 1,size do + local k = self:readObject() + local v = self:readObject() + object[k] = v + end + return object + end + else + error('unknown object') + end +end + +-- simple helpers to save/load arbitrary objects/tables +function torch.save(filename, object, mode) + mode = mode or 'binary' + local file = torch.DiskFile(filename, 'w') + file[mode](file) + file:writeObject(object) + file:close() +end + +function torch.load(filename, mode) + mode = mode or 'binary' + local file = torch.DiskFile(filename, 'r') + file[mode](file) + local object = file:readObject() + file:close() + return object +end + +-- public API (saveobj/loadobj are safe for global import) +torch.saveobj = torch.save +torch.loadobj = torch.load diff --git a/MemoryFile.c b/MemoryFile.c new file mode 100644 index 00000000000..3197170a018 --- /dev/null +++ b/MemoryFile.c @@ -0,0 +1,67 @@ +#include "general.h" + +static const void* torch_MemoryFile_id; +static const void* torch_CharStorage_id; + +static int torch_MemoryFile_new(lua_State *L) +{ + const char *mode; + THCharStorage *storage = luaT_toudata(L, 1, torch_CharStorage_id); + THFile *self; + + if(storage) + { + mode = luaL_optstring(L, 2, "rw"); + self = THMemoryFile_newWithStorage(storage, mode); + } + else + { + mode = luaL_optstring(L, 1, "rw"); + self = THMemoryFile_new(mode); + } + + luaT_pushudata(L, self, torch_MemoryFile_id); + return 1; +} + +static int torch_MemoryFile_storage(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id); + THCharStorage_retain(THMemoryFile_storage(self)); + luaT_pushudata(L, THMemoryFile_storage(self), torch_CharStorage_id); + return 1; +} + +static int torch_MemoryFile_free(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id); + THFile_free(self); + return 0; +} + +static int torch_MemoryFile___tostring__(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id); + lua_pushfstring(L, "torch.MemoryFile [status: %s -- mode: %c%c]", + (THFile_isOpened(self) ? "open" : "closed"), + (THFile_isReadable(self) ? 'r' : ' '), + (THFile_isWritable(self) ? 'w' : ' ')); + return 1; +} + +static const struct luaL_Reg torch_MemoryFile__ [] = { + {"storage", torch_MemoryFile_storage}, + {"__tostring__", torch_MemoryFile___tostring__}, + {NULL, NULL} +}; + +void torch_MemoryFile_init(lua_State *L) +{ + torch_CharStorage_id = luaT_checktypename2id(L, "torch.CharStorage"); + + torch_MemoryFile_id = luaT_newmetatable(L, "torch.MemoryFile", "torch.File", + torch_MemoryFile_new, torch_MemoryFile_free, NULL); + + luaL_register(L, NULL, torch_MemoryFile__); + lua_pop(L, 1); +} diff --git a/PipeFile.c b/PipeFile.c new file mode 100644 index 00000000000..32c275da858 --- /dev/null +++ b/PipeFile.c @@ -0,0 +1,46 @@ +#include "general.h" + +static const void* torch_PipeFile_id = NULL; + +static int torch_PipeFile_new(lua_State *L) +{ + const char *name = luaL_checkstring(L, 1); + const char *mode = luaL_optstring(L, 2, "r"); + int isQuiet = luaT_optboolean(L, 3, 0); + THFile *self = THPipeFile_new(name, mode, isQuiet); + + luaT_pushudata(L, self, torch_PipeFile_id); + return 1; +} + +static int torch_PipeFile_free(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_PipeFile_id); + THFile_free(self); + return 0; +} + +static int torch_PipeFile___tostring__(lua_State *L) +{ + THFile *self = luaT_checkudata(L, 1, torch_PipeFile_id); + lua_pushfstring(L, "torch.PipeFile on <%s> [status: %s -- mode: %c%c]", + THDiskFile_name(self), + (THFile_isOpened(self) ? "open" : "closed"), + (THFile_isReadable(self) ? 'r' : ' '), + (THFile_isWritable(self) ? 'w' : ' ')); + return 1; +} + +static const struct luaL_Reg torch_PipeFile__ [] = { + {"__tostring__", torch_PipeFile___tostring__}, + {NULL, NULL} +}; + +void torch_PipeFile_init(lua_State *L) +{ + torch_PipeFile_id = luaT_newmetatable(L, "torch.PipeFile", "torch.DiskFile", + torch_PipeFile_new, torch_PipeFile_free, NULL); + + luaL_register(L, NULL, torch_PipeFile__); + lua_pop(L, 1); +} diff --git a/Storage.c b/Storage.c new file mode 100644 index 00000000000..bb7f7db3e48 --- /dev/null +++ b/Storage.c @@ -0,0 +1,19 @@ +#include "general.h" + +static const void *torch_File_id = NULL; +static const void *torch_ByteStorage_id = NULL; +static const void *torch_CharStorage_id = NULL; +static const void *torch_ShortStorage_id = NULL; +static const void *torch_IntStorage_id = NULL; +static const void *torch_LongStorage_id = NULL; +static const void *torch_FloatStorage_id = NULL; +static const void *torch_DoubleStorage_id = NULL; + +#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME) +#define torch_Storage_id TH_CONCAT_3(torch_,Real,Storage_id) +#define THFile_readRealRaw TH_CONCAT_3(THFile_read, Real, Raw) +#define THFile_writeRealRaw TH_CONCAT_3(THFile_write, Real, Raw) +#define STRING_torchStorage TH_CONCAT_STRING_3(torch.,Real,Storage) + +#include "generic/Storage.c" +#include "THGenerateAllTypes.h" diff --git a/Tensor.c b/Tensor.c new file mode 100644 index 00000000000..dc65202d4af --- /dev/null +++ b/Tensor.c @@ -0,0 +1,29 @@ +#include "general.h" + +static const void *torch_File_id = NULL; + +static const void *torch_ByteStorage_id = NULL; +static const void *torch_CharStorage_id = NULL; +static const void *torch_ShortStorage_id = NULL; +static const void *torch_IntStorage_id = NULL; +static const void *torch_LongStorage_id = NULL; +static const void *torch_FloatStorage_id = NULL; +static const void *torch_DoubleStorage_id = NULL; + +static const void *torch_ByteTensor_id = NULL; +static const void *torch_CharTensor_id = NULL; +static const void *torch_ShortTensor_id = NULL; +static const void *torch_IntTensor_id = NULL; +static const void *torch_LongTensor_id = NULL; +static const void *torch_FloatTensor_id = NULL; +static const void *torch_DoubleTensor_id = NULL; + +#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME) +#define torch_Storage_id TH_CONCAT_3(torch_,Real,Storage_id) +#define STRING_torchStorage TH_CONCAT_STRING_3(torch.,Real,Storage) +#define torch_Tensor_(NAME) TH_CONCAT_4(torch_,Real,Tensor_,NAME) +#define torch_Tensor_id TH_CONCAT_3(torch_,Real,Tensor_id) +#define STRING_torchTensor TH_CONCAT_STRING_3(torch.,Real,Tensor) + +#include "generic/Tensor.c" +#include "THGenerateAllTypes.h" diff --git a/Tensor.lua b/Tensor.lua new file mode 100644 index 00000000000..43465e3c2cd --- /dev/null +++ b/Tensor.lua @@ -0,0 +1,279 @@ +-- additional methods for Storage +local Storage = {} + +-- additional methods for Tensor +local Tensor = {} + +-- types +local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'} + +-- tostring() functions for Tensor and Storage +local function Storage__printformat(self) + local intMode = true + local type = torch.typename(self) +-- if type == 'torch.FloatStorage' or type == 'torch.DoubleStorage' then + for i=1,self:size() do + if self[i] ~= math.ceil(self[i]) then + intMode = false + break + end + end +-- end + local tensor = torch.DoubleTensor(torch.DoubleStorage(self:size()):copy(self), 1, self:size()):abs() + local expMin = tensor:minall() + if expMin ~= 0 then + expMin = math.floor(math.log10(expMin)) + 1 + end + local expMax = tensor:maxall() + if expMax ~= 0 then + expMax = math.floor(math.log10(expMax)) + 1 + end + + local format + local scale + local sz + if intMode then + if expMax > 9 then + format = "%11.4e" + sz = 11 + else + format = "%SZd" + sz = expMax + 1 + end + else + if expMax-expMin > 4 then + format = "%SZ.4e" + sz = 11 + if math.abs(expMax) > 99 or math.abs(expMin) > 99 then + sz = sz + 1 + end + else + if expMax > 5 or expMax < 0 then + format = "%SZ.4f" + sz = 7 + scale = math.pow(10, expMax-1) + else + format = "%SZ.4f" + if expMax == 0 then + sz = 7 + else + sz = expMax+6 + end + end + end + end + format = string.gsub(format, 'SZ', sz) + if scale == 1 then + scale = nil + end + return format, scale, sz +end + +function Storage.__tostring__(self) + local strt = {'\n'} + local format,scale = Storage__printformat(self) + if format:sub(2,4) == 'nan' then format = '%f' end + if scale then + table.insert(strt, string.format('%g', scale) .. ' *\n') + for i = 1,self:size() do + table.insert(strt, string.format(format, self[i]/scale) .. '\n') + end + else + for i = 1,self:size() do + table.insert(strt, string.format(format, self[i]) .. '\n') + end + end + table.insert(strt, '[' .. torch.typename(self) .. ' of size ' .. self:size() .. ']\n') + str = table.concat(strt) + return str +end + +for _,type in ipairs(types) do + local metatable = torch.getmetatable('torch.' .. type .. 'Storage') + for funcname, func in pairs(Storage) do + rawset(metatable, funcname, func) + end +end + +local function Tensor__printMatrix(self, indent) + local format,scale,sz = Storage__printformat(self:storage()) + if format:sub(2,4) == 'nan' then format = '%f' end +-- print('format = ' .. format) + scale = scale or 1 + indent = indent or '' + local strt = {indent} + local nColumnPerLine = math.floor((80-#indent)/(sz+1)) +-- print('sz = ' .. sz .. ' and nColumnPerLine = ' .. nColumnPerLine) + local firstColumn = 1 + local lastColumn = -1 + while firstColumn <= self:size(2) do + if firstColumn + nColumnPerLine - 1 <= self:size(2) then + lastColumn = firstColumn + nColumnPerLine - 1 + else + lastColumn = self:size(2) + end + if nColumnPerLine < self:size(2) then + if firstColumn ~= 1 then + table.insert(strt, '\n') + end + table.insert(strt, 'Columns ' .. firstColumn .. ' to ' .. lastColumn .. '\n' .. indent) + end + if scale ~= 1 then + table.insert(strt, string.format('%g', scale) .. ' *\n ' .. indent) + end + for l=1,self:size(1) do + local row = self:select(1, l) + for c=firstColumn,lastColumn do + table.insert(strt, string.format(format, row[c]/scale)) + if c == lastColumn then + table.insert(strt, '\n') + if l~=self:size(1) then + if scale ~= 1 then + table.insert(strt, indent .. ' ') + else + table.insert(strt, indent) + end + end + else + table.insert(strt, ' ') + end + end + end + firstColumn = lastColumn + 1 + end + local str = table.concat(strt) + return str +end + +local function Tensor__printTensor(self) + local counter = torch.LongStorage(self:nDimension()-2) + local strt = {''} + local finished + counter:fill(1) + counter[1] = 0 + while true do + for i=1,self:nDimension()-2 do + counter[i] = counter[i] + 1 + if counter[i] > self:size(i) then + if i == self:nDimension()-2 then + finished = true + break + end + counter[i] = 1 + else + break + end + end + if finished then + break + end +-- print(counter) + if #strt > 1 then + table.insert(strt, '\n') + end + table.insert(strt, '(') + local tensor = self + for i=1,self:nDimension()-2 do + tensor = tensor:select(1, counter[i]) + table.insert(strt, counter[i] .. ',') + end + table.insert(strt, '.,.) = \n') + table.insert(strt, Tensor__printMatrix(tensor, ' ')) + end + local str = table.concat(strt) + return str +end + +function Tensor.__tostring__(self) + local str = '\n' + local strt = {''} + if self:nDimension() == 0 then + table.insert(strt, '[' .. torch.typename(self) .. ' with no dimension]\n') + else + local tensor = torch.DoubleTensor():resize(self:size()):copy(self) + if tensor:nDimension() == 1 then + local format,scale,sz = Storage__printformat(tensor:storage()) + if format:sub(2,4) == 'nan' then format = '%f' end + if scale then + table.insert(strt, string.format('%g', scale) .. ' *\n') + for i = 1,tensor:size(1) do + table.insert(strt, string.format(format, tensor[i]/scale) .. '\n') + end + else + for i = 1,tensor:size(1) do + table.insert(strt, string.format(format, tensor[i]) .. '\n') + end + end + table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ' .. tensor:size(1) .. ']\n') + elseif tensor:nDimension() == 2 then + table.insert(strt, Tensor__printMatrix(tensor)) + table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ' .. tensor:size(1) .. 'x' .. tensor:size(2) .. ']\n') + else + table.insert(strt, Tensor__printTensor(tensor)) + table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ') + for i=1,tensor:nDimension() do + table.insert(strt, tensor:size(i)) + if i ~= tensor:nDimension() then + table.insert(strt, 'x') + end + end + table.insert(strt, ']\n') + end + end + local str = table.concat(strt) + return str +end + +function Tensor.type(self,type) + local current = torch.typename(self) + if not type then return current end + if type ~= current then + local new = torch.getmetatable(type).new() + if self:nElement() > 0 then + new:resize(self:size()):copy(self) + end + return new + else + return self + end +end + +function Tensor.typeAs(self,tensor) + return self:type(tensor:type()) +end + +function Tensor.byte(self,type) + return self:type('torch.ByteTensor') +end + +function Tensor.char(self,type) + return self:type('torch.CharTensor') +end + +function Tensor.short(self,type) + return self:type('torch.ShortTensor') +end + +function Tensor.int(self,type) + return self:type('torch.IntTensor') +end + +function Tensor.long(self,type) + return self:type('torch.LongTensor') +end + +function Tensor.float(self,type) + return self:type('torch.FloatTensor') +end + +function Tensor.double(self,type) + return self:type('torch.DoubleTensor') +end + + +for _,type in ipairs(types) do + local metatable = torch.getmetatable('torch.' .. type .. 'Tensor') + for funcname, func in pairs(Tensor) do + rawset(metatable, funcname, func) + end +end diff --git a/TensorConvWrap.lua b/TensorConvWrap.lua new file mode 100644 index 00000000000..4addd903493 --- /dev/null +++ b/TensorConvWrap.lua @@ -0,0 +1,127 @@ +-- +-- require 'wrap' +--- + +interface = wrap.CInterface.new() + + +interface.dispatchregistry = {} +function interface:wrap(name, ...) + -- usual stuff + --wrap.CInterface.wrap(self, name, ...) + + -- dispatch function + if not interface.dispatchregistry[name] then + interface.dispatchregistry[name] = true + table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)}) + + interface:print(string.gsub([[ +static int torch_NAME(lua_State *L) +{ + int narg = lua_gettop(L); + const void *id; + + if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */ + { + if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */ + { + if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */ + lua_pop(L, 1); + else if(!(id = torch_istensorid(L, torch_getdefaulttensorid()))) + luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor"); + } + } + + lua_pushstring(L, "NAME"); + lua_rawget(L, -2); + if(lua_isfunction(L, -1)) + { + lua_insert(L, 1); + lua_pop(L, 2); /* the two tables we put on the stack above */ + lua_call(L, lua_gettop(L)-1, LUA_MULTRET); + } + else + return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id)); + + return lua_gettop(L); +} +]], 'NAME', name)) + end +end + +function interface:dispatchregister(name) + local txt = self.txt + table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) + for _,reg in ipairs(self.dispatchregistry) do + table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) + end + table.insert(txt, '{NULL, NULL}') + table.insert(txt, '};') + table.insert(txt, '') + self.dispatchregistry = {} +end + +interface:print('/* WARNING: autogenerated file */') +interface:print('') + +local reals = {ByteTensor='byte', + CharTensor='char', + ShortTensor='short', + IntTensor='int', + LongTensor='long', + FloatTensor='float', + DoubleTensor='double'} + +for _,Tensor in ipairs({"FloatTensor", "DoubleTensor", "IntTensor", "LongTensor", "ByteTensor", "CharTensor","ShortTensor"}) do + + local real = reals[Tensor] + + function interface.luaname2wrapname(self, name) + return string.format('torch_%s_%s', Tensor, name) + end + + local function cname(name) + return string.format('TH%s_%s', Tensor, name) + end + + local function lastdim(argn) + return function(arg) + return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg()) + end + end + + + for _,name in ipairs({"conv2","xcorr2","conv3","xcorr3"}) do + interface:wrap(name, + cname(name), + {{name=Tensor, default=true, returned=true}, + {name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=Tensor}} + ) + end + + + --interface:register(string.format("torch_%sLapack__", Tensor)) + +-- interface:print(string.gsub([[ +-- static void torch_TensorLapack_init(lua_State *L) +-- { +-- torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor"); +-- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); + +-- luaT_pushmetaclass(L, torch_Tensor_id); +-- lua_getfield(L,-1,"torch"); +-- luaL_register(L, NULL, torch_TensorLapack__); +-- lua_pop(L, 2); +-- } +-- ]], 'Tensor', Tensor)) +end + +interface:dispatchregister("torch_TensorConv__") + +if arg[1] then + interface:tofile(arg[1]) +else + interface:tostdio() +end diff --git a/TensorLapackWrap.lua b/TensorLapackWrap.lua new file mode 100644 index 00000000000..ae1fba0f4e2 --- /dev/null +++ b/TensorLapackWrap.lua @@ -0,0 +1,132 @@ +-- +-- require 'wrap' +--- + +interface = wrap.CInterface.new() + + +interface.dispatchregistry = {} +function interface:wrap(name, ...) + -- usual stuff + wrap.CInterface.wrap(self, name, ...) + + -- dispatch function + if not interface.dispatchregistry[name] then + interface.dispatchregistry[name] = true + table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)}) + + interface:print(string.gsub([[ +static int torch_NAME(lua_State *L) +{ + int narg = lua_gettop(L); + const void *id; + + if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */ + { + if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */ + { + if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */ + lua_pop(L, 1); + else if(!(id = torch_istensorid(L, torch_getdefaulttensorid()))) + luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor"); + } + } + + lua_pushstring(L, "NAME"); + lua_rawget(L, -2); + if(lua_isfunction(L, -1)) + { + lua_insert(L, 1); + lua_pop(L, 2); /* the two tables we put on the stack above */ + lua_call(L, lua_gettop(L)-1, LUA_MULTRET); + } + else + return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id)); + + return lua_gettop(L); +} +]], 'NAME', name)) + end +end + +function interface:dispatchregister(name) + local txt = self.txt + table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) + for _,reg in ipairs(self.dispatchregistry) do + table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) + end + table.insert(txt, '{NULL, NULL}') + table.insert(txt, '};') + table.insert(txt, '') + self.dispatchregistry = {} +end + +interface:print('/* WARNING: autogenerated file */') +interface:print('') + +local reals = {ByteTensor='byte', + CharTensor='char', + ShortTensor='short', + IntTensor='int', + LongTensor='long', + FloatTensor='float', + DoubleTensor='double'} + +for _,Tensor in ipairs({"FloatTensor", "DoubleTensor"}) do + + local real = reals[Tensor] + + function interface.luaname2wrapname(self, name) + return string.format('torch_%s_%s', Tensor, name) + end + + local function cname(name) + return string.format('TH%s_%s', Tensor, name) + end + + local function lastdim(argn) + return function(arg) + return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg()) + end + end + + +-- for _,name in ipairs({"gesv","gels","eig","svd"}) do +-- interface:wrap(name, +-- cname(name), +-- {{name=Tensor, returned=true}, +-- {name=Tensor, returned=true}, +-- {name=Tensor}, +-- {name=Tensor}}, +-- cname(name), +-- {{name=Tensor, default=true, returned=true, invisible=true}, +-- {name=Tensor, default=true, returned=true, invisible=true}, +-- {name=Tensor}, +-- {name=Tensor}} +-- ) +-- end + + + --interface:register(string.format("torch_%sLapack__", Tensor)) + +-- interface:print(string.gsub([[ +-- static void torch_TensorLapack_init(lua_State *L) +-- { +-- torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor"); +-- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); + +-- luaT_pushmetaclass(L, torch_Tensor_id); +-- lua_getfield(L,-1,"torch"); +-- luaL_register(L, NULL, torch_TensorLapack__); +-- lua_pop(L, 2); +-- } +-- ]], 'Tensor', Tensor)) +end + +interface:dispatchregister("torch_TensorLapack__") + +if arg[1] then + interface:tofile(arg[1]) +else + interface:tostdio() +end diff --git a/TensorMath.c b/TensorMath.c new file mode 100644 index 00000000000..05ed6fe8ebf --- /dev/null +++ b/TensorMath.c @@ -0,0 +1,54 @@ +#include "TH.h" +#include "luaT.h" +#include "utils.h" + +#include "sys/time.h" + +#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) +#define torch_string_(NAME) TH_CONCAT_STRING_3(torch., Real, NAME) + +static const void* torch_ByteTensor_id; +static const void* torch_CharTensor_id; +static const void* torch_ShortTensor_id; +static const void* torch_IntTensor_id; +static const void* torch_LongTensor_id; +static const void* torch_FloatTensor_id; +static const void* torch_DoubleTensor_id; + +static const void* torch_LongStorage_id; + + +#include "TensorMathWrap.c" +//#include "TensorLapackWrap.c" +//#include "TensorConvWrap.c" + +//#include "generic/TensorLapack.c" +//#include "THGenerateFloatTypes.h" + +//#include "generic/TensorConv.c" +//#include "THGenerateAllTypes.h" + +void torch_TensorMath_init(lua_State *L) +{ + torch_ByteTensorMath_init(L); + torch_CharTensorMath_init(L); + torch_ShortTensorMath_init(L); + torch_IntTensorMath_init(L); + torch_LongTensorMath_init(L); + torch_FloatTensorMath_init(L); + torch_DoubleTensorMath_init(L); + luaL_register(L, NULL, torch_TensorMath__); + +/* torch_FloatLapack_init(L); */ +/* torch_DoubleLapack_init(L); */ +/* luaL_register(L, NULL, torch_TensorLapack__); */ + +/* torch_ByteConv_init(L); */ +/* torch_CharConv_init(L); */ +/* torch_ShortConv_init(L); */ +/* torch_IntConv_init(L); */ +/* torch_LongConv_init(L); */ +/* torch_FloatConv_init(L); */ +/* torch_DoubleConv_init(L); */ +/* luaL_register(L, NULL, torch_TensorConv__); */ +} diff --git a/TensorMath.lua b/TensorMath.lua new file mode 100644 index 00000000000..7dddb00afb7 --- /dev/null +++ b/TensorMath.lua @@ -0,0 +1,110 @@ +for _,tensortype in ipairs({'ByteTensor', + 'CharTensor', + 'ShortTensor', + 'IntTensor', + 'LongTensor', + 'FloatTensor', + 'DoubleTensor'}) do + + for _,func in ipairs({'add', + 'mul', + 'div', + 'cmul', + 'cdiv', + 'addcmul', + 'addcdiv', + 'log', + 'log1p', + 'exp', + 'cos', + 'acos', + 'cosh', + 'sin', + 'asin', + 'sinh', + 'tan', + 'atan', + 'tanh', + 'pow', + 'sqrt', + 'ceil', + 'floor', + 'abs', + 'sign' + }) do + + local torchfunc = torch[tensortype].torch[func] + torch[tensortype][func] = function(self, ...) + return torchfunc(self, self, ...) + end + end + + for _,func in ipairs({'addmv', + 'addmm', + 'addr'}) do + + local torchfunc = torch[tensortype].torch[func] + torch[tensortype][func] = function(self, next1, next2, ...) + if type(next1) == 'number' and type(next2) == 'number' then + return torchfunc(self, next1, self, next2, ...) + elseif type(next1) == 'number' then + return torchfunc(self, self, next1, next2, ...) + else + return torchfunc(self, self, next1, next2, ...) + end + end + end + + for _,func in ipairs({'zero', + 'fill', + 'dot', + 'minall', + 'maxall', + 'sumall', + 'numel', + 'max', + 'min', + 'sum', + 'prod', + 'cumsum', + 'cumprod', + 'trace', + 'cross', + 'zeros', + 'ones', + 'diag', + 'eye', + 'range', + 'randperm', + 'reshape', + 'sort', + 'tril', + 'triu', + '_histc', + 'cat', + 'mean', + 'std', + 'var', + 'norm', + 'dist', + 'meanall', + 'varall', + 'stdall', + 'linspace', + 'logspace', + 'rand', + 'randn', + 'random', + 'uniform', + 'normal', + 'cauchy', + 'logNormal', + 'exponential', + 'geometric', + 'bernoulli', + 'squeeze' + }) do + + torch[tensortype][func] = torch[tensortype].torch[func] + end +end diff --git a/TensorMathWrap.lua b/TensorMathWrap.lua new file mode 100644 index 00000000000..7f2f6cb0172 --- /dev/null +++ b/TensorMathWrap.lua @@ -0,0 +1,870 @@ +-- +-- require 'wrap' +--- + +local interface = wrap.CInterface.new() + +interface:print([[ +#include "TH.h" +#include "luaT.h" +#include "utils.h" + +static const void* torch_ByteTensor_id; +static const void* torch_CharTensor_id; +static const void* torch_ShortTensor_id; +static const void* torch_IntTensor_id; +static const void* torch_LongTensor_id; +static const void* torch_FloatTensor_id; +static const void* torch_DoubleTensor_id; + +static const void* torch_LongStorage_id; + ]]) + +-- special argument specific to torch package +interface.argtypes.LongArg = { + + vararg = true, + + helpname = function(arg) + return "(LongStorage | dim1 [dim2...])" + end, + + declare = function(arg) + return string.format("THLongStorage *arg%d = NULL;", arg.i) + end, + + init = function(arg) + if arg.default then + error('LongArg cannot have a default value') + end + end, + + check = function(arg, idx) + return string.format("torch_islongargs(L, %d)", idx) + end, + + read = function(arg, idx) + return string.format("arg%d = torch_checklongargs(L, %d);", arg.i, idx) + end, + + carg = function(arg, idx) + return string.format('arg%d', arg.i) + end, + + creturn = function(arg, idx) + return string.format('arg%d', arg.i) + end, + + precall = function(arg) + local txt = {} + if arg.returned then + table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongStorage_id);', arg.i)) + end + return table.concat(txt, '\n') + end, + + postcall = function(arg) + local txt = {} + if arg.creturned then + -- this next line is actually debatable + table.insert(txt, string.format('THLongStorage_retain(arg%d);', arg.i)) + table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongStorage_id);', arg.i)) + end + if not arg.returned and not arg.creturned then + table.insert(txt, string.format('THLongStorage_free(arg%d);', arg.i)) + end + return table.concat(txt, '\n') + end +} + +interface.argtypes.charoption = { + + helpname = function(arg) + if arg.values then + return "(" .. table.concat(arg.values, '|') .. ")" + end + end, + + declare = function(arg) + local txt = {} + table.insert(txt, string.format("const char *arg%d = NULL;", arg.i)) + if arg.default then + table.insert(txt, string.format("char arg%d_default = '%s';", arg.i, arg.default)) + end + return table.concat(txt, '\n') + end, + + init = function(arg) + return string.format("arg%d = &arg%d_default;", arg.i, arg.i) + end, + + check = function(arg, idx) + local txt = {} + local txtv = {} + table.insert(txt, string.format('(arg%d = lua_tostring(L, %d)) && (', arg.i, idx)) + for _,value in ipairs(arg.values) do + table.insert(txtv, string.format("*arg%d == '%s'", arg.i, value)) + end + table.insert(txt, table.concat(txtv, ' || ')) + table.insert(txt, ')') + return table.concat(txt, '') + end, + + read = function(arg, idx) + end, + + carg = function(arg, idx) + return string.format('arg%d', arg.i) + end, + + creturn = function(arg, idx) + end, + + precall = function(arg) + end, + + postcall = function(arg) + end +} + +-- also specific to torch: we generate a 'dispatch' function +-- first we create a helper function +interface:print([[ +static const void* torch_istensorid(lua_State *L, const void *id) +{ + if(!id) + return NULL; + + luaT_pushmetaclass(L, id); + lua_pushstring(L, "torch"); + lua_rawget(L, -2); + if(lua_istable(L, -1)) + return id; + else + { + lua_pop(L, 2); + return NULL; + } + + return NULL; +} +]]) + +interface.dispatchregistry = {} +function interface:wrap(name, ...) + -- usual stuff + wrap.CInterface.wrap(self, name, ...) + + -- dispatch function + if not interface.dispatchregistry[name] then + interface.dispatchregistry[name] = true + table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)}) + + interface:print(string.gsub([[ +static int torch_NAME(lua_State *L) +{ + int narg = lua_gettop(L); + const void *id; + + if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */ + { + if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */ + { + if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */ + lua_pop(L, 1); + else if(!(id = torch_istensorid(L, torch_getdefaulttensorid()))) + luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor"); + } + } + + lua_pushstring(L, "NAME"); + lua_rawget(L, -2); + if(lua_isfunction(L, -1)) + { + lua_insert(L, 1); + lua_pop(L, 2); /* the two tables we put on the stack above */ + lua_call(L, lua_gettop(L)-1, LUA_MULTRET); + } + else + return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id)); + + return lua_gettop(L); +} +]], 'NAME', name)) + end +end + +function interface:dispatchregister(name) + local txt = self.txt + table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) + for _,reg in ipairs(self.dispatchregistry) do + table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) + end + table.insert(txt, '{NULL, NULL}') + table.insert(txt, '};') + table.insert(txt, '') + self.dispatchregistry = {} +end + +interface:print('/* WARNING: autogenerated file */') +interface:print('') + +local reals = {ByteTensor='unsigned char', + CharTensor='char', + ShortTensor='short', + IntTensor='int', + LongTensor='long', + FloatTensor='float', + DoubleTensor='double'} + +for _,Tensor in ipairs({"ByteTensor", "CharTensor", + "ShortTensor", "IntTensor", "LongTensor", + "FloatTensor", "DoubleTensor"}) do + + local real = reals[Tensor] + + function interface.luaname2wrapname(self, name) + return string.format('torch_%s_%s', Tensor, name) + end + + local function cname(name) + return string.format('TH%s_%s', Tensor, name) + end + + local function lastdim(argn) + return function(arg) + return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg()) + end + end + + interface:wrap("zero", + cname("zero"), + {{name=Tensor, returned=true}}) + + interface:wrap("fill", + cname("fill"), + {{name=Tensor, returned=true}, + {name=real}}) + + interface:wrap("zeros", + cname("zeros"), + {{name=Tensor, default=true, returned=true}, + {name="LongArg"}}) + + interface:wrap("ones", + cname("ones"), + {{name=Tensor, default=true, returned=true}, + {name="LongArg"}}) + + interface:wrap("reshape", + cname("reshape"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="LongArg"}}) + + interface:wrap("dot", + cname("dot"), + {{name=Tensor}, + {name=Tensor}, + {name=real, creturned=true}}) + + for _,name in ipairs({"minall", "maxall", "sumall"}) do + interface:wrap(name, + cname(name), + {{name=Tensor}, + {name=real, creturned=true}}) + end + + interface:wrap("add", + cname("add"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=real}}, + cname("cadd"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=real, default=1}, + {name=Tensor}}) + + interface:wrap("mul", + cname("mul"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=real}}) + + interface:wrap("div", + cname("div"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=real}}) + + interface:wrap("cmul", + cname("cmul"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=Tensor}}) + + interface:wrap("cdiv", + cname("cdiv"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=Tensor}}) + + interface:wrap("addcmul", + cname("addcmul"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=real, default=1}, + {name=Tensor}, + {name=Tensor}}) + + interface:wrap("addcdiv", + cname("addcdiv"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=real, default=1}, + {name=Tensor}, + {name=Tensor}}) + + for _,name in ipairs({"addmv", "addmm", "addr"}) do + interface:wrap(name, + cname(name), + {{name=Tensor, default=true, returned=true}, + {name=real, default=1}, + {name=Tensor}, + {name=real, default=1}, + {name=Tensor}, + {name=Tensor}}) + end + + interface:wrap("numel", + cname("numel"), + {{name=Tensor}, + {name=real, creturned=true}}) + + for _,name in ipairs({"sum", "prod", "cumsum", "cumprod"}) do + interface:wrap(name, + cname(name), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="index", default=lastdim(2)}}) + end + + interface:wrap("min", + cname("min"), + {{name=Tensor, default=true, returned=true}, + {name="IndexTensor", default=true, returned=true}, + {name=Tensor}, + {name="index", default=lastdim(3)}}) + + interface:wrap("max", + cname("max"), + {{name=Tensor, default=true, returned=true}, + {name="IndexTensor", default=true, returned=true}, + {name=Tensor}, + {name="index", default=lastdim(3)}}) + + interface:wrap("trace", + cname("trace"), + {{name=Tensor}, + {name=real, creturned=true}}) + + interface:wrap("cross", + cname("cross"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=Tensor}, + {name="index", default=0}}) + + interface:wrap("diag", + cname("diag"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="long", default=0}}) + + interface:wrap("eye", + cname("eye"), + {{name=Tensor, default=true, returned=true}, + {name="long"}, + {name="long", default=0}}) + + interface:wrap("range", + cname("range"), + {{name=Tensor, default=true, returned=true}, + {name=real}, + {name=real}, + {name=real, default=1}}) + + interface:wrap("randperm", + cname("randperm"), + {{name=Tensor, default=true, returned=true, userpostcall=function(arg) + return string.format("TH%s_add(%s, %s, 1);", Tensor, arg:carg(), arg:carg()) + end}, + {name="long"}}) + + interface:wrap("sort", + cname("sort"), + {{name=Tensor, default=true, returned=true}, + {name="IndexTensor", default=true, returned=true}, + {name=Tensor}, + {name="index", default=lastdim(3)}, + {name="boolean", default=0}}) + + + interface:wrap("tril", + cname("tril"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="int", default=0}}) + + interface:wrap("triu", + cname("triu"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="int", default=0}}) + + interface:wrap("cat", + cname("cat"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=Tensor}, + {name="index", default=lastdim(2)}}) + + if Tensor == 'ByteTensor' then -- we declare this only once + interface:print( + [[ +static int THRandom_random2__(long a, long b) +{ + THArgCheck(b >= a, 2, "upper bound must be larger than lower bound"); + return((THRandom_random() % (b+1-a)) + a); +} + +static int THRandom_random1__(long b) +{ + THArgCheck(b > 0, 1, "upper bound must be strictly positive"); + return(THRandom_random() % b + 1); +} + ]]) + end + + interface:print(string.gsub( + [[ +static void THTensor_random2__(THTensor *self, long a, long b) +{ + THArgCheck(b >= a, 2, "upper bound must be larger than lower bound"); + TH_TENSOR_APPLY(real, self, *self_data = ((THRandom_random() % (b+1-a)) + a);) +} + +static void THTensor_random1__(THTensor *self, long b) +{ + THArgCheck(b > 0, 1, "upper bound must be strictly positive"); + TH_TENSOR_APPLY(real, self, *self_data = (THRandom_random() % b + 1);) +} +]], 'Tensor', Tensor):gsub('real', real)) + + interface:wrap('random', + 'THRandom_random2__', + {{name='long'}, + {name='long'}, + {name='long', creturned=true}}, + 'THRandom_random1__', + {{name='long'}, + {name='long', creturned=true}}, + 'THRandom_random', + {{name='long', creturned=true}}, + cname("random2__"), + {{name=Tensor}, + {name='long'}, + {name='long'}}, + cname("random1__"), + {{name=Tensor}, + {name='long'}}, + cname("random"), + {{name=Tensor}}) + + for _,f in ipairs({{name='geometric'}, + {name='bernoulli', a=0.5}}) do + + interface:wrap(f.name, + string.format("THRandom_%s", f.name), + {{name="double", default=f.a}, + {name="double", creturned=true}}, + cname(f.name), + {{name=Tensor, returned=true}, + {name=real, default=f.a}}) + end + + interface:wrap("squeeze", + cname("squeeze"), + {{name=Tensor, default=true, returned=true, postcall=function(arg) + local txt = {} + if arg.returned then + table.insert(txt, string.format('if(arg%d->nDimension == 1 && arg%d->size[0] == 1)', arg.i, arg.i)) -- number + table.insert(txt, string.format('lua_pushnumber(L, (lua_Number)(*TH%s_data(arg%d)));', Tensor, arg.i)) + end + return table.concat(txt, '\n') + end}, + {name=Tensor}}, + cname("squeeze1d"), + {{name=Tensor, default=true, returned=true, postcall=function(arg) + local txt = {} + if arg.returned then + table.insert(txt, string.format('if(arg%d->nDimension == 1 && arg%d->size[0] == 1)', arg.i, arg.i)) -- number + table.insert(txt, string.format('lua_pushnumber(L, (lua_Number)(*TH%s_data(arg%d)));', Tensor, arg.i)) + end + return table.concat(txt, '\n') + end}, + {name=Tensor}, + {name="index"}}) + + interface:wrap("sign", + cname("sign"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}}) + + interface:wrap("conv2", + cname("conv2Dmul"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=2}, + {name=Tensor, dim=2}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="C", invisible=true}}, + cname("conv2Dcmul"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=3}, + {name=Tensor, dim=3}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="C", invisible=true}}, + cname("conv2Dmv"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=3}, + {name=Tensor, dim=4}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="C", invisible=true}} + ) + + interface:wrap("xcorr2", + cname("conv2Dmul"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=2}, + {name=Tensor, dim=2}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="X", invisible=true}}, + cname("conv2Dcmul"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=3}, + {name=Tensor, dim=3}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="X", invisible=true}}, + cname("conv2Dmv"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=3}, + {name=Tensor, dim=4}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="X", invisible=true}} + ) + + interface:wrap("conv3", + cname("conv3Dmul"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=3}, + {name=Tensor, dim=3}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="C", invisible=true}}, + cname("conv3Dcmul"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=4}, + {name=Tensor, dim=4}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="C", invisible=true}}, + cname("conv3Dmv"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=4}, + {name=Tensor, dim=5}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="C", invisible=true}} + ) + + interface:wrap("xcorr3", + cname("conv3Dmul"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=3}, + {name=Tensor, dim=3}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="X", invisible=true}}, + cname("conv3Dcmul"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=4}, + {name=Tensor, dim=4}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="X", invisible=true}}, + cname("conv3Dmv"), + {{name=Tensor, default=true, returned=true}, + {name=real, default=0, invisible=true}, + {name=real, default=1, invisible=true}, + {name=Tensor, dim=4}, + {name=Tensor, dim=5}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name=real, default=1, invisible=true}, + {name='charoption', values={'V', 'F'}, default='V'}, + {name='charoption', default="X", invisible=true}} + ) + + if Tensor == 'FloatTensor' or Tensor == 'DoubleTensor' then + + interface:wrap("mean", + cname("mean"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="index", default=lastdim(2)}}) + + interface:wrap("std", + cname("std"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="index", default=lastdim(2)}, + {name="boolean", default=false}}) + + interface:wrap("var", + cname("var"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="index", default=lastdim(2)}, + {name="boolean", default=false}}) + + interface:wrap("norm", + cname("norm"), + {{name=Tensor}, + {name=real, default=2}, + {name=real, creturned=true}}) + + interface:wrap("dist", + cname("dist"), + {{name=Tensor}, + {name=Tensor}, + {name=real, default=2}, + {name=real, creturned=true}}) + + for _,name in ipairs({"meanall", "varall", "stdall"}) do + interface:wrap(name, + cname(name), + {{name=Tensor}, + {name=real, creturned=true}}) + end + + interface:wrap("linspace", + cname("linspace"), + {{name=Tensor, default=true, returned=true}, + {name=real}, + {name=real}, + {name="long", default=100}}) + + interface:wrap("logspace", + cname("logspace"), + {{name=Tensor, default=true, returned=true}, + {name=real}, + {name=real}, + {name="long", default=100}}) + + for _,name in ipairs({"log", "log1p", "exp", + "cos", "acos", "cosh", + "sin", "asin", "sinh", + "tan", "atan", "tanh", + "sqrt", + "ceil", "floor", + "abs"}) do + + interface:wrap(name, + cname(name), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}}, + name, + {{name=real}, + {name=real, creturned=true}}) + + end + + interface:wrap("pow", + cname("pow"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=real}}, + "pow", + {{name=real}, + {name=real}, + {name=real, creturned=true}}) + + interface:wrap("rand", + cname("rand"), + {{name=Tensor, default=true, returned=true}, + {name="LongArg"}}) + + interface:wrap("randn", + cname("randn"), + {{name=Tensor, default=true, returned=true}, + {name="LongArg"}}) + + for _,f in ipairs({{name='uniform', a=0, b=1}, + {name='normal', a=0, b=1}, + {name='cauchy', a=0, b=1}, + {name='logNormal', a=1, b=2}}) do + + interface:wrap(f.name, + string.format("THRandom_%s", f.name), + {{name="double", default=f.a}, + {name="double", default=f.b}, + {name="double", creturned=true}}, + cname(f.name), + {{name=Tensor, returned=true}, + {name=real, default=f.a}, + {name=real, default=f.b}}) + end + + for _,f in ipairs({{name='exponential'}}) do + + interface:wrap(f.name, + string.format("THRandom_%s", f.name), + {{name="double", default=f.a}, + {name="double", creturned=true}}, + cname(f.name), + {{name=Tensor, returned=true}, + {name=real, default=f.a}}) + end + + for _,name in ipairs({"gesv","gels"}) do + interface:wrap(name, + cname(name), + {{name=Tensor, returned=true}, + {name=Tensor, returned=true}, + {name=Tensor}, + {name=Tensor}}, + cname(name), + {{name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor}, + {name=Tensor}} + ) + end + + interface:wrap("eig", + cname("syev"), + {{name=Tensor, returned=true}, + {name=Tensor, returned=true}, + {name=Tensor}, + {name='charoption', values={'N', 'V'}, default='N'}, + {name='charoption', values={'U', 'L'}, default='U'}}, + cname("syev"), + {{name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor}, + {name='charoption', values={'N', 'V'}, default='N'}, + {name='charoption', values={'U', 'L'}, default='U'}} + ) + + interface:wrap("svd", + cname("gesvd"), + {{name=Tensor, returned=true}, + {name=Tensor, returned=true}, + {name=Tensor, returned=true}, + {name=Tensor}, + {name='charoption', values={'A', 'S'}, default='S'}}, + cname("gesvd"), + {{name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor}, + {name='charoption', values={'A', 'S'}, default='S'}} + ) + + end + + interface:register(string.format("torch_%sMath__", Tensor)) + + interface:print(string.gsub([[ +static void torch_TensorMath_init(lua_State *L) +{ + torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor"); + torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); + + /* register everything into the "torch" field of the tensor metaclass */ + luaT_pushmetaclass(L, torch_Tensor_id); + lua_pushstring(L, "torch"); + lua_newtable(L); + luaL_register(L, NULL, torch_TensorMath__); + lua_rawset(L, -3); + lua_pop(L, 1); +} +]], 'Tensor', Tensor)) +end + +interface:dispatchregister("torch_TensorMath__") + +interface:print([[ +void torch_TensorMath_init(lua_State *L) +{ + torch_ByteTensorMath_init(L); + torch_CharTensorMath_init(L); + torch_ShortTensorMath_init(L); + torch_IntTensorMath_init(L); + torch_LongTensorMath_init(L); + torch_FloatTensorMath_init(L); + torch_DoubleTensorMath_init(L); + luaL_register(L, NULL, torch_TensorMath__); +} +]]) + +if arg[1] then + interface:tofile(arg[1]) +else + interface:tostdio() +end diff --git a/TensorOperator.c b/TensorOperator.c new file mode 100644 index 00000000000..d0671490a92 --- /dev/null +++ b/TensorOperator.c @@ -0,0 +1,8 @@ +#include "general.h" + +#define torch_TensorOperator_(NAME) TH_CONCAT_4(torch_,Real,TensorOperator_,NAME) +#define torch_Tensor_id TH_CONCAT_3(torch_,Real,Tensor_id) +#define STRING_torchTensor TH_CONCAT_STRING_3(torch.,Real,Tensor) + +#include "generic/TensorOperator.c" +#include "THGenerateAllTypes.h" diff --git a/Tester.lua b/Tester.lua new file mode 100644 index 00000000000..2be1a48fe42 --- /dev/null +++ b/Tester.lua @@ -0,0 +1,124 @@ +local Tester = torch.class('torch.Tester') + +function Tester:__init() + self.errors = {} + self.tests = {} + self.testnames = {} + self.curtestname = '' +end + + +function Tester:assert_sub (condition, message) + if not condition then + local ss = debug.traceback('tester',2) + --print(ss) + ss = ss:match('[^\n]+\n[^\n]+\n([^\n]+\n[^\n]+)\n') + self.errors[#self.errors+1] = self.curtestname .. '\n' .. message .. '\n' .. ss .. '\n' + end +end +function Tester:assert (condition, message) + self:assert_sub(condition,string.format('%s\n%s condition=%s',message,' BOOL violation ', tostring(condition))) +end +function Tester:assertlt (val, condition, message) + self:assert_sub(valcondition,string.format('%s\n%s val=%s, condition=%s',message,' GT(>) violation ', tostring(val), tostring(condition))) +end +function Tester:assertle (val, condition, message) + self:assert_sub(val<=condition,string.format('%s\n%s val=%s, condition=%s',message,' LE(<=) violation ', tostring(val), tostring(condition))) +end +function Tester:assertge (val, condition, message) + self:assert_sub(val>=condition,string.format('%s\n%s val=%s, condition=%s',message,' GE(>=) violation ', tostring(val), tostring(condition))) +end +function Tester:asserteq (val, condition, message) + self:assert_sub(val==condition,string.format('%s\n%s val=%s, condition=%s',message,' EQ(==) violation ', tostring(val), tostring(condition))) +end +function Tester:assertne (val, condition, message) + self:assert_sub(val~=condition,string.format('%s\n%s val=%s, condition=%s',message,' NE(~=) violation ', tostring(val), tostring(condition))) +end +function Tester:assertTensorEq(ta, tb, condition, message) + local diff = ta-tb + local err = diff:abs():maxall() + self:assert_sub(err ' .. self.curtestname + io.write('\r' .. pstr) + io.flush() + + local stat, message, pass = self:pcall(v) + + if pass then + --io.write(string.format('\b_')) + statstr = statstr:sub(1,i-1) .. '_' .. statstr:sub(i+1) + else + statstr = statstr:sub(1,i-1) .. '*' .. statstr:sub(i+1) + --io.write(string.format('\b*')) + end + + if not stat then + print() + print('Function call failed: Test No ' .. i .. ' ' .. self.testnames[i]) + print(message) + end + collectgarbage() + end + --clear + io.write('\r' .. string.rep(' ', pstr:len())) + io.flush() + -- write finish + pstr = statstr .. ' ==> Done ' + io.write('\r' .. pstr) + io.flush() + print() + print() + self:report() +end + +function Tester:add(f,name) + name = name or 'unknown' + if type(f) == "table" then + for i,v in pairs(f) do + self:add(v,i) + end + elseif type(f) == "function" then + self.tests[#self.tests+1] = f + self.testnames[#self.tests] = name + else + error('Tester:add(f) expects a function or a table of functions') + end +end diff --git a/Timer.c b/Timer.c new file mode 100644 index 00000000000..6935fae2029 --- /dev/null +++ b/Timer.c @@ -0,0 +1,157 @@ +#include "general.h" + +#ifdef _MSC_VER +#include +#else +#include +#include +#endif + +#ifdef _MSC_VER +static time_t base_time = 0; +#endif + +static const void* torch_Timer_id = NULL; + +typedef struct _Timer +{ + int isRunning; + + double totalrealtime; + double totalusertime; + double totalsystime; + + double startrealtime; + double startusertime; + double startsystime; + +} Timer; + +static double torch_Timer_realtime() +{ + struct timeval current; + gettimeofday(¤t, NULL); + return (current.tv_sec + current.tv_usec/1000000.0); +} + +static double torch_Timer_usertime() +{ + struct rusage current; + getrusage(RUSAGE_SELF, ¤t); + return (current.ru_utime.tv_sec + current.ru_utime.tv_usec/1000000.0); +} + +static double torch_Timer_systime() +{ + struct rusage current; + getrusage(RUSAGE_SELF, ¤t); + return (current.ru_stime.tv_sec + current.ru_stime.tv_usec/1000000.0); +} + +static int torch_Timer_new(lua_State *L) +{ + Timer *timer = luaT_alloc(L, sizeof(Timer)); +#ifdef _MSC_VER + while(!base_time) + time(&base_time); +#endif + timer->isRunning = 1; + timer->totalrealtime = 0; + timer->totalusertime = 0; + timer->totalsystime = 0; + timer->startrealtime = torch_Timer_realtime(); + timer->startusertime = torch_Timer_usertime(); + timer->startsystime = torch_Timer_systime(); + luaT_pushudata(L, timer, torch_Timer_id); + return 1; +} + +static int torch_Timer_reset(lua_State *L) +{ + Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); + timer->totalrealtime = 0; + timer->totalusertime = 0; + timer->totalsystime = 0; + timer->startrealtime = torch_Timer_realtime(); + timer->startusertime = torch_Timer_usertime(); + timer->startsystime = torch_Timer_systime(); + lua_settop(L, 1); + return 1; +} + +static int torch_Timer_free(lua_State *L) +{ + Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); + luaT_free(L, timer); + return 0; +} + +static int torch_Timer_stop(lua_State *L) +{ + Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); + if(timer->isRunning) + { + double realtime = torch_Timer_realtime() - timer->startrealtime; + double usertime = torch_Timer_usertime() - timer->startusertime; + double systime = torch_Timer_systime() - timer->startsystime; + timer->totalrealtime += realtime; + timer->totalusertime += usertime; + timer->totalsystime += systime; + timer->isRunning = 0; + } + lua_settop(L, 1); + return 1; +} + +static int torch_Timer_resume(lua_State *L) +{ + Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); + if(!timer->isRunning) + { + timer->isRunning = 1; + timer->startrealtime = torch_Timer_realtime(); + timer->startusertime = torch_Timer_usertime(); + timer->startsystime = torch_Timer_systime(); + } + lua_settop(L, 1); + return 1; +} + +static int torch_Timer_time(lua_State *L) +{ + Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); + double realtime = (timer->isRunning ? (timer->totalrealtime + torch_Timer_realtime() - timer->startrealtime) : timer->totalrealtime); + double usertime = (timer->isRunning ? (timer->totalusertime + torch_Timer_usertime() - timer->startusertime) : timer->totalusertime); + double systime = (timer->isRunning ? (timer->totalsystime + torch_Timer_systime() - timer->startsystime) : timer->totalsystime); + lua_createtable(L, 0, 3); + lua_pushnumber(L, realtime); + lua_setfield(L, -2, "real"); + lua_pushnumber(L, usertime); + lua_setfield(L, -2, "user"); + lua_pushnumber(L, systime); + lua_setfield(L, -2, "sys"); + return 1; +} + +static int torch_Timer___tostring__(lua_State *L) +{ + Timer *timer = luaT_checkudata(L, 1, torch_Timer_id); + lua_pushfstring(L, "torch.Timer [status: %s]", (timer->isRunning ? "running" : "stopped")); + return 1; +} + +static const struct luaL_Reg torch_Timer__ [] = { + {"reset", torch_Timer_reset}, + {"stop", torch_Timer_stop}, + {"resume", torch_Timer_resume}, + {"time", torch_Timer_time}, + {"__tostring__", torch_Timer___tostring__}, + {NULL, NULL} +}; + +void torch_Timer_init(lua_State *L) +{ + torch_Timer_id = luaT_newmetatable(L, "torch.Timer", NULL, torch_Timer_new, torch_Timer_free, NULL); + luaL_register(L, NULL, torch_Timer__); + lua_pop(L, 1); +} diff --git a/dok/cmdline.dok b/dok/cmdline.dok new file mode 100644 index 00000000000..47ec1cb9842 --- /dev/null +++ b/dok/cmdline.dok @@ -0,0 +1,115 @@ +====== CmdLine ====== +{{anchor:torch.CmdLine.dok}} + +This class provides a parameter parsing framework which is very +usefull when one needs to run several experiments that rely on +different parameter settings that are passed in the command line. +This class will also override the default print function to direct +all the output to a log file as well as screen at the same time. + +A sample ''lua'' file is given below that makes use of ''CmdLine'' +class. + + + +cmd = torch.CmdLine() +cmd:text() +cmd:text() +cmd:text('Training a simple network') +cmd:text() +cmd:text('Options') +cmd:option('-seed',123,'initial random seed') +cmd:option('-booloption',false,'boolean option') +cmd:option('-stroption','mystring','string option') +cmd:text() + +-- parse input params +params = cmd:pard(arg) + +params.rundir = cmd:string('experiment', params, {dir=true}) + +-- create log file +cmd:log(params.rundir .. '/log', params) + + + +When this file is run on the lua commandline as follows + +# lua myscript.lua + + +It will produce the following output: + + +[program started on Tue Jan 10 15:33:49 2012] +[command line arguments] +booloption false +seed 123 +rundir experiment +stroption mystring +[----------------------] +booloption false +seed 123 +rundir experiment +stroption mystring + + +The same output will also be written to file +''experiment/log''. Whenever one of the options are passed on the +command line and is different than the default value, the ''rundir'' +is name is produced to reflect the parameter setting. + + +# lua myscript.lua -seed 456 -stroption mycustomstring + + +This will produce the following output: + + +[program started on Tue Jan 10 15:36:55 2012] +[command line arguments] +booloption false +seed 456 +rundir experiment,seed=456,stroption=mycustomstring +stroption mycustomstring +[----------------------] +booloption false +seed 456 +rundir experiment,seed=456,stroption=mycustomstring +stroption mycustomstring + + +and the output will be logged in +''experiment,seed=456,stroption=mycustomstring/log'' + +==== text(string) ==== +{{anchor:torch.CmdLine.text}} +Logs a custom text message. + +==== option(name, default, help) ==== +{{anchor:torch.CmdLine.option}} +Stores an option argument. The name should always start with '-'. + +==== [table] parse(arg) ==== +{{anchor:torch.CmdLine.parse}} +Parses a given table, ''arg'' is by default the argument table that +is created by ''lua'' using the command line arguments passed to the +executable. Returns a table of option values. + +==== [string] string(prefix, params, ignore) ==== +{{anchor:torch.CmdLine.string}} + +Returns a string representation of the options by concatenating the +non-default options. ''ignore'' is a table ''{dir=true}'', which will +ensure that option named ''dir'' will be ignored while creating the +string representation. + +This function is usefull for creating unique experiment directories that +depend on the parameter settings. + +==== log(filename, parameter_table) ==== +{{anchor:torch.CmdLine.log}} + +It set the log filename to ''filename'' and prints the values of +parameters in the ''parameter_table''. + diff --git a/dok/diskfile.dok b/dok/diskfile.dok new file mode 100644 index 00000000000..4691c84f7d6 --- /dev/null +++ b/dok/diskfile.dok @@ -0,0 +1,64 @@ +====== DiskFile ====== +{{anchor:torch.DiskFile.dok}} + +Parent classes: [[File|File]] + +A ''DiskFile'' is a particular ''File'' which is able to perform basic read/write operations +on a file stored on disk. It implements all methods described in [[File|File]], and +some additional methods relative to //endian// encoding. + +By default, a ''DiskFile'' is in [[File#torch.File.binary|ASCII]] mode. If changed to +the [[File#torch.File.binary|binary]] mode, the default endian encoding is the native +computer one. + +The file might be open in read, write, or read-write mode, depending on the parameter +''mode'' (which can take the value ''"r"'', ''"w"'' or ''"rw"'' respectively) +given to the [[#torch.DiskFile|torch.DiskFile(fileName, mode)]]. + +===== torch.DiskFile(fileName, [mode], [quiet]) ===== +{{anchor:torch.DiskFile}} + +//Constructor// which opens ''fileName'' on disk, using the given ''mode''. Valid ''mode'' are +''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). Default is read mode. + +If read-write mode, the file //will be created// if it does not exists. If it +exists, it will be positionned at the beginning of the file after opening. + +If (and only if) ''quiet'' is ''true'', no error will be raised in case of +problem opening the file: instead ''nil'' will be returned. + +The file is opened in [[File#torch.File.ascii|ASCII]] mode by default. + +===== bigEndianEncoding() ===== +{{anchor:torch.DiskFile.bigEndianEncoding}} + +In [[file#torch.File.binary|binary]] mode, force encoding in //big endian//. +(//big end first//: decreasing numeric significance with increasing memory +addresses) + +===== [boolean] isBigEndianCPU() ===== +{{anchor:torch.DiskFile.isBigEndianCPU}} + +Returns ''true'' if, and only if, the computer CPU operates in //big endian//. +//Big end first//: decreasing numeric significance with increasing +memory addresses. + +===== [boolean] isLittleEndianCPU() ===== +{{anchor:torch.DiskFile.isLittleEndianCPU}} + +Returns ''true'' if, and only if, the computer CPU operates in //little endian//. +//Little end first//: increasing numeric significance with increasing +memory addresses. + +===== littleEndianEncoding() ===== +{{anchor:torch.DiskFile.littleEndianEncoding}} + +In [[file#torch.File.binary|binary]] mode, force encoding in //little endian//. +(//little end first//: increasing numeric significance with increasing memory +addresses) + +===== nativeEndianEncoding() ===== +{{anchor:torch.DiskFile.nativeEndianEncoding}} + +In [[file#torch.File.binary|binary]] mode, force encoding in //native endian//. + diff --git a/dok/file.dok b/dok/file.dok new file mode 100644 index 00000000000..64fa6bdf149 --- /dev/null +++ b/dok/file.dok @@ -0,0 +1,333 @@ +====== File ====== +{{anchor:torch.File.dok}} + +This is an //abstract// class. It defines most methods implemented by its +child classes, like [[DiskFile|DiskFile]], +[[MemoryFile|MemoryFile]] and [[PipeFile|PipeFile]]. + +Methods defined here are intended for basic read/write functionalities. +Read/write methods might write in [[#torch.File.ascii|ASCII]] mode or +[[#torch.File.binary|binary]] mode. + +In [[#torch.File.ascii|ASCII]] mode, numbers are converted in human readable +format (characters). Booleans are converted into ''0'' (false) or ''1'' (true). +In [[#torch.File.binary|binary]] mode, numbers and boolean are directly encoded +as represented in a register of the computer. While not being human +readable and less portable, the binary mode is obviously faster. + +In [[#torch.File.ascii|ASCII]] mode, if the default option +[[#torch.File.autoSpacing|autoSpacing()]] is chosen, a space will be generated +after each written number or boolean. A carriage return will also be added +after each call to a write method. With this option, the spaces are +supposed to exist while reading. This option can be deactivated with +[[#torch.File.noAutoSpacing|noAutoSpacing()]]. + +A ''Lua'' error might or might be not generated in case of read/write error +or problem in the file. This depends on the choice made between +[[#torch.File.quiet|quiet()]] and [[#torch.File.pedantic|pedantic()]] options. It +is possible to query if an error occured in the last operation by calling +[[#torch.File.hasError|hasError()]]. + +===== Read methods ===== +{{anchor:torch.File.read}} +{{anchor:torch.File.readBool}} +{{anchor:torch.File.readByte}} +{{anchor:torch.File.readChar}} +{{anchor:torch.File.readShort}} +{{anchor:torch.File.readInt}} +{{anchor:torch.File.readLong}} +{{anchor:torch.File.readFloat}} +{{anchor:torch.File.readDouble}} + +They are three types of reading methods: + - ''[number] readTYPE()'' + - ''[TYPEStorage] readTYPE(n)'' + - ''[number] readTYPE(TYPEStorage)'' + +where ''TYPE'' can be either ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'', ''Float'' or ''Double''. + +A convenience method also exist for boolean types: ''[boolean] readBool()''. It reads +a value on the file with ''readInt()'' and returns ''true'' if and only if this value is ''1''. It is not possible +to read storages of booleans. + +All these methods depends on the encoding choice: [[#torch.File.ascii|ASCII]] +or [[#torch.File.binary|binary]] mode. In [[#torch.File.ascii|ASCII]] mode, the +option [[#torch.File.autoSpacing|autoSpacing()]] and +[[#torch.File.noAutoSpacing|noAutoSpacing()]] have also an effect on these +methods. + +If no parameter is given, one element is returned. This element is +converted to a ''Lua'' number when reading. + +If ''n'' is given, ''n'' values of the specified type are read +and returned in a new [[Storage|Storage]] of that particular type. +The storage size corresponds to the number of elements actually read. + +If a ''Storage'' is given, the method will attempt to read a number of elements +equals to the size of the given storage, and fill up the storage with these elements. +The number of elements actually read is returned. + +In case of read error, these methods will call the ''Lua'' error function using the default +[[#torch.File.pedantic|pedantic]] option, or stay quiet with the [[#torch.File.quiet|quiet]] +option. In the latter case, one can check if an error occurred with +[[#torch.File.hasError|hasError()]]. + +===== Write methods ===== +{{anchor:torch.File.write}} +{{anchor:torch.File.writeBool}} +{{anchor:torch.File.writeByte}} +{{anchor:torch.File.writeChar}} +{{anchor:torch.File.writeShort}} +{{anchor:torch.File.writeInt}} +{{anchor:torch.File.writeLong}} +{{anchor:torch.File.writeFloat}} +{{anchor:torch.File.writeDouble}} + +They are two types of reading methods: + - ''[number] writeTYPE(number)'' + - ''[number] writeTYPE(TYPEStorage)'' + +where ''TYPE'' can be either ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'', ''Float'' or ''Double''. + +A convenience method also exist for boolean types: ''writeBool(value)''. If ''value'' is ''nil'' or +not ''true'' a it is equivalent to a ''writeInt(0)'' call, else to ''writeInt(1)''. It is not possible +to write storages of booleans. + +All these methods depends on the encoding choice: [[#torch.File.ascii|ASCII]] +or [[#torch.File.ascii|binary]] mode. In [[#torch.File.ascii|ASCII]] mode, the +option [[#torch.File.autoSpacing|autoSpacing()]] and +[[#torch.File.noAutoSpacing|noAutoSpacing()]] have also an effect on these +methods. + +If one ''Lua'' number is given, this number is converted according to the +name of the method when writing (e.g. ''writeInt(3.14)'' will write ''3''). + +If a ''Storage'' is given, the method will attempt to write all the elements contained +in the storage. + +These methods return the number of elements actually written. + +In case of read error, these methods will call the ''Lua'' error function using the default +[[#torch.File.pedantic|pedantic]] option, or stay quiet with the [[#torch.File.quiet|quiet]] +option. In the latter case, one can check if an error occurred with +[[#torch.File.hasError|hasError()]]. + +===== Serialization methods ===== +{{anchor:torch.File.serialization}} + +These methods allow the user to save any serializable objects on disk and +reload it later in its original state. In other words, it can perform a +//deep// copy of an object into a given ''File''. + +Serializable objects are ''Torch'' objects having a ''read()'' and +''write()'' method. ''Lua'' objects such as ''table'', ''number'' or +''string'' or //pure Lua// functions are also serializable. + +If the object to save contains several other objects (let say it is a tree +of objects), then objects appearing several times in this tree will be +//saved only once//. This saves disk space, speedup loading/saving and +respect the dependencies between objects. + +Interestingly, if the ''File'' is a [[MemoryFile|MemoryFile]], it allows +the user to easily make a //clone// of any serializable object: + +file = torch.MemoryFile() -- creates a file in memory +file:writeObject(object) -- writes the object into file +file:seek(1) -- comes back at the beginning of the file +objectClone = file:readObject() -- gets a clone of object + + +==== readObject() ==== +{{anchor:torch.File.readObject}} + +Returns the next [[#torch.File.serialization|serializable]] object saved beforehand +in the file with [[#torch.File.writeObject|writeObject()]]. + +Note that objects which were [[#torch.File.writeObject|written]] with the same +reference have still the same reference after loading. + +Example: + +-- creates an array which contains twice the same tensor +array = {} +x = torch.Tensor(1) +table.insert(array, x) +table.insert(array, x) + +-- array[1] and array[2] refer to the same address +-- x[1] == array[1][1] == array[2][1] == 3.14 +array[1][1] = 3.14 + +-- write the array on disk +file = torch.DiskFile('foo.asc', 'w') +file:writeObject(array) +file:close() -- make sure the data is written + +-- reload the array +file = torch.DiskFile('foo.asc', 'r') +arrayNew = file:readObject() + +-- arrayNew[1] and arrayNew[2] refer to the same address! +-- arrayNew[1][1] == arrayNew[2][1] == 3.14 +-- so if we do now: +arrayNew[1][1] = 2.72 +-- arrayNew[1][1] == arrayNew[2][1] == 2.72 ! + + +==== writeObject(object) ==== +{{anchor:torch.File.writeObject}} + +Writes ''object'' into the file. This object can be read later using +[[#torch.File.readObject|readObject()]]. Serializable objects are ''Torch'' +objects having a ''read()'' and ''write()'' method. ''Lua'' objects such as +''table'', ''number'' or ''string'' or pure Lua functions are also serializable. + +If the object has been already written in the file, only a //reference// to +this already saved object will be written: this saves space an speed-up +writing; it also allows to keep the dependencies between objects intact. + +In returns, if one writes an object, modify its member, and write the +object again in the same file, the modifications will not be recorded +in the file, as only a reference to the original will be written. See +[[#torch.File.readObject|readObject()]] for an example. + +==== [string] readString(format) ==== +{{anchor:torch.File.readString}} + +If ''format'' starts with ''"*l"'' then returns the next line in the ''File''. The end-of-line character is skipped. + +If ''format'' starts with ''"*a"'' then returns all the remaining contents of the ''File''. + +If no data is available, then an error is raised, except if ''File'' is in [[#torch.File.quiet|quiet()]] mode where +it then returns ''nil''. + +Because Torch is more precised on number typing, the ''Lua'' format ''"*n"'' is not supported: +instead use one of the [[#torch.File.read|number read methods]]. + +==== [number] writeString(str) ==== +{{anchor:torch.File.writeString}} + +Writes the string ''str'' in the ''File''. If the string cannot be written completely an error is raised, except +if ''File'' is in [[#torch.File.quiet|quiet()]] mode where it returns the number of character actually written. + +===== ascii() [default] ===== +{{anchor:torch.File.ascii}} + +The data read or written will be in ''ASCII'' mode: all numbers are converted +to characters (human readable format) and boolean are converted to ''0'' +(false) or ''1'' (true). The input-output format in this mode depends on the +options [[#torch.File.autoSpacing|autoSpacing()]] and +[[#torch.File.noAutoSpacing|noAutoSpacing()]]. + +===== autoSpacing() [default] ===== +{{anchor:torch.File.autoSpacing}} + +In [[#torch.File.ascii|ASCII]] mode, write additional spaces around the elements +written on disk: if writing a [[Storage|Storage]], a space will be +generated between each //element// and a //return line// after the last +element. If only writing one element, a //return line// will be generated +after this element. + +Those spaces are supposed to exist while reading in this mode. + +This is the default behavior. You can de-activate this option with the +[[#torch.File.noAutoSpacing|noAutoSpacing()]] method. + +===== binary() ===== +{{anchor:torch.File.binary}} + +The data read or written will be in binary mode: the representation in the +''File'' is the same that the one in the computer memory/register (not human +readable). This mode is faster than [[#torch.File.ascii|ASCII]] but less +portable. + +===== clearError() ===== +{{anchor:torch.File.clearError}} + +Clear the error.flag returned by [[#torch.File.hasError|hasError()]]. + +===== close() ===== +{{anchor:torch.File.close}} + +Close the file. Any subsequent operation will generate a ''Lua'' error. + +===== noAutoSpacing() ===== +{{anchor:torch.File.noAutoSpacing}} + +In [[#torch.File.ascii|ASCII]] mode, do not put extra spaces between element +written on disk. This is the contrary of the option +[[#torch.File.autoSpacing|autoSpacing()]]. + +===== synchronize() ===== +{{anchor:torch.File.synchronize}} + +If the child class bufferize the data while writing, ensure that the data +is actually written. + + +===== pedantic() [default] ===== +{{anchor:torch.File.pedantic}} + +If this mode is chosen (which is the default), a ''Lua'' error will be +generated in case of error (which will cause the program to stop). + +It is possible to use [[#torch.File.quiet|quiet()]] to avoid ''Lua'' error generation +and set a flag instead. + +===== [number] position() ===== +{{anchor:torch.File.position}} + +Returns the current position (in bytes) in the file. +The first position is ''1'' (following Lua standard indexing). + +===== quiet() ===== +{{anchor:torch.File.quiet}} + +If this mode is chosen instead of [[#torch.File.pedantic|pedantic()]], no ''Lua'' +error will be generated in case of read/write error. Instead, a flag will +be raised, readable through [[#torch.File.hasError|hasError()]]. This flag can +be cleared with [[#torch.File.clearError|clearError()]] + +Checking if a file is quiet can be performed using [[#torch.File.isQuiet|isQuiet()]]. + +===== seek(position) ===== +{{anchor:torch.File.seek}} + +Jump into the file at the given ''position'' (in byte). Might generate/raise +an error in case of problem. The first position is ''1'' (following Lua standard indexing). + +===== seekEnd() ===== +{{anchor:torch.File.seekEnd}} + +Jump at the end of the file. Might generate/raise an error in case of +problem. + +===== File state query ===== + +These methods allow the user to query the state of the given ''File''. + +==== [boolean] hasError() ==== +{{anchor:torch.File.hasError}} + +Returns if an error occurred since the last [[#torch.File.clearError|clearError()]] call, or since +the opening of the file if ''clearError()'' has never been called. + +==== [boolean] isQuiet() ==== +{{anchor:torch.File.isQuiet}} + +Returns a boolean which tells if the file is in [[#torch.File.quiet|quiet]] mode or not. + +==== [boolean] isReadable() ==== +{{anchor:torch.File.isReadable}} + +Tells if one can read the file or not. + +==== [boolean] isWritable() ==== +{{anchor:torch.File.isWritable}} + +Tells if one can write in the file or not. + +==== [boolean] isAutoSpacing() ==== +{{anchor:torch.File.isAutoSpacing}} + +Return ''true'' if [[#torch.File.autoSpacing|autoSpacing]] has been chosen. diff --git a/dok/index.dok b/dok/index.dok new file mode 100644 index 00000000000..d43faaca8ea --- /dev/null +++ b/dok/index.dok @@ -0,0 +1,39 @@ +====== Torch Package Reference Manual ====== +{{anchor:torch.reference.dok}} + +The **''torch''** package contains basic classes used everywhere in ''Torch7''. + +//Input-output management// is provided with [[File|File]] (abstract class), [[DiskFile|DiskFile]] (file on disk), +[[MemoryFile|MemoryFile]] (file in ''RAM'') and [[PipeFile|PipeFile]] (file from a piped command). These +classes also handle //serialization//. + +[[Storage|Storage]] and [[Tensor|Tensor]] are the basic bricks for //powerful numeric operations//. Tensors support +a wide variety of fundamental [[maths|math operations]]. + +[[Timer|Timer]] is provided for //measuring time//. + +[[Tester|Tester]] is provided as a generic testing framework and it is also used by [[..:nn:index|nn]] package. + +[[CmdLine|CmdLine]] is provided as a command line argument parsing utility. + +Finally, ''Torch'' provides some [[Utility|utility functions]] for creating and handling ''Torch'' //classes//, +as well as support for [[random|random number generation]]. + +===== Torch Packages ===== +{{anchor:torch.reference.dok}} + + * File I/O Interface Library + * [[File|File]] is an abstract interface for common file operations. + * [[DiskFile|Disk File]] defines operations on files stored on disk. + * [[MemoryFile|Memory File]] defines operations on stored in RAM. + * [[PipeFile|Pipe File]] defines operations for using piped commands. + * Tensor Library + * [[Storage|Storage]] defines a simple storage interface that controls the underlying storage for any tensor object. + * [[Tensor|Tensor]] defines the //all powerful// tensor object that defines multi-dimensional numerical arrays with type templating. + * [[maths|Mathemetical operations]] are defined for the tensor object types. + * Useful Utilities + * [[Timer|Timer]] provides functionality for //measuring time//. + * [[Tester|Tester]] is a generic tester framework. + * [[CmdLine|CmdLine]] is a command line argument parsing utility. + * [[Random|Random]] defines a random number generator package with various distributions. + * Finally useful [[Utility|utility] functions are provided for easy handling of torch tensor types and class inheritance. diff --git a/dok/maths.dok b/dok/maths.dok new file mode 100644 index 00000000000..187331b4812 --- /dev/null +++ b/dok/maths.dok @@ -0,0 +1,804 @@ +====== Math Functions ====== +{{anchor:torch.maths.dok}} + +Torch provides Matlab-like functions for manipulating +[[index#Tensor|Tensor]] objects. Functions fall into several types of +categories: + * [[#torch.construction.dok|constructors]] like [[#torch.zeros|zeros]], [[#torch.ones|ones]] + * extractors like [[#torch.diag|diag]] and [[#torch.triu|triu]], + * [[#torch.elementwise.dok|element-wise]] operations like [[#torch.abs|abs]] and [[#torch.pow|pow]], + * [[#torch.columnwise.dok|column or row-wise operations]] like [[#torch.sum|sum]] and [[#torch.max|max]], + * [[#torch.matrixwide.dok|matrix-wide operations]] like [[#torch.trace|trace]] and [[#torch.norm|norm]]. + * [[#torch.conv.dok|Convolution and cross-correlation]] operations like [[#torch.conv2|conv2]]. + * [[#torch.linalg.dok|Basic linear algebra operations]] like [[#torch.eig|eigen value/vector calculation]], [[#torch.svd|singular value decomposition (svd)]] and [[#torch.gesv|linear system solution]]. + +By default, all operations allocate a new tensor to return the +result. However, all functions also support passing the resulting(s) +tensor(s) as the first argument(s), in which case the resulting tensor(s) +will be resized accordingly and filled with result. + +For example, ''torch.conv2'' function can be used in the following manner. + + +x = torch.rand(100,100) +k = torch.rand(10,10) +res1 = torch.conv2(x,k) + +res2 = torch.Tensor() +torch.conv2(res2,x,k) + +=res2:dist(res1) +0 + + + +The advantage of second case is, same ''res2'' tensor can be used successively in a loop without any new allocation. + + +-- no new memory allocations... +for i=1,100 do + torch.conv2(res2,x,k) +end +=res2:dist(res1) +0 + + +====== Construction or extraction functions ====== +{{anchor:torch.construction.dok}} + +===== torch.cat( [res,] x_1, x_2, [dimension] ) ===== +{{anchor:torch.cat}} + +''x=torch.cat(x_1,x_2,[dimension])'' returns a tensor ''x'' which is the concatenation of tensors x_1 and x_2 along dimension ''dimension''. + +If ''dimension'' is not specified it is 1. + +The other dimensions of x_1 and x_2 have to be equal. + +Examples: + +> print(torch.cat(torch.ones(3),torch.zeros(2))) + + 1 + 1 + 1 + 0 + 0 +[torch.Tensor of dimension 5] + + +> print(torch.cat(torch.ones(3,2),torch.zeros(2,2))) + + 1 1 + 1 1 + 1 1 + 0 0 + 0 0 +[torch.DoubleTensor of dimension 5x2] + + +> print(torch.cat(torch.ones(2,2),torch.zeros(2,2))) + 1 1 + 1 1 + 0 0 + 0 0 +[torch.DoubleTensor of dimension 4x2] + +> print(torch.cat(torch.ones(2,2),torch.zeros(2,2),2)) + 1 1 0 0 + 1 1 0 0 +[torch.DoubleTensor of dimension 2x4] + + +> print(torch.cat(torch.cat(torch.ones(2,2),torch.zeros(2,2)),torch.rand(3,2))) + + 1.0000 1.0000 + 1.0000 1.0000 + 0.0000 0.0000 + 0.0000 0.0000 + 0.3227 0.0493 + 0.9161 0.1086 + 0.2206 0.7449 +[torch.DoubleTensor of dimension 7x2] + + + + +===== torch.diag( [res,] x) ===== +{{anchor:torch.diag}} + +''y=torch.diag(x)'' when x is of dimension 1 returns a diagonal matrix with diagonal elements constructed from x. + +''y=torch.diag(x)'' when x is of dimension 2 returns a tensor of dimension 1 +with elements constructed from the diagonal of x. + +''y=torch.diag(x,k)'' returns the k-th diagonal of x, +wher k=0 is the main diagonal, k>0 is above the main diagonal and k<0 +is below the main diagonal. + +===== torch.eye( [res,] n) ===== +{{anchor:torch.eye}} + +''y=torch.eye(n)'' returns the n-by-n identity matrix. + +''y=torch.eye(m,n)'' returns an m-by-n identity matrix with ones on the diagonal and zeros elsewhere. + + +===== torch.linspace( [res,] x1,x2) ===== +{{anchor:torch.linspace}} + +''y=torch.linspace(x1,x2)'' returns a one-dimensional tensor of size 100 equally spaced points between x1 and x2. + +''y=torch.linspace(x1,x2,n)'' returns a one-dimensional tensor of n equally spaced points between x1 and x2. + + +===== torch.logspace( [res,] x1, x2) ===== +{{anchor:torch.logspace}} + +''y=torch.logspace(x1,x2)'' returns a one-dimensional tensor of 50 logarithmically eqally spaced points between x1 and x2. + +''y=torch.logspace(x1,x2,n)'' returns a one-dimensional tensor of n logarithmically equally spaced points between x1 and x2. + +===== torch.ones( [res,] m) ===== +{{anchor:torch.ones}} + +''y=torch.ones(n)'' returns a one-dimensional tensor of size n filled with ones. + +''y=torch.ones(m,n)'' returns a mxn tensor filled with ones. + +''y=torch.ones(m,n,k)'' returns a mxnxk tensor filled with ones. + +''y=torch.ones(d1,...,d_n)'' returns an n-dimensional tensor with sizes d1, ..., d_n filled with ones. + +===== torch.rand( [res,] m [, n, k, ...]) ===== +{{anchor:torch.rand}} + +''y=torch.rand(n)'' returns a one-dimensional tensor of size n filled with random numbers from a uniform distribution on the interval (0,1). + +''y=torch.rand(m,n)'' returns a mxn tensor of random numbers from a uniform distribution on the interval (0,1). + +===== torch.randn( [res,] m [, n, k, ...]) ===== +{{anchor:torch.randn}} + +''y=torch.randn(n)'' returns a one-dimensional tensor of size n filled with random numbers from a normal distribution with mean zero and variance one. + +''y=torch.randn(m,n)'' returns a mxn tensor of random numbers from a normal distribution with mean zero and variance one. + +===== torch.range([res,] n,m) ===== +{{anchor:torch.range}} + +''y=torch.range(n,m)'' returns a tensor of size m-n+1x1 with integer +values n to m. + + +> print(torch.range(2,5)) + + 2 + 3 + 4 + 5 +[torch.Tensor of dimension 4] + + +''y=torch.range(n,m,incr)'' returns a tensor filled in range n to m with incr increments. + +print(torch.range(2,5,1.2)) + 2.0000 + 3.2000 + 4.4000 +[torch.DoubleTensor of dimension 3] + + +===== torch.randperm([res,] n) ===== +{{anchor:torch.randperm}} + +''y=torch.randperm(n)'' returns a randomly ordered nx1 tensor of the integers from 1 to n. + +===== torch.reshape([res,] x,m,n) ===== +{{anchor:torch.reshape}} + +''y=torch.reshape(x,m,n)'' returns a new mxn tensor y whose elements +are taken rowwise from x, which must have m*n elements. The elements are copied into the new tensor. + +===== torch.tril([res,] x) ===== +{{anchor:torch.tril}} + +''y=torch.tril(x)'' returns the lower triangular part of x, the other elements of y are set to 0. + +''torch.tril(x,k)'' returns the elements on and below the k-th diagonal of x as non-zero. k=0 is the main diagonal, k>0 is above the main diagonal and k<0 +is below the main diagonal. + +===== torch.triu([res,] x) ===== +{{anchor:torch.triu}} + +''y=torch.triu(x)'' returns the upper triangular part of x, +the other elements of y are set to 0. + +''torch.triu(x,k)'' returns the elements on and above the k-th diagonal of x as non-zero. k=0 is the main diagonal, k>0 is above the main diagonal and k<0 +is below the main diagonal. + +===== torch.zeros([res,] x) ===== +{{anchor:torch.zeros}} + +''y=torch.zeros(n)'' returns a one-dimensional tensor of size n filled with zeros. + +''y=torch.zeros(m,n)'' returns a mxn tensor filled with zeros. + + +====== Element-wise operations ====== +{{anchor:torch.elementwise.dok}} + +===== torch.abs([res,] x) ===== +{{anchor:torch.abs}} + +''y=torch.abs(x)'' returns the absolute values of the elements of x. + +===== torch.acos([res,] x) ===== +{{anchor:torch.acos}} + +''y=torch.acos(x)'' returns the arcosine of the elements of x. + +===== torch.asin([res,] x) ===== +{{anchor:torch.asin}} + +''y=torch.asin(x)'' returns the arcsine of the elements of x. + +===== torch.atan([res,] x) ===== +{{anchor:torch.atan}} + +''y=torch.atan(x)'' returns the arctangent of the elements of x. + +===== torch.ceil([res,] x) ===== +{{anchor:torch.ceil}} + +''y=torch.ceil(x)'' returns the values of the elements of x rounded up to the nearest integers. + +===== torch.cos([res,] x) ===== +{{anchor:torch.cos}} + +''y=torch.cos(x)'' returns the cosine of the elements of x. + +===== torch.cosh([res,] x) ===== +{{anchor:torch.cosh}} + +''y=torch.cosh(x)'' returns the hyberbolic cosine of the elements of x. + +===== torch.exp[res,] (x) ===== +{{anchor:torch.exp}} + +''y=torch.exp(x)'' returns, for each element in x, e (the base of natural logarithms) raised to the power of the element in x. + +===== torch.floor([res,] x) ===== +{{anchor:torch.floor}} + +''y=torch.floor(x)'' returns the values of the elements of x rounded down to the nearest integers. + +===== torch.log[res,] (x) ===== +{{anchor:torch.log}} + +''y=torch.log(x)'' returns the natural logarithm of the elements of x. + +===== torch.pow([res,] x) ===== +{{anchor:torch.pow}} + +''y=torch.pow(x,n)'' returns the elements of x to the power of n. + +===== torch.sin([res,] x) ===== +{{anchor:torch.sin}} + +''y=torch.sin(x)'' returns the sine of the elements of x. + +===== torch.sinh([res,] x) ===== +{{anchor:torch.sinh}} + +''y=torch.sinh(x)'' returns the hyperbolic sine of the elements of x. + +===== torch.sqrt([res,] x) ===== +{{anchor:torch.sqrt}} + +''y=torch.sqrt(x)'' returns the square root of the elements of x. + +===== torch.tan([res,] x) ===== +{{anchor:torch.tan}} + +''y=torch.abs(x)'' returns the tangent of the elements of x. + +===== torch.tanh([res,] x) ===== +{{anchor:torch.tanh}} + +''y=torch.tanh(x)'' returns the hyperbolic tangent of the elements of x. + +====== Column or row-wise operations (dimension-wise operations) ====== +{{anchor:torch.columnwise.dok}} + +===== torch.cross([res,] a,b) ===== +{{anchor:torch.cross}} + +''y=torch.cross(a,b)'' returns the cross product of the tensors a and b. +a and b must be 3 element vectors. + +''y=cross(a,b)'' returns the cross product of a and b along the first dimension of length 3. + +''y=cross(a,b,n)'', where a and b returns the cross +product of vectors in dimension n of a and b. +a and b must have the same size, +and both a:size(n) and b:size(n) must be 3. + + +===== torch.cumprod([res,] x) ===== +{{anchor:torch.cumprod}} + +''y=torch.cumprod(x)'' returns the cumulative product of the elements of x, performing the operation over the last dimension. + +''y=torch.cumprod(x,n)'' returns the cumulative product of the elements of x, performing the operation over dimension n. + +===== torch.cumsum([res,] x) ===== +{{anchor:torch.cumsum}} + +''y=torch.cumsum(x)'' returns the cumulative product of the elements of x, performing the operation over the first dimension. + +''y=torch.cumsum(x,n)'' returns the cumulative product of the elements of x, performing the operation over dimension n. + +===== torch.max([resval, resind, ] x) ===== +{{anchor:torch.max}} + +''y,i=torch.max(x)'' returns a tensor y of the largest element in +each row of x, and a tensor i of their corresponding indices in x. + +''y,i=torch.max(x,1)'' performs the max operation for each row and +''y,i=torch.max(x,n)'' performs the max operation over the dimension n. + + +===== torch.mean([res,] x) ===== +{{anchor:torch.mean}} + +''y=torch.mean(x)'' returns a tensor y of the mean of the elements in +each row of x. + +''y=torch.mean(x,2)'' performs the mean operation for each row and +''y=torch.mean(x,n)'' performs the mean operation over the dimension n. + +===== torch.min([resval, resind, ] x) ===== +{{anchor:torch.min}} + +''y,i=torch.min(x)'' returns a tensor y of the smallest element in +each row of x, and a tensor i of their corresponding indices in x. + +''y,i=torch.min(x,2)'' performs the min operation for each row and +''y,i=torch.min(x,n)'' performs the min operation over the dimension n. + + +===== torch.prod([res,] x) ===== +{{anchor:torch.prod}} + +''y=torch.prod(x)'' returns a tensor y of the product of the elements in +each row of x. + +''y=torch.prod(x,2)'' performs the prod operation for each row and +''y=torch.prod(x,n)'' performs the prod operation over the dimension n. + +===== torch.sort([resval, resind, ] x) ===== +{{anchor:torch.sort}} + +''y,i=torch.sort(x)'' returns a tensor y of the sorted +rows of x, and a tensor i of the corresponding indices from x. + +''y,i=torch.sort(x,2)'' performs the sort operation for each row and +''y,i=torch.sort(x,n)'' performs the sort operation over the dimension n. + +===== torch.std([res,] x) ===== +{{anchor:torch.std}} + +''y=torch.std(x)'' returns a tensor y of the standard deviation of the elements in +each row of x. + +''torch.std(x)'' normalizes by (n-1) where n is the number of elements. This +makes torch.sum(torch.pow(torch.std(x),2)) +the best unbiased estimate of the variance if x +is a sample from a normal distribution. + +''y=torch.std(x,true)'' performs the std operation normalizing by n instead of n-1. + +''y=torch.std(x,false)'' performs the std operation normalizing by n-1. + +''y=torch.std(x,flag,n)'' performs the std operation over the dimension n. + + +===== torch.sum([res,] x) ===== +{{anchor:torch.sum}} + +''y=torch.sum(x)'' returns a tensor y of the sum of the elements in +each row of x. + +''y=torch.sum(x,2)'' performs the sum operation for each row and +''y=torch.sum(x,n)'' performs the sum operation over the dimension n. + +===== torch.var([res,] x) ===== +{{anchor:torch.var}} + +''y=torch.var(x)'' returns a tensor y of the standard deviation of the elements in +each row of x. + +''torch.var(x)'' normalizes by (n-1) where n is the number of elements. This +makes torch.sum(torch.var(x)) +the best unbiased estimate of the variance if x +is a sample from a normal distribution. + +''y=torch.var(x,true)'' performs the var operation normalizing by n instead of n-1. + +''y=torch.var(x,false)'' performs the var operation normalizing by n-1. + +''y=torch.var(x,flag,n)'' performs the var operation over the dimension n. + +====== Matrix-wide operations (tensor-wide operations) ====== +{{anchor:torch.matrixwide.dok}} + +===== torch.norm(x) ===== +{{anchor:torch.norm}} + +''y=torch.norm(x)'' returns the 2-norm of the tensor x. + +''y=torch.norm(x,p)'' returns the p-norm of the tensor x. + + +===== torch.dist(x,y) ===== +{{anchor:torch.dist}} + +''y=torch.dist(x,y)'' returns the 2-norm of (x-y). + +''y=torch.dist(x,y,p)'' returns the p-norm of (x-y). + +===== torch.numel(x) ===== +{{anchor:torch.numel}} + +''y=torch.numel(x)'' returns the count of the number of elements in the matrix x. + +===== torch.trace(x) ===== +{{anchor:torch.trace}} + +''y=torch.trace(x)'' returns the trace (sum of the diagonal elements) +of a matrix x. This is equal to the sum of the eigenvalues of x. +The returned value ''y'' is a number, not a tensor. + +====== Convolution Operations ====== +{{anchor:torch.conv.dok}} + +These function implement convolution or cross-correlation of an input +image (or set of input images) with a kernel (or set of kernels). The +convolution function in Torch can handle different types of +input/kernel dimensions and produces corresponding outputs. The +general form of operations always remain the same. + +===== torch.conv2([res,] x, k, ['f' or 'v']) ===== +{{anchor:torch.conv2}} + +This function computes 2 dimensional convolutions between '' x '' and '' k ''. These operations are similar to BLAS operations when number of dimensions of input and kernel are reduced by 2. + + * '' x '' and '' k '' are 2D : convolution of a single image with a single kernel (2D output). This operation is similar to multiplication of two scalars. + * '' x '' and '' k '' are 3D : convolution of each input slice with corresponding kernel (3D output). + * '' x (p x m x n) '' 3D, '' k (q x p x ki x kj)'' 4D : convolution of all input slices with the corresponding slice of kernel. Output is 3D '' (q x m x n) ''. This operation is similar to matrix vector product of matrix '' k '' and vector '' x ''. + +The last argument controls if the convolution is a full ('f') or valid ('v') convolution. The default is 'valid' convolution. + + +x=torch.rand(100,100) +k=torch.rand(10,10) +c = torch.conv2(x,k) +=c:size() + + 91 + 91 +[torch.LongStorage of size 2] + +c = torch.conv2(x,k,'f') +=c:size() + + 109 + 109 +[torch.LongStorage of size 2] + + + +===== torch.xcorr2([res,] x, k, ['f' or 'v']) ===== +{{anchor:torch.xcorr2}} + +This function operates with same options and input/output +configurations as [[#torch.conv2|torch.conv2]], but performs +cross-correlation of the input with the kernel '' k ''. + +===== torch.conv3([res,] x, k, ['f' or 'v']) ===== +{{anchor:torch.conv3}} + +This function computes 3 dimensional convolutions between '' x '' and '' k ''. These operations are similar to BLAS operations when number of dimensions of input and kernel are reduced by 3. + + * '' x '' and '' k '' are 3D : convolution of a single image with a single kernel (3D output). This operation is similar to multiplication of two scalars. + * '' x '' and '' k '' are 4D : convolution of each input slice with corresponding kernel (4D output). + * '' x (p x m x n x o) '' 4D, '' k (q x p x ki x kj x kk)'' 5D : convolution of all input slices with the corresponding slice of kernel. Output is 4D '' (q x m x n x o) ''. This operation is similar to matrix vector product of matrix '' k '' and vector '' x ''. + +The last argument controls if the convolution is a full ('f') or valid ('v') convolution. The default is 'valid' convolution. + + +x=torch.rand(100,100,100) +k=torch.rand(10,10,10) +c = torch.conv3(x,k) +=c:size() + + 91 + 91 + 91 +[torch.LongStorage of size 3] + +c = torch.conv3(x,k,'f') +=c:size() + + 109 + 109 + 109 +[torch.LongStorage of size 3] + + + +===== torch.xcorr3([res,] x, k, ['f' or 'v']) ===== +{{anchor:torch.xcorr3}} + +This function operates with same options and input/output +configurations as [[#torch.conv3|torch.conv3]], but performs +cross-correlation of the input with the kernel '' k ''. + +====== Eigenvalues, SVD, Linear System Solution ====== +{{anchor:torch.linalg.dok}} + +Functions in this section are implemented with an interface to LAPACK +libraries. If LAPACK libraries are not found during compilation step, +then these functions will not be available. + +===== torch.gesv([resb, resa,] b,a [, true]) ===== +{{anchor:torch.gesv}} + +Solution of '' AX=B '' and ''A'' has to be square and non-singular. '' +A '' is '' m x m '', '' X '' is '' m x k '', '' B '' is '' m x k ''. + +If ''resb'' and ''resa'' are given, then they will be used for +temporary storage and returning the result. + + * ''resa'' will contain L and U factors for ''LU'' factorization of ''A''. + * ''resb'' will contain the solution. + +If ''gesv'' is called with 3 parameters with last parameters ''true'', +then ''b'' and ''a'' will destroyed and their output values will be +same as ''resa'' and ''resb''. + + +a=torch.Tensor({{6.80, -2.11, 5.66, 5.97, 8.23}, + {-6.05, -3.30, 5.36, -4.44, 1.08}, + {-0.45, 2.58, -2.70, 0.27, 9.04}, + {8.32, 2.71, 4.35, -7.17, 2.14}, + {-9.67, -5.14, -7.26, 6.08, -6.87}}):t() + +b=torch.Tensor({{4.02, 6.19, -8.22, -7.57, -3.03}, + {-1.56, 4.00, -8.67, 1.75, 2.86}, + {9.81, -4.09, -4.57, -8.61, 8.99}}):t() + + =b + 4.0200 -1.5600 9.8100 + 6.1900 4.0000 -4.0900 +-8.2200 -8.6700 -4.5700 +-7.5700 1.7500 -8.6100 +-3.0300 2.8600 8.9900 +[torch.DoubleTensor of dimension 5x3] + +=a + 6.8000 -6.0500 -0.4500 8.3200 -9.6700 +-2.1100 -3.3000 2.5800 2.7100 -5.1400 + 5.6600 5.3600 -2.7000 4.3500 -7.2600 + 5.9700 -4.4400 0.2700 -7.1700 6.0800 + 8.2300 1.0800 9.0400 2.1400 -6.8700 +[torch.DoubleTensor of dimension 5x5] + + +x=torch.gesv(b,a) + =x +-0.8007 -0.3896 0.9555 +-0.6952 -0.5544 0.2207 + 0.5939 0.8422 1.9006 + 1.3217 -0.1038 5.3577 + 0.5658 0.1057 4.0406 +[torch.DoubleTensor of dimension 5x3] + +=b:dist(a*x) +1.1682163181673e-14 + + + +===== torch.gels([resb, resa,] b,a) ===== +{{anchor:torch.gels}} + +Solution of least squares and least norm problems for a full rank '' A '' that is '' m x n''. + * If '' n %%<=%% m '', then solve '' ||AX-B||_F ''. + * If '' n > m '' , then solve '' min ||X||_F s.t. AX=B ''. + +On return, first '' n '' rows of '' X '' matrix contains the solution +and the rest contains residual information. Square root of sum squares +of elements of each column of '' X '' starting at row '' n + 1 '' is +the residual for corresponding column. + + + +a=torch.Tensor({{ 1.44, -9.96, -7.55, 8.34, 7.08, -5.45}, + {-7.84, -0.28, 3.24, 8.09, 2.52, -5.70}, + {-4.39, -3.24, 6.27, 5.28, 0.74, -1.19}, + {4.53, 3.83, -6.64, 2.06, -2.47, 4.70}}):t() + +b=torch.Tensor({{8.58, 8.26, 8.48, -5.28, 5.72, 8.93}, + {9.35, -4.43, -0.70, -0.26, -7.36, -2.52}}):t() + +=a + 1.4400 -7.8400 -4.3900 4.5300 +-9.9600 -0.2800 -3.2400 3.8300 +-7.5500 3.2400 6.2700 -6.6400 + 8.3400 8.0900 5.2800 2.0600 + 7.0800 2.5200 0.7400 -2.4700 +-5.4500 -5.7000 -1.1900 4.7000 +[torch.DoubleTensor of dimension 6x4] + +=b + 8.5800 9.3500 + 8.2600 -4.4300 + 8.4800 -0.7000 +-5.2800 -0.2600 + 5.7200 -7.3600 + 8.9300 -2.5200 +[torch.DoubleTensor of dimension 6x2] + +x = torch.gels(a,b) +=x + -0.4506 0.2497 + -0.8492 -0.9020 + 0.7066 0.6323 + 0.1289 0.1351 + 13.1193 -7.4922 + -4.8214 -7.1361 +[torch.DoubleTensor of dimension 6x2] + +=b:dist(a*x:narrow(1,1,4)) +17.390200628863 + +=math.sqrt(x:narrow(1,5,2):pow(2):sumall()) +17.390200628863 + + + +===== torch.eig([rese, resv,] a, [, 'n' or 'v']) ===== +{{anchor:torch.eig}} + +Eigen values and eigen vectors of a symmetric real matrix '' A '' of +size '' m x m ''. This function calculates all eigenvalues (and +vectors) of '' A '' such that '' A = V' diag(e) V ''. Since the input +matrix '' A '' is supposed to be symmetric, only upper triangular +portion is used. + +Last argument defines computation of eigenvectors or eigenvalues +only. If '' n '', only eignevalues are computed. If '' v '', both +eigenvalues and eigenvectors are computed. + + + +a=torch.Tensor({{ 1.96, 0.00, 0.00, 0.00, 0.00}, + {-6.49, 3.80, 0.00, 0.00, 0.00}, + {-0.47, -6.39, 4.17, 0.00, 0.00}, + {-7.20, 1.50, -1.51, 5.70, 0.00}, + {-0.65, -6.34, 2.67, 1.80, -7.10}}):t() + +=a + 1.9600 -6.4900 -0.4700 -7.2000 -0.6500 + 0.0000 3.8000 -6.3900 1.5000 -6.3400 + 0.0000 0.0000 4.1700 -1.5100 2.6700 + 0.0000 0.0000 0.0000 5.7000 1.8000 + 0.0000 0.0000 0.0000 0.0000 -7.1000 +[torch.DoubleTensor of dimension 5x5] + +e = torch.eig(a) +=e +-11.0656 + -6.2287 + 0.8640 + 8.8655 + 16.0948 +[torch.DoubleTensor of dimension 5] + +e,v = torch.eig(a,'v') +=e +-11.0656 + -6.2287 + 0.8640 + 8.8655 + 16.0948 +[torch.DoubleTensor of dimension 5] + +=v +-0.2981 -0.6075 0.4026 -0.3745 0.4896 +-0.5078 -0.2880 -0.4066 -0.3572 -0.6053 +-0.0816 -0.3843 -0.6600 0.5008 0.3991 +-0.0036 -0.4467 0.4553 0.6204 -0.4564 +-0.8041 0.4480 0.1725 0.3108 0.1622 +[torch.DoubleTensor of dimension 5x5] + +=v*torch.diag(e)*v:t() + 1.9600 -6.4900 -0.4700 -7.2000 -0.6500 +-6.4900 3.8000 -6.3900 1.5000 -6.3400 +-0.4700 -6.3900 4.1700 -1.5100 2.6700 +-7.2000 1.5000 -1.5100 5.7000 1.8000 +-0.6500 -6.3400 2.6700 1.8000 -7.1000 +[torch.DoubleTensor of dimension 5x5] + +=a:dist(torch.triu(v*torch.diag(e)*v:t())) +1.0219480822443e-14 + + + +===== torch.svd([resu, ress, resv] a, [, 's' or 'a']) ===== +{{anchor:torch.svd}} + +Singular value decomposition of a real matrix '' A '' of size '' n x m +'' such that '' A = USV**T ''. The call to ''svd'' returns ''U,S,VT''. + +The last argument, if it is string, represents the number of singular +values to be computed. 's' stands for 'some' and 'a' stands for 'all'. + + + + +a=torch.Tensor({{8.79, 6.11, -9.15, 9.57, -3.49, 9.84}, + {9.93, 6.91, -7.93, 1.64, 4.02, 0.15}, + {9.83, 5.04, 4.86, 8.83, 9.80, -8.99}, + {5.45, -0.27, 4.85, 0.74, 10.00, -6.02}, + {3.16, 7.98, 3.01, 5.80, 4.27, -5.31}}):t() +=a + 8.7900 9.9300 9.8300 5.4500 3.1600 + 6.1100 6.9100 5.0400 -0.2700 7.9800 + -9.1500 -7.9300 4.8600 4.8500 3.0100 + 9.5700 1.6400 8.8300 0.7400 5.8000 + -3.4900 4.0200 9.8000 10.0000 4.2700 + 9.8400 0.1500 -8.9900 -6.0200 -5.3100 + +u,s,v = torch.svd(a) + +=u +-0.5911 0.2632 0.3554 0.3143 0.2299 +-0.3976 0.2438 -0.2224 -0.7535 -0.3636 +-0.0335 -0.6003 -0.4508 0.2334 -0.3055 +-0.4297 0.2362 -0.6859 0.3319 0.1649 +-0.4697 -0.3509 0.3874 0.1587 -0.5183 + 0.2934 0.5763 -0.0209 0.3791 -0.6526 +[torch.DoubleTensor of dimension 6x5] + +=s + 27.4687 + 22.6432 + 8.5584 + 5.9857 + 2.0149 +[torch.DoubleTensor of dimension 5] + +=v +-0.2514 -0.3968 -0.6922 -0.3662 -0.4076 + 0.8148 0.3587 -0.2489 -0.3686 -0.0980 +-0.2606 0.7008 -0.2208 0.3859 -0.4933 + 0.3967 -0.4507 0.2513 0.4342 -0.6227 +-0.2180 0.1402 0.5891 -0.6265 -0.4396 +[torch.DoubleTensor of dimension 5x5] + +=u*torch.diag(s)*v + 8.7900 9.9300 9.8300 5.4500 3.1600 + 6.1100 6.9100 5.0400 -0.2700 7.9800 + -9.1500 -7.9300 4.8600 4.8500 3.0100 + 9.5700 1.6400 8.8300 0.7400 5.8000 + -3.4900 4.0200 9.8000 10.0000 4.2700 + 9.8400 0.1500 -8.9900 -6.0200 -5.3100 +[torch.DoubleTensor of dimension 6x5] + + =a:dist(u*torch.diag(s)*v) +2.8923773593204e-14 + + + diff --git a/dok/memoryfile.dok b/dok/memoryfile.dok new file mode 100644 index 00000000000..3c963e38d13 --- /dev/null +++ b/dok/memoryfile.dok @@ -0,0 +1,36 @@ +====== MemoryFile ====== +{{anchor:torch.MemoryFile.dok}} + +Parent classes: [[File|File]] + +A ''MemoryFile'' is a particular ''File'' which is able to perform basic +read/write operations on a buffer in ''RAM''. It implements all methods +described in [[File|File]]. + +The data of the this ''File'' is contained into a ''NULL'' terminated +[[Storage|CharStorage]]. + +===== torch.MemoryFile([mode]) ===== +{{anchor:torch.MemoryFile}} + +//Constructor// which returns a new ''MemoryFile'' object using ''mode''. Valid +''mode'' are ''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). Default is ''"rw"''. + + +===== torch.MemoryFile(storage, mode) ===== +{{anchor:torch.MemoryFile}} + +//Constructor// which returns a new ''MemoryFile'' object, using the given +[[Storage|storage]] (which must be a ''CharStorage'') and ''mode''. Valid +''mode'' are ''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). The last character +in this storage //must// be ''NULL'' or an error will be generated. This allow +to read existing memory. If used for writing, not that the ''storage'' might +be resized by this class if needed. + +===== [CharStorage] storage() ===== +{{anchor:torch.MemoryFile.storage}} + +Returns the [[Storage|storage]] which contains all the data of the +''File'' (note: this is //not// a copy, but a //reference// on this storage). The +size of the storage is the size of the data in the ''File'', plus one, the +last character being ''NULL''. diff --git a/dok/pipefile.dok b/dok/pipefile.dok new file mode 100644 index 00000000000..910ac8236b9 --- /dev/null +++ b/dok/pipefile.dok @@ -0,0 +1,21 @@ +====== PipeFile ====== +{{anchor:torch.PipeFile.dok}} + +Parent classes: [[DiskFile|DiskFile]] + +A ''PipeFile'' is a particular ''File'' which is able to perform basic read/write operations +on a command pipe. It implements all methods described in [[DiskFile|DiskFile]] and [[File|File]]. + +The file might be open in read or write mode, depending on the parameter +''mode'' (which can take the value ''"r"'' or ''"w"'') +given to the [[#torch.PipeFile|torch.PipeFile(fileName, mode)]]. Read-write mode is not allowed. + +===== torch.PipeFile(command, [mode], [quiet]) ===== +{{anchor:torch.PipeFile}} + +//Constructor// which execute ''command'' by opening a pipe in read or write +''mode''. Valid ''mode'' are ''"r"'' (read) or ''"w"'' (write). Default is read +mode. + +If (and only if) ''quiet'' is ''true'', no error will be raised in case of +problem opening the file: instead ''nil'' will be returned. diff --git a/dok/random.dok b/dok/random.dok new file mode 100644 index 00000000000..394b6fc69a0 --- /dev/null +++ b/dok/random.dok @@ -0,0 +1,105 @@ +====== Random Numbers ====== +{{anchor:torch.random.dok}} + +Torch provides accurate mathematical random generation, based on +[[http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html|Mersenne Twister]] +random number generator. + +====== Seed Handling ====== +{{anchor::torch.seed.dok}} + +If no seed is provided to the random generator (using +[[#torch.seed|seed()]] or [[#torch.manualSeed|manualSeed()]]), a +random seed will be set according to [[#torch.seed|seed()]] the first +time a random number is generated. + +Initial seed can be obtained using [[#torch.initialSeed|initialSeed()]]. + +Setting a particular seed allows the user to (re)-generate a particular serie of +random numbers. Example: + +> torch.manualSeed(123) +> = torch.uniform() +0.69646918727085 +> return torch.uniform() +0.71295532141812 +> return torch.uniform() +0.28613933874294 +> torch.manualSeed(123) +> return torch.uniform() +0.69646918727085 +> return torch.uniform() +0.71295532141812 +> return torch.uniform() +0.28613933874294 +> torch.manualSeed(torch.initialSeed()) +> return torch.uniform() +0.69646918727085 +> return torch.uniform() +0.71295532141812 +> return torch.uniform() +0.28613933874294 + + +===== [number] seed() ===== +{{anchor:torch.seed}} + +Set the seed of the random number generator according to the time of the +computer. Granularity is seconds. Returns the seed obtained. + +===== manualSeed(number) ===== +{{anchor:torch.manualSeed}} + +Set the seed of the random number generator to the given ''number''. + +===== initialSeed() ===== +{{anchor:torch.initialSeed}} + +Returns the initial seed used to initialize the random generator. + +====== [number] random() ====== +{{anchor:torch.random}} + +Returns a 32 bit integer random number. + +====== [number] uniform([a],[b]) ====== +{{anchor:torch.uniform}} + +Returns a random real number according to uniform distribution on [a,b[. By default ''a'' is 0 and ''b'' is 1. + +====== [number] normal([mean],[stdv]) ====== +{{anchor:torch.normal}} + +Returns a random real number according to a normal distribution with the given ''mean'' and standard deviation ''stdv''. +''stdv'' must be positive. + +====== [number] exponential(lambda) ====== +{{anchor:torch.exponential}} + +Returns a random real number according to the exponential distribution +''p(x) = lambda * exp(-lambda * x)'' + +====== [number] cauchy(median, sigma) ====== +{{anchor:torch.cauchy}} + +Returns a random real number according to the Cauchy distribution +''p(x) = sigma/(pi*(sigma^2 + (x-median)^2))'' + +====== [number] logNormal(mean, stdv) ====== +{{anchor:torch.logNormal}} + +Returns a random real number according to the log-normal distribution, with +the given ''mean'' and standard deviation ''stdv''. +''stdv'' must be positive. + +====== [number] geometric(p) ====== +{{anchor:torch.geometric}} + +Returns a random integer number according to a geometric distribution +''p(i) = (1-p) * p^(i-1)''. ''p'' must satisfy ''0 < p < 1''. + +====== [number] bernouilli([p]) ====== +{{anchor:torch.bernoulli}} + +Returns ''1'' with probability ''p'' and ''0'' with probability ''1-p''. ''p'' must satisfy ''0 < p < 1''. +By default ''p'' is equal to ''0.5''. diff --git a/dok/storage.dok b/dok/storage.dok new file mode 100644 index 00000000000..594f7665b8a --- /dev/null +++ b/dok/storage.dok @@ -0,0 +1,222 @@ +====== Storage ====== +{{anchor:torch.Storage.dok}} +{{anchor:torch.ByteStorage.dok}} +{{anchor:torch.CharStorage.dok}} +{{anchor:torch.ShortStorage.dok}} +{{anchor:torch.IntStorage.dok}} +{{anchor:torch.LongStorage.dok}} +{{anchor:torch.FloatStorage.dok}} +{{anchor:torch.DoubleStorage.dok}} + +//Storages// are basically a way for ''Lua'' to access memory of a ''C'' pointer +or array. //Storages// can also [[#__torch.StorageMap|map the contents of a file to memory]]. +A ''Storage'' is an array of //basic// ''C'' types. For arrays of ''Torch'' objects, +use the ''Lua'' tables. + +Several ''Storage'' classes for all the basic ''C'' types exist and have the +following self-explanatory names: ''ByteStorage'', ''CharStorage'', ''ShortStorage'', +''IntStorage'', ''LongStorage'', ''FloatStorage'', ''DoubleStorage''. + +Note that ''ByteStorage'' and ''CharStorage'' represent both arrays of bytes. ''ByteStorage'' represents an array of +//unsigned// chars, while ''CharStorage'' represents an array of //signed// chars. + +Conversions between two ''Storage'' type might be done using ''copy'': + +x = torch.IntStorage(10):fill(1) +y = torch.DoubleStorage(10):copy(x) + + +[[#torch.Storage|Classical storages]] are [[File#torch.File.serialization|serializable]]. +[[#__torch.StorageMap|Storages mapping a file]] are also [[#FileSerialization|serializable]], +but //will be saved as a normal storage//. + +An alias ''torch.Storage()'' is made over your preferred Storage type, +controlled by the +[[utility#torch.setdefaulttensortype|torch.setdefaulttensortype]] +function. By default, this "points" on ''torch.DoubleStorage''. + +===== torch.TYPEStorage([size]) ===== +{{anchor:torch.Storage}} + +Returns a new ''Storage'' of type ''TYPE''. Valid ''TYPE'' are ''Byte'', ''Char'', ''Short'', +''Int'', ''Long'', ''Float'', and ''Double''. If ''size'' is given, resize the +''Storage'' accordingly, else create an empty ''Storage''. + +Example: + +-- Creates a Storage of 10 double: +x = torch.DoubleStorage(10) + + +The data in the ''Storage'' is //uninitialized//. + +===== torch.TYPEStorage(table) ===== +{{anchor:torch.Storage}} + +The argument is assumed to be a Lua array of numbers. The constructor returns a new storage of the specified 'TYPE', +of the size of the table, containing all the table elements converted + +Example: + +> = torch.IntStorage({1,2,3,4}) + + 1 + 2 + 3 + 4 +[torch.IntStorage of size 4] + + +===== torch.TYPEStorage(filename [, shared]) ===== +{{anchor:torch.Storage}} +{{anchor:__torch.StorageMap}} + +Returns a new kind of ''Storage'' which maps the contents of the given +''filename'' to memory. Valid ''TYPE'' are ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'', +''Float'', and ''Double''. If the optional boolean argument ''shared'' is ''true'', +the mapped memory is shared amongst all processes on the computer. + +When ''shared'' is ''true'', the file must be accessible in read-write mode. Any +changes on the storage will be written in the file. The changes might be written +only after destruction of the storage. + +When ''shared'' is ''false'' (or not provided), the file must be at least +readable. Any changes on the storage will not affect the file. Note: +changes made on the file after creation of the storage have an unspecified +effect on the storage contents. + +The [[#torch.Storage.size|size]] of the returned ''Storage'' will be + +(size of file in byte)/(size of TYPE). + + +Example: + +$ echo "Hello World" > hello.txt +$ lua +Lua 5.1.3 Copyright (C) 1994-2008 Lua.org, PUC-Rio +> require 'torch' +> x = torch.CharStorage('hello.txt') +> = x + 72 + 101 + 108 + 108 + 111 + 32 + 87 + 111 + 114 + 108 + 100 + 10 +[torch.CharStorage of size 12] + +> = x:string() +Hello World + +> = x:fill(42):string() +************ +> +$ cat hello.txt +Hello World +$ lua +Lua 5.1.3 Copyright (C) 1994-2008 Lua.org, PUC-Rio +> require 'torch' +> x = torch.CharStorage('hello.txt', true) +> = x:string() +Hello World + +> x:fill(42) +> +$ cat hello.txt +************ + + +===== [number] #self ===== +{{anchor:__torch.StorageSharp}} + +Returns the number of elements in the storage. Equivalent to [[#torch.Storage.size|size()]]. + +===== [number] self[index] ===== +{{anchor:torch.Storage.__index__}} + +Returns or set the element at position ''index'' in the storage. Valid range +of ''index'' is 1 to [[#torch.Storage.size|size()]]. + +Example: + +x = torch.DoubleStorage(10) +print(x[5]) + + +===== [self] copy(storage) ===== +{{anchor:torch.Storage.copy}} + +Copy another ''storage''. The types of the two storages might be different: in that case +a conversion of types occur (which might result, of course, in loss of precision or rounding). +This method returns self, allowing things like: + +x = torch.IntStorage(10):fill(1) +y = torch.DoubleStorage(10):copy(x) -- y won't be nil! + + +===== [self] fill(value) ===== +{{anchor:torch.Storage.fill}} + +Fill the ''Storage'' with the given value. This method returns self, allowing things like: + +x = torch.IntStorage(10):fill(0) -- x won't be nil! + + +===== [self] resize(size) ===== +{{anchor:torch.Storage.resize}} + +Resize the storage to the provide ''size''. //The new contents are undertermined//. + +This function returns self, allowing things like: + +x = torch.DoubleStorage(10):fill(1) +y = torch.DoubleStorage():resize(x:size()):copy(x) -- y won't be nil! + + +===== [number] size() ===== +{{anchor:torch.Storage.size}} + +Returns the number of elements in the storage. Equivalent to [[#__torch.StorageSharp|#]]. + +===== [self] string(str) ===== +{{anchor:torch.Storage.string}} + +This function is available only on ''ByteStorage'' and ''CharStorage''. + +This method resizes the storage to the length of the provided +string ''str'', and copy the contents of ''str'' into the storage. The ''NULL'' terminating character is not copied, +but ''str'' might contain ''NULL'' characters. The method returns the ''Storage''. + +> x = torch.CharStorage():string("blah blah") +> print(x) + 98 + 108 + 97 + 104 + 32 + 98 + 108 + 97 + 104 +[torch.CharStorage of size 9] + + +===== [string] string() ===== +{{anchor:torch.Storage.string}} + +This function is available only on ''ByteStorage'' and ''CharStorage''. + +The contents of the storage viewed as a string are returned. The string might contain +''NULL'' characters. + +> x = torch.CharStorage():string("blah blah") +> print(x:string()) +blah blah + diff --git a/dok/tensor.dok b/dok/tensor.dok new file mode 100644 index 00000000000..93d701a8370 --- /dev/null +++ b/dok/tensor.dok @@ -0,0 +1,1794 @@ +====== Tensor ====== +{{anchor:torch.Tensor.dok}} + +The ''Tensor'' class is probably the most important class in ''Torch''. Almost every package depends on this +class. It is ***the*** class for handling numeric data. Tensors are [[File#torch.File.serialization|serializable]]. + +**Multi-dimensional matrix** + +A ''Tensor'' is a potentially multi-dimensional matrix. The number of +dimensions is unlimited. Many methods have some convenience methods for for +a number of dimensions inferior or equal to ''4'', but can also be called using +[[Storage|LongStorage]] with more dimensions. Example: + + --- creation of a 4D-tensor 4x5x6x2 + z = torch.Tensor(4,5,6,2) + --- for more dimensions, (here a 6D tensor) one can do: + s = torch.LongStorage(6) + s[1] = 4; s[2] = 5; s[3] = 6; s[4] = 2; s[5] = 7; s[6] = 3; + x = torch.Tensor(s) + + +The number of dimensions of a ''Tensor'' can be queried by +[[#torch.Tensor.nDimension|nDimension()]] or [[#torch.Tensor.dim|dim()]]. Size of +the ''i-th'' dimension is returned by [[#torch.Tensor.size|size(i)]]. A +[[Storage|LongStorage]] containing all the dimensions can be returned by +[[#torch.Tensor.size|size()]]. + +> print(x:nDimension()) +6 +> print(x:size()) + 4 + 5 + 6 + 2 + 7 + 3 +[torch.LongStorage of size 6] + + +**Internal data representation** + +The actual data of a ''Tensor'' is contained into a +[[Storage|Storage]]. It can be accessed using +[[#torch.Tensor.storage|''storage()'']]. While the memory of a ''Tensor'' has to be +contained in this unique ''Storage'', it might not be contiguous: +the first position used in the ''Storage'' is given by [[#torch.Tensor.storageOffset|''storageOffset()'']] +(starting at ''1''). And the //jump// needed to go from one element to another element +in the ''i-th'' dimension is given by [[#torch.Tensor.stride|''stride(i)'']]. In other words, given a 3D tensor + +x = torch.Tensor(7,7,7) + +accessing the element ''(3,4,5)'' can be done by + += x[3][4][5] + +or equivalently (but slowly!) + += x:storage()[x:storageOffset() + +(3-1)*x:stride(1)+(4-1)*x:stride(2)+(5-1)*x:stride(3)] + +One could say that a ''Tensor'' is a particular way of //viewing// a +''Storage'': a ''Storage'' only represents a chunk of memory, while the +''Tensor'' interprets this chunk of memory as having dimensions: + +> x = torch.Tensor(4,5) +> s = x:storage() +> for i=1,s:size() do -- fill up the Storage +>> s[i] = i +>> end +> print(x) -- s is interpreted by x as a 2D matrix + 1 2 3 4 5 + 6 7 8 9 10 + 11 12 13 14 15 + 16 17 18 19 20 +[torch.DoubleTensor of dimension 4x5] + + +Note also that in Torch7 **//elements in the same row//** [elements along the **last** dimension] +are contiguous in memory for a matrix [tensor]: + +> x = torch.Tensor(4,5) +> i = 0 +> +> x:apply(function() +>> i = i + 1 +>> return i +>> end) +> +> print(x) + 1 2 3 4 5 + 6 7 8 9 10 + 11 12 13 14 15 + 16 17 18 19 20 +[torch.DoubleTensor of dimension 4x5] + +> return x:stride() + 5 + 1 -- element in the last dimension are contiguous! +[torch.LongStorage of size 2] + +This is exactly like in C (and not ''Fortran''). + +**Tensors of different types** + +Actually, several types of ''Tensor'' exists: + +ByteTensor -- contains unsigned chars +CharTensor -- contains signed chars +ShortTensor -- contains shorts +IntTensor -- contains ints +FloatTensor -- contains floats +DoubleTensor -- contains doubles + + +Most numeric operations are implemented //only// for ''FloatTensor'' and ''DoubleTensor''. +Other Tensor types are useful if you want to save memory space. + +**Default Tensor type** + +For convenience, //an alias// ''torch.Tensor'' is provided, which allows the user to write +type-independent scripts, which can then ran after choosing the desired Tensor type with +a call like + +torch.setdefaulttensortype('torch.FloatTensor') + +See [[Utility#torch.setdefaulttensortype|torch.setdefaulttensortype]] for more details. +By default, the alias "points" on ''torch.DoubleTensor''. + +**Efficient memory management** + +//All// tensor operations in this class do //not// make any memory copy. All +these methods transform the existing tensor, or return a new tensor +referencing //the same storage//. This magical behavior is internally +obtained by good usage of the [[#torch.Tensor.stride|stride()]] and +[[#torch.Tensor.storageOffset|storageOffset()]]. Example: + +> x = torch.Tensor(5):zero() +> print(x) +0 +0 +0 +0 +0 +[torch.DoubleTensor of dimension 5] +> x:narrow(1, 2, 3):fill(1) -- narrow() returns a Tensor + -- referencing the same Storage than x +> print(x) + 0 + 1 + 1 + 1 + 0 +[torch.Tensor of dimension 5] + + +If you really need to copy a ''Tensor'', you can use the [[#torch.Tensor.copy|copy()]] method: + +> y = torch.Tensor(x:size()):copy(x) + +Or the convenience method + +> y = x:clone() + + +We now describe all the methods for ''Tensor''. If you want to specify the Tensor type, +just replace ''Tensor'' by the name of the Tensor variant (like ''CharTensor''). + +===== Tensor constructors ===== +{{anchor:torch.Tensor}} + +Here are several ways to construct a new ''Tensor''. + +==== torch.Tensor() ==== +{{anchor:torch.Tensor}} + +Returns an empty tensor. + +==== torch.Tensor(tensor) ==== +{{anchor:torch.Tensor}} + +Returns a new tensor which reference the same [[#torch.Tensor.storage|Storage]] +than the given ''tensor''. The [[#torch.Tensor.size|size]], +[[#torch.Tensor.stride|stride]], and [[#torch.Tensor.storageOffset|storage offset]] are +the same than the given tensor. + +The new ''Tensor'' is now going to "view" the same +[[Storage|storage]] than the given ''tensor''. As the result, any +modification in the elements of the ''Tensor'' will have a impact on the +elements of the given ''tensor'', and vice-versa. No memory copy! + +> x = torch.Tensor(2,5):fill(3.14) +> print(x) + + 3.1400 3.1400 3.1400 3.1400 3.1400 + 3.1400 3.1400 3.1400 3.1400 3.1400 +[torch.DoubleTensor of dimension 2x5] + +> y = torch.Tensor(x) +> print(y) + + 3.1400 3.1400 3.1400 3.1400 3.1400 + 3.1400 3.1400 3.1400 3.1400 3.1400 +[torch.DoubleTensor of dimension 2x5] + +> y:zero() +> print(x) -- elements of x are the same than y! + +0 0 0 0 0 +0 0 0 0 0 +[torch.DoubleTensor of dimension 2x5] + + + +==== torch.Tensor(sz1 [,sz2 [,sz3 [,sz4]]]]) ==== +{{anchor:torch.Tensor}} + +Create a tensor up to 4 dimensions. The tensor size will be ''sz1 x sz2 x sx3 x sz4''. + +==== torch.Tensor(sizes, [strides]) ==== +{{anchor:torch.Tensor}} + +Create a tensor of any number of dimensions. The [[Storage|LongStorage]] +''sizes'' gives the size in each dimension of the tensor. The optional +[[Storage|LongStorage]] ''strides'' gives the jump necessary to go from +one element to the next one in the each dimension. Of course, ''sizes'' and +''strides'' must have the same size. If not given, or if some elements of +''strides'' are //negative//, the [[#torch.Tensor.stride|stride()]] will be computed +such that the tensor is as contiguous as possible in memory. + +Example, create a 4D 4x4x3x2 tensor: + +x = torch.Tensor(torch.LongStorage({4,4,3,2})) + + +Playing with the strides can give some interesting things: + +x = torch.Tensor(torch.LongStorage({4}), torch.LongStorage({0})):zero() -- zeroes the tensor +x[1] = 1 -- all elements point to the same address! +print(x) + + 1 + 1 + 1 + 1 +[torch.DoubleTensor of dimension 4] + +Note that //negative strides are not allowed//, and, if given as argument when constructing the Tensor, will be interpreted +as //choose the right stride such that the Tensor is contiguous in memory//. + +==== torch.Tensor(storage, [storageOffset, sizes, [strides]]) ==== +{{anchor:torch.Tensor}} + +Returns a tensor which uses the existing [[Storage|Storage]] +''storage'', starting at position ''storageOffset'' (>=1). The size of each +dimension of the tensor is given by the [[Storage|LongStorage]] +''sizes''. + +If only ''storage'' is provided, it will create a 1D Tensor viewing the all Storage. + +The jump necessary to go from one element to the next one in each dimension +is given by the optional argument [[Storage|LongStorage]] ''strides''. If +not given, or if some elements of ''strides'' are negative, the +[[#torch.Tensor.stride|stride()]] will be computed such that the tensor is as +contiguous as possible in memory. + +Any modification in the elements of the ''Storage'' will have a impact on the +elements of the new ''Tensor'', and vice-versa. There is no memory copy! + +-- creates a storage with 10 elements +> s = torch.Storage(10):fill(1) + + -- we want to see it as a 2x5 tensor +> x = torch.Tensor(s, 1, torch.LongStorage{2,5}) +> print(x) + + 1 1 1 1 1 + 1 1 1 1 1 +[torch.DoubleTensor of dimension 2x5] +> x:zero() +> print(s) -- the storage contents have been modified +> print(s) +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +[torch.DoubleStorage of size 10] + + +==== torch.Tensor(storage, [storageOffset, sz1 [, st1 ... [, sz4 [, st4]]]]) ==== +{{anchor:torch.Tensor}} + +Convenience constructor (for the previous constructor) assuming a number of +dimensions inferior or equal to 4. ''szi'' is the size in the ''i-th'' dimension, and ''sti'' it +the stride in the ''i-th'' dimension. + +==== torch.Tensor(table) ===== +{{anchor:torch.Tensor}} + +The argument is assumed to be a Lua array of numbers. The constructor +returns a new Tensor of the size of the table, containing all the table +elements. The table might be multi-dimensional. + +Example: + +> = torch.Tensor({{1,2,3,4}, {5,6,7,8}}) + 1 2 3 4 + 5 6 7 8 +[torch.DoubleTensor of dimension 2x4] + + +===== Cloning ===== + +==== [Tensor] clone() ==== +{{anchor:torch.Tensor.clone}} + +Returns a clone of a tensor. The memory is copied. + + +i = 0 +x = torch.Tensor(5):apply(function(x) +i = i + 1 +return i +end) += x + + 1 + 2 + 3 + 4 + 5 +[torch.DoubleTensor of dimension 5] + +-- create a clone of x +y = x:clone() + += y + + 1 + 2 + 3 + 4 + 5 +[torch.DoubleTensor of dimension 5] + +-- fill up y with 1 +y:fill(1) += y + + 1 + 1 + 1 + 1 + 1 +[torch.DoubleTensor of dimension 5] + +-- the contents of x were not changed: += x + + 1 + 2 + 3 + 4 + 5 +[torch.DoubleTensor of dimension 5] + + +==== [Tensor] contiguous ==== +{{anchor:torch.Tensor.contiguous}} + + * If the given Tensor contents are contiguous in memory, returns the exact same Tensor (no memory copy). + * Otherwise (//not contiguous in memory//), returns a [[#torch.Tensor.clone|clone]] (memory //copy//). + + +x = torch.Tensor(2,3):fill(1) += x + + 1 1 1 + 1 1 1 +[torch.DoubleTensor of dimension 2x3] + +-- x is contiguous, so y points to the same thing +y = x:contiguous():fill(2) += y + + 2 2 2 + 2 2 2 +[torch.DoubleTensor of dimension 2x3] + +-- contents of x have been changed += x + + 2 2 2 + 2 2 2 +[torch.DoubleTensor of dimension 2x3] + +-- x:t() is not contiguous, so z is a clone +z = x:t():contiguous():fill(3.14) += z + + 3.1400 3.1400 + 3.1400 3.1400 + 3.1400 3.1400 +[torch.DoubleTensor of dimension 3x2] + +-- contents of x have not been changed += x + + 2 2 2 + 2 2 2 +[torch.DoubleTensor of dimension 2x3] + + +==== [Tensor or string] type(type) ==== +{{anchor:torch.Tensor.type}} + +**If ''type'' is ''nil''**, returns the type name of the given tensor. + += torch.Tensor():type() +torch.DoubleTensor + + +**If ''type'' is a string** describing a Tensor type, and is equal to the +given tensor typename, returns the exact same tensor (//no memory copy//). + +x = torch.Tensor(3):fill(3.14) += x + + 3.1400 + 3.1400 + 3.1400 +[torch.DoubleTensor of dimension 3] + +y = x:type('torch.DoubleTensor') += y + + 3.1400 + 3.1400 + 3.1400 +[torch.DoubleTensor of dimension 3] + +-- zero y contents +y:zero() + +-- contents of x have been changed += x + +0 +0 +0 +[torch.DoubleTensor of dimension 3] + + + +**If ''type'' is a string** describing a Tensor type, different from the type name of the given Tensor, +returns a new Tensor of the specified type, whose contents corresponds to the contents of the original Tensor, +casted to the given type (//memory copy occurs, with possible loss of precision//). + +x = torch.Tensor(3):fill(3.14) += x + + 3.1400 + 3.1400 + 3.1400 +[torch.DoubleTensor of dimension 3] + +y = x:type('torch.IntTensor') += y + + 3 + 3 + 3 +[torch.IntTensor of dimension 3] + + + +==== [Tensor] typeAs(tensor) ==== +{{anchor:torch.Tensor.typeAs}} + +Convenience method for the [[#torch.Tensor.type|type]] method. Equivalent to + +type(tensor:type()) + + + +==== [Tensor] byte(), char(), short(), int(), long(), float(), double() ==== +{{anchor:torch.Tensor.byte}} +{{anchor:torch.Tensor.char}} +{{anchor:torch.Tensor.short}} +{{anchor:torch.Tensor.int}} +{{anchor:torch.Tensor.long}} +{{anchor:torch.Tensor.float}} +{{anchor:torch.Tensor.double}} + +Convenience methods for the [[#torch.Tensor.type|type]] method. For e.g., + +x = torch.Tensor(3):fill(3.14) + += x + 3.1400 + 3.1400 + 3.1400 +[torch.DoubleTensor of dimension 3] + +-- calling type('torch.IntTensor') += x:type('torch.IntTensor') + + 3 + 3 + 3 +[torch.IntTensor of dimension 3] + + +-- is equivalent to calling int() += x:int() + + 3 + 3 + 3 +[torch.IntTensor of dimension 3] + + +===== Querying the size and structure ===== + +==== [number] nDimension() ==== +{{anchor:torch.Tensor.nDimension}} + +Returns the number of dimensions in a ''Tensor''. + +> x = torch.Tensor(4,5) -- a matrix +> = x:nDimension() +2 + + +==== [number] dim() ==== +{{anchor:torch.Tensor.dim}} + +Same as [[#torch.Tensor.nDimension|nDimension()]]. + +==== [number] size(dim) ==== +{{anchor:torch.Tensor.size}} + +Returns the size of the specified dimension ''dim''. Example: + +> x = torch.Tensor(4,5):zero() +> print(x) + +0 0 0 0 0 +0 0 0 0 0 +0 0 0 0 0 +0 0 0 0 0 +[torch.DoubleTensor of dimension 4x5] + +> return x:size(2) -- gets the number of columns +5 + + +==== [LongStorage] size() ==== +{{anchor:torch.Tensor.size}} + +Returns a [[Storage|LongStorage]] containing the size of each dimension +of the tensor. + +> x = torch.Tensor(4,5):zero() +> print(x) + +0 0 0 0 0 +0 0 0 0 0 +0 0 0 0 0 +0 0 0 0 0 +[torch.DoubleTensor of dimension 4x5] + +> return x:size() + 4 + 5 +[torch.LongStorage of size 2] + + +==== [LongStorage] #self ==== +{{anchor:torch.Tensor.size}} + +Same as previous method. + +==== [number] stride(dim) ==== +{{anchor:torch.Tensor.stride}} + +Returns the jump necessary to go from one element to the next one in the +specified dimension ''dim''. Example: + +> x = torch.Tensor(4,5):zero() +> print(x) + +0 0 0 0 0 +0 0 0 0 0 +0 0 0 0 0 +0 0 0 0 0 +[torch.DoubleTensor of dimension 4x5] + + --- elements in a column are contiguous in memory +> return x:stride(1) +1 + + --- to go from one element to the next one in a row + --- we need here to jump the size of the column +> return x:stride(1) +5 + + +Note also that in ''Torch'' //elements in the same row// [elements along the **last** dimension] +are contiguous in memory for a matrix [tensor]. + +==== [LongStorage] stride() ==== +{{anchor:torch.Tensor.stride}} + +Returns the jump necessary to go from one element to the next one in each dimension. Example: + +> x = torch.Tensor(4,5):zero() +> print(x) + +0 0 0 0 0 +0 0 0 0 0 +0 0 0 0 0 +0 0 0 0 0 +[torch.DoubleTensor of dimension 4x5] + +> return x:stride() + 5 + 1 -- elements are contiguous in a column [last dimension] +[torch.LongStorage of size 2] + + +Note also that in ''Torch'' //elements in the same row// [elements along the **last** dimension] +are contiguous in memory for a matrix [tensor]. + +==== [Storage] storage() ==== +{{anchor:torch.Tensor.storage}} + +Returns the [[Storage|Storage]] used to store all the elements of the ''Tensor''. +Basically, a ''Tensor'' is a particular way of //viewing// a ''Storage''. + +> x = torch.Tensor(4,5) +> s = x:storage() +> for i=1,s:size() do -- fill up the Storage +>> s[i] = i +>> end +> print(x) -- s is interpreted by x as a 2D matrix + + 1 2 3 4 5 + 6 7 8 9 10 + 11 12 13 14 15 + 16 17 18 19 20 +[torch.DoubleTensor of dimension 4x5] + + +==== [boolean] isContiguous() ==== +{{anchor:torch.Tensor.isContiguous}} + +Returns ''true'' iff the elements of the ''Tensor'' are contiguous in memory. + + -- normal tensors are contiguous in memory +> x = torch.Tensor(4,5):zero() +> = x:isContiguous() +true + -- y now "views" the 3rd column of x + -- the storage of y is the same than x + -- so the memory cannot be contiguous +> y = x:select(2, 3) +> = y:isContiguous() +false + -- indeed, to jump to one element to + -- the next one, the stride is 4 +> = y:stride() + 5 +[torch.LongStorage of size 1] + + +==== [number] nElement() ==== +{{anchor:torch.Tensor.nElement}} + +Returns the number of elements of a tensor. + +> x = torch.Tensor(4,5) +> = x:nElement() -- 4x5 = 20! +20 + + +==== [number] storageOffset() ==== +{{anchor:torch.Tensor.storageOffset}} + +Return the first index (starting at 1) used in the tensor's [[#torch.Tensor.storage|storage]]. + +===== Querying elements ===== +{{anchor:torch.Tensor.__index__}} + +Elements of a tensor can be retrieved with the ''[index]'' operator. + +If ''index'' is a number, ''[index]'' operator is equivalent to a [[#torch.Tensor.select|''select(last dim, start)'']] if the +tensor has more than one dimension. If the tensor is a 1D tensor, it returns the value +at ''index'' in this tensor. + +If ''index'' is a table, the table must contain //n// numbers, where //n// +is the [[#torch.Tensor.nDimension|number of dimensions]] of the Tensor. It will return the element +at the given position. + +In the same spirit, ''index'' might be a [[Storage|LongStorage]], +specifying the position (in the Tensor) of the element to be retrieved. + +Example: + +> x = torch.Tensor(3,3) +> i = 0; x:apply(function() i = i + 1; return i end) +> = x + + 1 2 3 + 4 5 6 + 7 8 9 +[torch.DoubleTensor of dimension 3x3] + +> = x[2] -- returns row 2 + + 4 + 5 + 6 +[torch.DoubleTensor of dimension 3] + +> = x[2][3] -- returns row 2, column 3 +6 + +> = x[{2,3}] -- another way to return row 2, column 3 +6 + +> = x[torch.LongStorage{2,3}] -- yet another way to return row 2, column 3 +6 + + + +===== Referencing a tensor to an existing tensor or chunk of memory ===== +{{anchor:torch.Tensor.set}} + +A ''Tensor'' being a way of //viewing// a [[Storage|Storage]], it is +possible to "set" a ''Tensor'' such that it views an existing [[Storage|Storage]]. + +Note that if you want to perform a set on an empty ''Tensor'' like + +y = torch.Storage(10) +x = torch.Tensor() +x:set(y, 1, 10) + +you might want in that case to use one of the [[#torch.Tensor|equivalent constructor]]. + +y = torch.Storage(10) +x = torch.Tensor(y, 1, 10) + + +==== [self] set(tensor) ==== +{{anchor:torch.Tensor.set}} + +The ''Tensor'' is now going to "view" the same [[#torch.Tensor.storage|storage]] +than the given ''tensor''. As the result, any modification in the elements of +the ''Tensor'' will have a impact on the elements of the given ''tensor'', and +vice-versa. This is an efficient method, as there is no memory copy! + + +> x = torch.Tensor(2,5):fill(3.14) +> print(x) + + 3.1400 3.1400 3.1400 3.1400 3.1400 + 3.1400 3.1400 3.1400 3.1400 3.1400 +[torch.DoubleTensor of dimension 2x5] + +> y = torch.Tensor():set(x) +> print(y) + + 3.1400 3.1400 3.1400 3.1400 3.1400 + 3.1400 3.1400 3.1400 3.1400 3.1400 +[torch.DoubleTensor of dimension 2x5] + +> y:zero() +> print(x) -- elements of x are the same than y! + +0 0 0 0 0 +0 0 0 0 0 +[torch.DoubleTensor of dimension 2x5] + + +==== [self] set(storage, [storageOffset, sizes, [strides]]) ==== +{{anchor:torch.Tensor.set}} + +The ''Tensor'' is now going to "view" the given +[[Storage|''storage'']], starting at position ''storageOffset'' (>=1) +with the given [[#torch.Tensor.size|dimension ''sizes'']] and the optional given +[[#torch.Tensor.stride|''strides'']]. As the result, any modification in the +elements of the ''Storage'' will have a impact on the elements of the +''Tensor'', and vice-versa. This is an efficient method, as there is no +memory copy! + +If only ''storage'' is provided, the whole storage will be viewed as a 1D Tensor. + + + -- creates a storage with 10 elements +> s = torch.Storage(10):fill(1) + + -- we want to see it as a 2x5 tensor +> sz = torch.LongStorage({2,5}) +> x = torch.Tensor() +> x:set(s, 1, sz) +> print(x) + + 1 1 1 1 1 + 1 1 1 1 1 +[torch.DoubleTensor of dimension 2x5] +> x:zero() +> print(s) -- the storage contents have been modified +> print(s) +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +[torch.DoubleStorage of size 10] + + +==== [self] set(storage, [storageOffset, sz1 [, st1 ... [, sz4 [, st4]]]]) ==== +{{anchor:torch.Tensor.set}} + +This is a "shorcut" for previous method. +It works up to 4 dimensions. ''szi'' is the size of the ''i''-th dimension of the tensor. +''sti'' is the stride in the ''i''-th dimension. + +===== Copying and initializing ===== + +==== [self] copy(tensor) ==== +{{anchor:torch.Tensor.copy}} + +Copy the elements of the given ''tensor''. The [[#torch.Tensor.nElement|number of elements]] must match, but +the sizes might be different. + +> x = torch.Tensor(4):fill(1) +> y = torch.Tensor(2,2):copy(x) +> print(x) + + 1 + 1 + 1 + 1 +[torch.DoubleTensor of dimension 4] + +> print(y) + + 1 1 + 1 1 +[torch.DoubleTensor of dimension 2x2] + + +If a different type of ''tensor'' is given, then a type conversion occurs, +which, of course, might result in lost of precision. + +==== [self] fill(value) ==== +{{anchor:torch.Tensor.fill}} + +Fill the tensor with the given ''value''. + +> = torch.DoubleTensor(4):fill(3.14) + + 3.1400 + 3.1400 + 3.1400 + 3.1400 +[torch.DoubleTensor of dimension 4] + + +==== [self] zero() ==== +{{anchor:torch.Tensor.zero}} + +Fill the tensor with zeros. + +> = torch.Tensor(4):zero() + +0 +0 +0 +0 +[torch.DoubleTensor of dimension 4] + + +===== Resizing ===== +{{anchor:torch.Tensor.resize.dok}} + +**When resizing to a larger size**, the underlying [[Storage|Storage]] is resized to fit +all the elements of the ''Tensor''. + +**When resizing to a smaller size**, the underlying [[#Storage|Storage]] is not resized. + +**Important note:** the content of a ''Tensor'' after resizing is //undertermined// as [[#torch.Tensor.stride|strides]] +might have been completely changed. In particular, //the elements of the resized tensor are contiguous in memory//. + +==== [self] resizeAs(tensor) ==== +{{anchor:torch.Tensor.resizeAs}} + +Resize the ''tensor'' as the given ''tensor'' (of the same type). + +==== [self] resize(sizes) ==== +{{anchor:torch.Tensor.resize}} + +Resize the ''tensor'' according to the given [[Storage|LongStorage]] ''size''. + +==== [self] resize(sz1 [,sz2 [,sz3 [,sz4]]]]) ==== +{{anchor:torch.Tensor.resize}} + +Convenience method of the previous method, working for a number of dimensions up to 4. + +===== Extracting sub-tensors ===== + +Each of these methods returns a ''Tensor'' which is a sub-tensor of the given +tensor, //with the same ''Storage''//. Hence, any modification in the memory of +the sub-tensor will have an impact on the primary tensor, and vice-versa. + +These methods are very fast, as they do not involve any memory copy. + +==== [Tensor] narrow(dim, index, size) ==== +{{anchor:torch.Tensor.narrow}} + +Returns a new ''Tensor'' which is a narrowed version of the current one: the dimension ''dim'' is narrowed +from ''index'' to ''index+size-1''. + + +> x = torch.Tensor(5, 6):zero() +> print(x) + +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +[torch.DoubleTensor of dimension 5x6] + +> y = x:narrow(1, 2, 3) -- narrow dimension 1 from index 2 to index 2+3-1 +> y:fill(1) -- fill with 1 +> print(y) + + 1 1 1 1 1 1 + 1 1 1 1 1 1 + 1 1 1 1 1 1 +[torch.DoubleTensor of dimension 3x6] + +> print(x) -- memory in x has been modified! + + 0 0 0 0 0 0 + 1 1 1 1 1 1 + 1 1 1 1 1 1 + 1 1 1 1 1 1 + 0 0 0 0 0 0 +[torch.DoubleTensor of dimension 5x6] + + +==== [Tensor] sub(dim1s, dim1e ... [, dim4s [, dim4e]]) ==== +{{anchor:torch.Tensor.sub}} + +This method is equivalent to do a serie of +[[#torch.Tensor.narrow|narrow]] up to the first 4 dimensions. It returns +a new ''Tensor'' which is a sub-tensor going from index ''dimis'' to ''dimie'' in +the ''i''-th dimension. Negative values are interpreted index starting from the end: +''-1'' is the last index, ''-2'' is the index before the last index, ... + + +> x = torch.Tensor(5, 6):zero() +> print(x) + +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +[torch.DoubleTensor of dimension 5x6] + +> y = x:sub(2,4):fill(1) -- y is sub-tensor of x: +> print(y) -- dimension 1 starts at index 2, ends at index 4 + + 1 1 1 1 1 1 + 1 1 1 1 1 1 + 1 1 1 1 1 1 +[torch.DoubleTensor of dimension 3x6] + +> print(x) -- x has been modified! + + 0 0 0 0 0 0 + 1 1 1 1 1 1 + 1 1 1 1 1 1 + 1 1 1 1 1 1 + 0 0 0 0 0 0 +[torch.DoubleTensor of dimension 5x6] + +> z = x:sub(2,4,3,4):fill(2) -- we now take a new sub-tensor +> print(z) -- dimension 1 starts at index 2, ends at index 4 + -- dimension 2 starts at index 3, ends at index 4 + 2 2 + 2 2 + 2 2 +[torch.DoubleTensor of dimension 3x2] + +> print(x) -- x has been modified + + 0 0 0 0 0 0 + 1 1 2 2 1 1 + 1 1 2 2 1 1 + 1 1 2 2 1 1 + 0 0 0 0 0 0 +[torch.DoubleTensor of dimension 5x6] + +> print(y:sub(-1, -1, 3, 4)) -- negative values = bounds + + 2 2 +[torch.DoubleTensor of dimension 1x2] + + +==== [Tensor] select(dim, index) ==== +{{anchor:torch.Tensor.select}} + +Returns a new ''Tensor'' which is a tensor slice at the given ''index'' in the +dimension ''dim''. The returned tensor has one less dimension: the dimension +''dim'' is removed. As a result, it is not possible to ''select()'' on a 1D +tensor. + +Note that "selecting" on the first dimension is equivalent to use the [[#torch.Tensor.__index__ |[] operator]] + + +> x = torch.Tensor(5,6):zero() +> print(x) + +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +0 0 0 0 0 0 +[torch.DoubleTensor of dimension 5x6] + +> y = x:select(1, 2):fill(2) -- select row 2 and fill up +> print(y) + + 2 + 2 + 2 + 2 + 2 + 2 +[torch.DoubleTensor of dimension 6] + +> print(x) + + 0 0 0 0 0 0 + 2 2 2 2 2 2 + 0 0 0 0 0 0 + 0 0 0 0 0 0 + 0 0 0 0 0 0 +[torch.DoubleTensor of dimension 5x6] + +> z = x:select(2,5):fill(5) -- select column 5 and fill up +> print(z) + + 5 + 5 + 5 + 5 + 5 +[torch.DoubleTensor of dimension 5] + +> print(x) + + 0 0 0 0 5 0 + 2 2 2 2 5 2 + 0 0 0 0 5 0 + 0 0 0 0 5 0 + 0 0 0 0 5 0 +[torch.DoubleTensor of dimension 5x6] + + +===== Manipulating the tensor view ===== + +Each of these methods returns a ''Tensor'' which is another way of viewing +the ''Storage'' of the given tensor. Hence, any modification in the memory of +the sub-tensor will have an impact on the primary tensor, and vice-versa. + +These methods are very fast, are they do not involve any memory copy. + +==== [Tensor] transpose(dim1, dim2) ==== +{{anchor:torch.Tensor.transpose}} + +Returns a tensor where dimensions ''dim1'' and ''dim2'' have been swapped. For 2D tensors, +the convenience method of [[#torch.Tensor.t|t()]] is available. + +> x = torch.Tensor(3,4):zero() +> x:select(2,3):fill(7) -- fill column 3 with 7 +> print(x) + + 0 0 7 0 + 0 0 7 0 + 0 0 7 0 +[torch.DoubleTensor of dimension 3x4] + +> y = x:transpose(1,2) -- swap dimension 1 and 2 +> print(y) + + 0 0 0 + 0 0 0 + 7 7 7 + 0 0 0 +[torch.DoubleTensor of dimension 4x3] + +> y:select(2, 3):fill(8) -- fill column 3 with 8 +> print(y) + + 0 0 8 + 0 0 8 + 7 7 8 + 0 0 8 +[torch.DoubleTensor of dimension 4x3] + +> print(x) -- contents of x have changed as well + + 0 0 7 0 + 0 0 7 0 + 8 8 8 8 +[torch.DoubleTensor of dimension 3x4] + + + +==== [Tensor] t() ==== +{{anchor:torch.Tensor.t}} + +Convenience method of [[#torch.Tensor.transpose|transpose()]] for 2D +tensors. The given tensor must be 2 dimensional. Swap dimensions 1 and 2. + +> x = torch.Tensor(3,4):zero() +> x:select(2,3):fill(7) +> y = x:t() +> print(y) + + 0 0 0 + 0 0 0 + 7 7 7 + 0 0 0 +[torch.DoubleTensor of dimension 4x3] + +> print(x) + + 0 0 7 0 + 0 0 7 0 + 0 0 7 0 +[torch.DoubleTensor of dimension 3x4] + + +==== [Tensor] unfold(dim, size, step) ==== +{{anchor:torch.Tensor.unfold}} + +Returns a tensor which contains all slices of size ''size'' in the dimension ''dim''. Step between +two slices is given by ''step''. + +If ''sizedim'' is the original size of dimension ''dim'', the size of dimension +''dim'' in the returned tensor will be ''(sizedim - size) / step + 1'' + +An additional dimension of size ''size'' is appended in the returned tensor. + + +> x = torch.Tensor(7) +> for i=1,7 do x[i] = i end +> print(x) + + 1 + 2 + 3 + 4 + 5 + 6 + 7 +[torch.DoubleTensor of dimension 7] + +> return x:unfold(1, 2, 1) + + 1 2 + 2 3 + 3 4 + 4 5 + 5 6 + 6 7 +[torch.DoubleTensor of dimension 6x2] + +> return x:unfold(1, 2, 2) + + 1 2 + 3 4 + 5 6 +[torch.DoubleTensor of dimension 3x2] + + +===== Applying a function to a tensor ===== + +These functions apply a function to each element of the tensor on which the +method is called (self). These methods are much faster than using a ''for'' +loop in ''Lua''. The results is stored in ''self'' (if the function returns +something). A similar function exists in the [[..:lab:index#map|lab]] +package, but where a new tensor containing the result is returned instead. + +==== [self] apply(function) ==== +{{anchor:torch.Tensor.apply}} + +Apply the given function to all elements of self. + +The function takes a number (the current element of the tensor) and might return +a number, in which case it will be stored in self. + +Examples: + +> i = 0 +> z = torch.Tensor(3,3) +> z:apply(function(x) +>> i = i + 1 +>> return i +>> end) -- fill up the tensor +> = z + + 1 2 3 + 4 5 6 + 7 8 9 +[torch.DoubleTensor of dimension 3x3] + +> z:apply(math.sin) -- apply the sin function +> = z + + 0.8415 0.9093 0.1411 +-0.7568 -0.9589 -0.2794 + 0.6570 0.9894 0.4121 +[torch.DoubleTensor of dimension 3x3] + +> sum = 0 +> z:apply(function(x) +>> sum = sum + x +>> end) -- compute the sum of the elements +> = sum +1.9552094821074 +> = z:sum() -- it is indeed correct! +1.9552094821074 + + +==== [self] map(tensor, function(xs, xt)) ==== +{{anchor:torch.Tensor.map}} + +Apply the given function to all elements of self and ''tensor''. The number of elements of both tensors +must match, but sizes do not matter. + +The function takes two numbers (the current element of self and ''tensor'') and might return +a number, in which case it will be stored in self. + +Example: + +> x = torch.Tensor(3,3) +> y = torch.Tensor(9) +> i = 0 +> x:apply(function() i = i + 1; return i end) -- fill-up x +> i = 0 +> y:apply(function() i = i + 1; return i end) -- fill-up y +> = x + + 1 2 3 + 4 5 6 + 7 8 9 +[torch.DoubleTensor of dimension 3x3] + +> = y + + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +[torch.DoubleTensor of dimension 9] + +> x:map(y, function(xx, yy) return xx*yy end) -- element-wise multiplication +> = x + + 1 4 9 + 16 25 36 + 49 64 81 +[torch.DoubleTensor of dimension 3x3] + + +==== [self] map2(tensor1, tensor2, function(x, xt1, xt2)) ==== +{{anchor:torch.Tensor.map2}} + +Apply the given function to all elements of self, ''tensor1'' and ''tensor2''. The number of elements of all tensors +must match, but sizes do not matter. + +The function takes three numbers (the current element of self, ''tensor1'' and ''tensor2'') and might return +a number, in which case it will be stored in self. + +Example: + +> x = torch.Tensor(3,3) +> y = torch.Tensor(9) +> z = torch.Tensor(3,3) +> +> i = 0; x:apply(function() i = i + 1; return math.cos(i)*math.cos(i) end) +> i = 0; y:apply(function() i = i + 1; return i end) +> i = 0; z:apply(function() i = i + 1; return i end) +> +> print(x) + + 0.2919 0.1732 0.9801 + 0.4272 0.0805 0.9219 + 0.5684 0.0212 0.8302 +[torch.DoubleTensor of dimension 3x3] + +> print(y) + + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +[torch.DoubleTensor of dimension 9] + +> print(z) + + 1 2 3 + 4 5 6 + 7 8 9 +[torch.DoubleTensor of dimension 3x3] + +> +> x:map2(y, z, function(xx, yy, zz) return xx+yy*zz end) +> +> print(x) + + 1.2919 4.1732 9.9801 + 16.4272 25.0805 36.9219 + 49.5684 64.0212 81.8302 +[torch.DoubleTensor of dimension 3x3] + + +===== Math functions ===== + +These functions apply a function to the tensor, and return self. + +==== [self] log() ==== +{{anchor:torch.Tensor.Log}} + +Computes the natural logarithm. + +==== [self] log1p() ==== +{{anchor:torch.Tensor.Log1p}} + +''log1p(x)'' computes [[#torch.Tensor.Log|log(1+x)]], with precision more +accurate than standard ''log()'' function for small value of ''x''. + +==== [self] exp() ==== +{{anchor:torch.Tensor.Exp}} + +Computes the exponential function. + +==== [self] cos() ==== +{{anchor:torch.Tensor.cos}} + +Computes the cosine function. + +==== [self] acos() ==== +{{anchor:torch.Tensor.acos}} + +Computes the arc cosine function. + +==== [self] cosh() ==== +{{anchor:torch.Tensor.cosh}} + +Computes the hyperbolic cosine function. + +==== [self] sin() ==== +{{anchor:torch.Tensor.sin}} + +Computes the sinusoid function. + +==== [self] asin() ==== +{{anchor:torch.Tensor.asin}} + +Computes the arc sinusoid function. + +==== [self] sinh() ==== +{{anchor:torch.Tensor.sinh}} + +Computes the hyperbolic sinusoid function. + +==== [self] tan() ==== +{{anchor:torch.Tensor.tanh}} + +Computes the tangent function. + +==== [self] atan() ==== +{{anchor:torch.Tensor.atan}} + +Computes the arc tangent function. + +==== [self] tanh() ==== +{{anchor:torch.Tensor.tanh}} + +Computes the hyperbolic tangent function. + +==== [self] pow(value) ==== +{{anchor:torch.Tensor.pow}} + +Computes the power to the given ''value''. + +==== [self] sqrt() ==== +{{anchor:torch.Tensor.sqrt}} + +Computes the square root. Values must be positive. + +==== [self] ceil() ==== +{{anchor:torch.Tensor.ceil}} + +For each tensor value, computes the smallest integral value greater than or equal to this value. + +==== [self] floor() ==== +{{anchor:torch.Tensor.floor}} + +For each tensor value, computes the largest integral value less than or equal to this value. + +==== [self] abs() ==== +{{anchor:torch.Tensor.abs}} + +Computes the absolute value. + +===== Basic statistics ===== + +This functions return a statistic (scalar) on the full tensor. For more complex statistics +see the [[..:lab:index|lab]] package. + +==== [number] sum() ==== +{{anchor:torch.Tensor.sum}} + +Returns the sum of all the elements in the tensor. + +==== [number] mean() ==== +{{anchor:torch.Tensor.mean}} + +Returns the mean of all the elements in the tensor. + +==== [number] max() ==== +{{anchor:torch.Tensor.max}} + +Returns the maximum value of the tensor. + +==== [number] min() ==== +{{anchor:torch.Tensor.min}} + +Returns the minimum value of the tensor. + +==== [number] std() ==== +{{anchor:torch.Tensor.std}} + +Returns the unbiased standard deviation estimator of the elements in the tensor. (The normalization +factor is (n-1) and not n, where n is the number of elements in the tensor). + +==== [number] var() ==== +{{anchor:torch.Tensor.var}} + +Returns the unbiased variance estimator of the elements in the tensor. (The normalization +factor is (n-1) and not n, where n is the number of elements in the tensor). + +==== [number] norm([p]) ==== +{{anchor:torch.Tensor.norm}} + +Returns the p-norm of the elements of the tensor seen as a vector. +Default value for ''p'' is ''2''. + +==== [number] dist(tensor, [value]) ==== +{{anchor:torch.Tensor.dist}} + +Returns the p-norm of the difference between self and the given ''tensor''. + +===== Basic operations ===== +{{anchor:torch.Tensor.BasicOperations.dok}} + +All this operation affect the tensor on which the method is called (self). +No additional memory is created. + +==== [self] add(value) ==== +{{anchor:torch.Tensor.add}} + +Add the given value to all elements in the tensor. + +==== [self] add(tensor) ==== +{{anchor:torch.Tensor.add}} + +Add the given ''tensor'' to self. The number of elements must match, but sizes do not matter. + +> x = torch.Tensor(2,2):fill(2) +> y = torch.Tensor(4):fill(3) +> x:add(y) +> = x + + 5 5 + 5 5 +[torch.Tensor of dimension 2x2] + + +==== [self] add(value, tensor) ==== +{{anchor:torch.Tensor.add}} + +Multiply elements of ''tensor'' by the scalar ''value'' and add it to self. +The number of elements must match, but sizes do not matter. + +> x = torch.Tensor(2,2):fill(2) +> y = torch.Tensor(4):fill(3) +> x:add(2, y) +> = x + + 8 8 + 8 8 +[torch.Tensor of dimension 2x2] + + +==== [self] mul(value) ==== +{{anchor:torch.Tensor.mul}} + +Multiply all elements in the tensor by the given ''value''. + +==== [self] cmul(tensor) ==== +{{anchor:torch.Tensor.cmul}} + +Element-wise multiplication of ''tensor'' by self. The number of elements must match, but sizes do not matter. + +> x = torch.Tensor(2,2):fill(2) +> y = torch.Tensor(4):fill(3) +> x:cmul(y) +> = x + + 6 6 + 6 6 +[torch.Tensor of dimension 2x2] + + +==== [self] addcmul(value, tensor1, tensor2) ==== +{{anchor:torch.Tensor.addcmul}} + +Performs the element-wise multiplication of ''tensor1'' by ''tensor1'', multiply the result by the scalar ''value'' +and add it to self. The number of elements must match, but sizes do not matter. + +> x = torch.Tensor(2,2):fill(2) +> y = torch.Tensor(4):fill(3) +> z = torch.Tensor(2,2):fill(5) +> x:addcmul(2, y, z) +> = x + + 32 32 + 32 32 +[torch.Tensor of dimension 2x2] + + +==== [self] div(value) ==== +{{anchor:torch.Tensor.div}} + +Divide all elements in the tensor by the given ''value''. + +==== [self] cdiv(tensor) ==== +{{anchor:torch.Tensor.cdiv}} + +Performs the element-wise division of self by ''tensor''. +The number of elements must match, but sizes do not matter. + +> x = torch.Tensor(2,2):fill(1) +> y = torch.Tensor(4) +> for i=1,4 do y[i] = i end +> x:cdiv(y) +> = x + + 1.0000 0.3333 + 0.5000 0.2500 +[torch.Tensor of dimension 2x2] + + +==== [self] addcdiv(value, tensor1, tensor2) ==== +{{anchor:torch.Tensor.addcdiv}} + +Performs the element-wise division of ''tensor1'' by ''tensor1'', multiply the result by the scalar ''value'' +and add it to self. The number of elements must match, but sizes do not matter. + +> x = torch.Tensor(2,2):fill(1) +> y = torch.Tensor(4) +> z = torch.Tensor(2,2):fill(5) +> for i=1,4 do y[i] = i end +> x:addcdiv(2, y, z) +> = x + + 1.4000 2.2000 + 1.8000 2.6000 +[torch.Tensor of dimension 2x2] + + +==== [number] dot(tensor) ==== +{{anchor:torch.Tensor.dot}} + +Performs the dot product between ''tensor'' and self. The number of elements must match: both tensors are seen +as a 1D vector. + +> x = torch.Tensor(2,2):fill(2) +> y = torch.Tensor(4):fill(3) +> = x:dot(y) +24 + + +==== addmv(value, mat, vec) ==== +{{anchor:torch.Tensor.addmv}} + +Performs a matrix-vector multiplication between ''mat'' (2D tensor) and ''vec'' (1D tensor), multiply by the scalar ''value'' and add +it to self. In other words, + +self = self + value * mat*vec + + +Sizes must respect the matrix-multiplication operation: if ''mat'' is a ''n x m'' matrix, ''vec'' must be +vector of size ''m'' and self must be a vector of size ''n''. + + +> x = torch.Tensor(3):fill(0) +> M = torch.Tensor(3,2):fill(3) +> y = torch.Tensor(2):fill(2) +> x:addT2dotT1(1, M, y) +> = x + + 12 + 12 + 12 +[torch.Tensor of dimension 3] + + +==== addr(value, vec1, vec2) ==== +{{anchor:torch.Tensor.addr}} + +Performs the outer-product between ''vec1'' (1D tensor) and ''vec2'' (1D tensor), multiply the resulting matrix +by the scalar ''value'' and add the result to self (which must be a 2D tensor). +In other words, + +self_ij = self_ij + value * vec1_i * vec2_j + + +If ''vec1'' is a vector of size ''n'' and ''vec2'' is a vector of size ''m'', then self must be a matrix of size +''n x m''. + +> x = torch.Tensor(3) +> y = torch.Tensor(2) +> for i=1,3 do x[i] = i end +> for i=1,2 do y[i] = i end +> M = torch.Tensor(3, 2):zero() +> M:addT1outT1(1, x, y) +> = M + + 1 2 + 2 4 + 3 6 +[torch.Tensor of dimension 3x2] + + +==== addmm(value, mat1, mat2) ==== +{{anchor:torch.Tensor.addmm}} + +Performs a matrix-matrix multiplication between ''mat1'' (2D tensor) and ''mat2'' (2D tensor), multiply the resulting +matrix by the scalar ''value'' and add the result to self (2D tensor). +In other words, + +self = self + value * mat1*mat2 + + +If ''mat1'' is a ''n x m'' matrix, ''mat2'' a ''m x p'' matrix, self must be a ''n x p'' matrix. + +===== Overloaded operators ===== + +It is possible to use basic mathematic operators like ''+'', ''-'', ''/'' and =*= +with tensors. These operators are provided as a convenience. While they +might be handy, they create and return a new tensor containing the +results. They are thus not as fast as the operations available in the +[[#torch.Tensor.BasicOperations.dok|previous section]]. + +==== Addition and substraction ==== + +You can add a tensor to another one with the ''+'' operator. Substraction is done with ''-''. +The number of elements in the tensors must match, but the sizes do not matter. The size +of the returned tensor will be the size of the first tensor. + +> x = torch.Tensor(2,2):fill(2) +> y = torch.Tensor(4):fill(3) +> = x+y + + 5 5 + 5 5 +[torch.Tensor of dimension 2x2] + +> = y-x + + 1 + 1 + 1 + 1 +[torch.Tensor of dimension 4] + + +A scalar might also be added or substracted to a tensor. The scalar might be on the right or left of the operator. + +> x = torch.Tensor(2,2):fill(2) +> = x+3 + + 5 5 + 5 5 +[torch.Tensor of dimension 2x2] + +> = 3-x + + 1 1 + 1 1 +[torch.Tensor of dimension 2x2] + + +==== Negation ==== + +A tensor can be negated with the ''-'' operator placed in front: + +> x = torch.Tensor(2,2):fill(2) +> = -x + +-2 -2 +-2 -2 +[torch.Tensor of dimension 2x2] + + +==== Multiplication ==== + +Multiplication between two tensors is supported with the =*= operators. The result of the multiplication +depends on the sizes of the tensors. +$ 1D and 1D: Returns the dot product between the two tensors (scalar). +$ 2D and 1D: Returns the matrix-vector operation between the two tensors (1D tensor). +$ 2D and 2D: Returns the matrix-matrix operation between the two tensors (2D tensor). +$ 4D and 2D: Returns a tensor product (2D tensor). +Sizes must be relevant for the corresponding operation. + +A tensor might also be multiplied by a scalar. The scalar might be on the right or left of the operator. + +Examples: + +> M = torch.Tensor(2,2):fill(2) +> N = torch.Tensor(2,4):fill(3) +> x = torch.Tensor(2):fill(4) +> y = torch.Tensor(2):fill(5) +> = x*y -- dot product +40 +> = M*x --- matrix-vector + + 16 + 16 +[torch.Tensor of dimension 2] + +> = M*N -- matrix-matrix + + 12 12 12 12 + 12 12 12 12 +[torch.Tensor of dimension 2x4] + + + +==== Division ==== + +Only the division of a tensor by a scalar is supported with the operator ''/''. +Example: + +> x = torch.Tensor(2,2):fill(2) +> = x/3 + + 0.6667 0.6667 + 0.6667 0.6667 +[torch.Tensor of dimension 2x2] + diff --git a/dok/tester.dok b/dok/tester.dok new file mode 100644 index 00000000000..b79166cb1bf --- /dev/null +++ b/dok/tester.dok @@ -0,0 +1,130 @@ +====== Tester ====== +{{anchor:torch.Tester.dok}} + +This class provides a generic unit testing framework. It is already +being used in [[..:nn:index|nn]] package to verify the correctness of classes. + +The framework is generally used as follows. + + +mytest = {} + +tester = torch.Tester() + +function mytest.TestA() + local a = 10 + local b = 10 + tester:asserteq(a,b,'a == b') + tester:assertne(a,b,'a ~= b') +end + +function mytest.TestB() + local a = 10 + local b = 9 + tester:assertlt(a,b,'a < b') + tester:assertgt(a,b,'a > b') +end + +tester:add(mytest) +tester:run() + + + +Running this code will report 2 errors in 2 test functions. Generally it is +better to put single test cases in each test function unless several very related +test cases exit. The error report includes the message and line number of the error. + + + +Running 2 tests +** ==> Done + +Completed 2 tests with 2 errors + +-------------------------------------------------------------------------------- +TestB +a < b + LT(<) violation val=10, condition=9 + ...y/usr.t7/local.master/share/lua/5.1/torch/Tester.lua:23: in function 'assertlt' + [string "function mytest.TestB()..."]:4: in function 'f' + +-------------------------------------------------------------------------------- +TestA +a ~= b + NE(~=) violation val=10, condition=10 + ...y/usr.t7/local.master/share/lua/5.1/torch/Tester.lua:38: in function 'assertne' + [string "function mytest.TestA()..."]:5: in function 'f' + +-------------------------------------------------------------------------------- + + + + +==== torch.Tester() ==== +{{anchor:torch.Tester}} + +Returns a new instance of ''torch.Tester'' class. + +==== add(f, 'name') ==== +{{anchor:torch.Tester.add}} + +Adds a new test function with name ''name''. The test function is stored in ''f''. +The function is supposed to run without any arguments and not return any values. + +==== add(ftable) ==== +{{anchor:torch.Tester.add}} + +Recursively adds all function entries of the table ''ftable'' as tests. This table +can only have functions or nested tables of functions. + +==== assert(condition [, message]) ==== +{{anchor:torch.Tester.assert}} + +Saves an error if condition is not true with the optional message. + +==== assertlt(val, condition [, message]) ==== +{{anchor:torch.Tester.assertlt}} + +Saves an error if ''val < condition'' is not true with the optional message. + +==== assertgt(val, condition [, message]) ==== +{{anchor:torch.Tester.assertgt}} + +Saves an error if ''val > condition'' is not true with the optional message. + +==== assertle(val, condition [, message]) ==== +{{anchor:torch.Tester.assertle}} + +Saves an error if ''val <= condition'' is not true with the optional message. + +==== assertge(val, condition [, message]) ==== +{{anchor:torch.Tester.assertge}} + +Saves an error if ''val >= condition'' is not true with the optional message. + +==== asserteq(val, condition [, message]) ==== +{{anchor:torch.Tester.asserteq}} + +Saves an error if ''val == condition'' is not true with the optional message. + +==== assertne(val, condition [, message]) ==== +{{anchor:torch.Tester.assertne}} + +Saves an error if ''val ~= condition'' is not true with the optional message. + +==== assertTensorEq(ta, tb, condition [, message]) ==== +{{anchor:torch.Tester.assertTensorEq}} + +Saves an error if ''max(abs(ta-tb)) < condition'' is not true with the optional message. + +==== run() ==== +{{anchor:torch.Tester.run}} + +Runs all the test functions that are stored using [[#torch.Tester.add|add()]] function. +While running it reports progress and at the end gives a summary of all errors. + + + + + + diff --git a/dok/timer.dok b/dok/timer.dok new file mode 100644 index 00000000000..20554225771 --- /dev/null +++ b/dok/timer.dok @@ -0,0 +1,43 @@ +====== Timer ====== +{{anchor:torch.Timer.dok}} + +This class is able to measure time (in seconds) elapsed in a particular period. Example: + + timer = torch.Timer() -- the Timer starts to count now + x = 0 + for i=1,1000000 do + x = x + math.sin(x) + end + print('Time elapsed for 1,000,000 sin: ' .. timer:time().real .. ' seconds') + + +===== torch.Timer() ===== +{{anchor:torch.Timer}} + +Returns a new ''Timer''. The timer starts to count the time now. + +===== [self] reset() ===== +{{anchor:torch.Timer.reset}} + +Reset the timer accumulated time to ''0''. If the timer was running, the timer +restarts to count the time now. If the timer was stopped, it stays stopped. + +===== [self] resume() ===== +{{anchor:torch.Timer.resume}} + +Resume a stopped timer. The timer restarts to count the time, and addition +the accumulated time with the time already counted before being stopped. + +===== [self] stop() ===== +{{anchor:torch.Timer.stop}} + +Stop the timer. The accumulated time counted until now is stored. + +===== [table] time() ===== +{{anchor:torch.Timer.time}} + +Returns a table reporting the accumulated time elapsed until now. Following the UNIX shell ''time'' command, +there are three fields in the table: + * ''real'': the wall-clock elapsed time. + * ''user'': the elapsed CPU time. Note that the CPU time of a threaded program sums time spent in all threads. + * ''sys'': the time spent in system usage. diff --git a/dok/utility.dok b/dok/utility.dok new file mode 100644 index 00000000000..0ca45a8ffc2 --- /dev/null +++ b/dok/utility.dok @@ -0,0 +1,234 @@ +====== Torch utility functions ====== +{{anchor:torch.utility.dok}} + +This functions are used in all Torch package for creating and handling classes. +The most interesting function is probably [[#torch.class|torch.class()]] which allows +the user to create easily new classes. [[#torch.typename|torch.typename()]] might +also be interesting to check what is the class of a given Torch object. + +The other functions are more for advanced users. + +===== [metatable] torch.class(name, [parentName]) ===== +{{anchor:torch.class}} + +Creates a new ''Torch'' class called ''name''. If ''parentName'' is provided, the class will inherit +''parentName'' methods. A class is a table which has a particular metatable. + +If ''name'' is of the form ''package.className'' then the class ''className'' will be added to the specified ''package''. +In that case, ''package'' has to be a valid (and already loaded) package. If ''name'' does not contain any ''"."'', +then the class will be defined in the global environment. + +One [or two] (meta)tables are returned. These tables contain all the method +provided by the class [and its parent class if it has been provided]. After +a call to ''torch.class()'' you have to fill-up properly the metatable. + +After the class definition is complete, constructing a new class //name// will be achieved by a call to ''//name//()''. +This call will first call the method __init() if it exists, passing all arguments of ''//name//()''. + + + require "torch" + + -- for naming convenience + do + --- creates a class "Foo" + local Foo = torch.class('Foo') + + --- the initializer + function Foo:__init() + self.contents = "this is some text" + end + + --- a method + function Foo:print() + print(self.contents) + end + + --- another one + function Foo:bip() + print('bip') + end + + end + + --- now create an instance of Foo + foo = Foo() + + --- try it out + foo:print() + + --- create a class torch.Bar which + --- inherits from Foo + do + local Bar, parent = torch.class('torch.Bar', 'Foo') + + --- the initializer + function Bar:__init(stuff) + --- call the parent initializer on ourself + parent.__init(self) + + --- do some stuff + self.stuff = stuff + end + + --- a new method + function Bar:boing() + print('boing!') + end + + --- override parent's method + function Bar:print() + print(self.contents) + print(self.stuff) + end + end + + --- create a new instance and use it + bar = torch.Bar("ha ha!") + bar:print() -- overrided method + bar:boing() -- child method + bar:bip() -- parent's method + + + +For advanced users, it is worth mentionning that ''torch.class()'' actually +calls [[#torch.newmetatable|torch.newmetatable()]]. with a particular +constructor. The constructor creates a Lua table and set the right +metatable on it, and then calls __init() if it exists in the +metatable. It also sets a [[#torch.factory|factory]] field __factory such that it +is possible to create an empty object of this class. + +===== [string] torch.typename(object) ===== +{{anchor:torch.typename}} + +Checks if ''object'' has a metatable. If it does, and if it corresponds to a +''Torch'' class, then returns a string containing the name of the +class. Returns ''nil'' in any other cases. + +A Torch class is a class created with [[#torch.class|torch.class()]] or +[[#torch.newmetatable|torch.newmetatable()]]. + +===== [userdata] torch.typename2id(string) ===== +{{anchor:torch.typename2id}} + +Given a Torch class name specified by ''string'', returns a unique +corresponding id (defined by a ''lightuserdata'' pointing on the internal +structure of the class). This might be useful to do a //fast// check of the +class of an object (if used with [[#torch.id|torch.id()]]), avoiding string +comparisons. + +Returns ''nil'' if ''string'' does not specify a Torch object. + +===== [userdata] torch.id(object) ===== +{{anchor:torch.id}} + +Returns a unique id corresponding to the //class// of the given Torch object. +The id is defined by a ''lightuserdata'' pointing on the internal structure +of the class. + +Returns ''nil'' if ''object'' is not a Torch object. + +This is different from the //object// id returned by [[#torch.pointer|torch.pointer()]]. + +===== [table] torch.newmetatable(name, parentName, constructor) ===== +{{anchor:torch.newmetatable}} + +Register a new metatable as a Torch type with the given string ''name''. The new metatable is returned. + +If the string ''parentName'' is not ''nil'' and is a valid Torch type (previously created +by ''torch.newmetatable()'') then set the corresponding metatable as a metatable to the returned new +metatable. + +If the given ''constructor'' function is not ''nil'', then assign to the variable ''name'' the given constructor. +The given ''name'' might be of the form ''package.className'', in which case the ''className'' will be local to the +specified ''package''. In that case, ''package'' must be a valid and already loaded package. + +===== [function] torch.factory(name) ===== +{{anchor:torch.factory}} + +Returns the factory function of the Torch class ''name''. If the class name is invalid or if the class +has no factory, then returns ''nil''. + +A Torch class is a class created with [[#torch.class|torch.class()]] or +[[#torch.newmetatable|torch.newmetatable()]]. + +A factory function is able to return a new (empty) object of its corresponding class. This is helpful for +[[File#torch.File.serialization|object serialization]]. + +===== [table] torch.getmetatable(string) ===== +{{anchor:torch.getmetatable}} + +Given a ''string'', returns a metatable corresponding to the Torch class described +by ''string''. Returns ''nil'' if the class does not exist. + +A Torch class is a class created with [[#torch.class|torch.class()]] or +[[#torch.newmetatable|torch.newmetatable()]]. + +Example: + +> for k,v in pairs(torch.getmetatable("torch.CharStorage")) do print(k,v) end +__index__ function: 0x1a4ba80 +__typename torch.CharStorage +write function: 0x1a49cc0 +__tostring__ function: 0x1a586e0 +__newindex__ function: 0x1a4ba40 +string function: 0x1a4d860 +__version 1 +copy function: 0x1a49c80 +read function: 0x1a4d840 +__len__ function: 0x1a37440 +fill function: 0x1a375c0 +resize function: 0x1a37580 +__index table: 0x1a4a080 +size function: 0x1a4ba20 + + +===== [boolean] torch.isequal(object1, object2) ===== +{{anchor:torch.isequal}} + +If the two objects given as arguments are ''Lua'' tables (or Torch objects), then returns ''true'' if and only if the +tables (or Torch objects) have the same address in memory. Returns ''false'' in any other cases. + +A Torch class is a class created with [[#TorchClass|torch.class()]] or +[[#torch.newmetatable|torch.newmetatable()]]. + +===== torch.setenv(function or userdata, table) ===== +{{anchor:torch.setenv}} + +Assign ''table'' as the Lua environment of the given ''function'' or the given +''userdata''. To know more about environments, please read the documentation +of [[http://www.lua.org/manual/5.1/manual.html#lua_setfenv|lua_setfenv()]] +and [[http://www.lua.org/manual/5.1/manual.html#lua_getfenv|lua_getfenv()]]. + +===== [table] torch.getenv(function or userdata) ===== +{{anchor:torch.getenv}} + +Returns the Lua ''table'' environment of the given ''function'' or the given +''userdata''. To know more about environments, please read the documentation +of [[http://www.lua.org/manual/5.1/manual.html#lua_setfenv|lua_setfenv()]] +and [[http://www.lua.org/manual/5.1/manual.html#lua_getfenv|lua_getfenv()]]. + +===== [number] torch.version(object) ===== +{{anchor:torch.version}} + +Returns the field __version of a given object. This might +be helpful to handle variations in a class over time. + +===== [number] torch.pointer(object) ===== +{{anchor:torch.pointer}} + +Returns a unique id (pointer) of the given ''object'', which can be a Torch +object, a table, a thread or a function. + +This is different from the //class// id returned by [[#torch.id|torch.id()]]. + +===== [object] torch.setmetatable(table, classname) ===== +{{anchor:torch.setmetatable}} + +Set the metatable of the given ''table'' to the metatable of the Torch object named ''classname''. +This function has to be used with a lot of care. + +===== [table] torch.getconstructortable(string) ===== +{{anchor:torch.getconstructortable}} + +BUGGY +Return the constructor table of the Torch class specified by ''string'. diff --git a/general.h b/general.h new file mode 100644 index 00000000000..da439ca6b27 --- /dev/null +++ b/general.h @@ -0,0 +1,18 @@ +#ifndef TORCH_GENERAL_INC +#define TORCH_GENERAL_INC + +#include +#include + +#include "luaT.h" +#include "TH.h" + +#ifdef _MSC_VER + +#define snprintf _snprintf +#define popen _popen +#define pclose _pclose + +#endif + +#endif diff --git a/generic/Storage.c b/generic/Storage.c new file mode 100644 index 00000000000..8612d3b35d4 --- /dev/null +++ b/generic/Storage.c @@ -0,0 +1,221 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/Storage.c" +#else + +static int torch_Storage_(new)(lua_State *L) +{ + THStorage *storage; + if(lua_type(L, 1) == LUA_TSTRING) + { + const char *fileName = luaL_checkstring(L, 1); + int isShared = luaT_optboolean(L, 2, 0); + storage = THStorage_(newWithMapping)(fileName, isShared); } + else if(lua_type(L, 1) == LUA_TTABLE) + { + long size = lua_objlen(L, 1); + long i; + storage = THStorage_(newWithSize)(size); + for(i = 1; i <= size; i++) + { + lua_rawgeti(L, 1, i); + if(!lua_isnumber(L, -1)) + { + THStorage_(free)(storage); + luaL_error(L, "element at index %d is not a number", i); + } + THStorage_(set)(storage, i-1, (real)lua_tonumber(L, -1)); + lua_pop(L, 1); + } + } + else + { + long size = luaL_optlong(L, 1, 0); + storage = THStorage_(newWithSize)(size); + } + luaT_pushudata(L, storage, torch_Storage_id); + return 1; +} + +static int torch_Storage_(free)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage_(free)(storage); + return 0; +} + +static int torch_Storage_(resize)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + long size = luaL_checklong(L, 2); +/* int keepContent = luaT_optboolean(L, 3, 0); */ + THStorage_(resize)(storage, size);/*, keepContent); */ + lua_settop(L, 1); + return 1; +} + +static int torch_Storage_(copy)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + void *src; + if( (src = luaT_toudata(L, 2, torch_Storage_id)) ) + THStorage_(copy)(storage, src); + else if( (src = luaT_toudata(L, 2, torch_ByteStorage_id)) ) + THStorage_(copyByte)(storage, src); + else if( (src = luaT_toudata(L, 2, torch_CharStorage_id)) ) + THStorage_(copyChar)(storage, src); + else if( (src = luaT_toudata(L, 2, torch_ShortStorage_id)) ) + THStorage_(copyShort)(storage, src); + else if( (src = luaT_toudata(L, 2, torch_IntStorage_id)) ) + THStorage_(copyInt)(storage, src); + else if( (src = luaT_toudata(L, 2, torch_LongStorage_id)) ) + THStorage_(copyLong)(storage, src); + else if( (src = luaT_toudata(L, 2, torch_FloatStorage_id)) ) + THStorage_(copyFloat)(storage, src); + else if( (src = luaT_toudata(L, 2, torch_DoubleStorage_id)) ) + THStorage_(copyDouble)(storage, src); + else + luaL_typerror(L, 2, "torch.*Storage"); + lua_settop(L, 1); + return 1; +} + +static int torch_Storage_(fill)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + double value = luaL_checknumber(L, 2); + THStorage_(fill)(storage, (real)value); + lua_settop(L, 1); + return 1; +} + +static int torch_Storage_(__len__)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + lua_pushnumber(L, storage->size); + return 1; +} + +static int torch_Storage_(__newindex__)(lua_State *L) +{ + if(lua_isnumber(L, 2)) + { + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + long index = luaL_checklong(L, 2) - 1; + double number = luaL_checknumber(L, 3); + THStorage_(set)(storage, index, (real)number); + lua_pushboolean(L, 1); + } + else + lua_pushboolean(L, 0); + + return 1; +} + +static int torch_Storage_(__index__)(lua_State *L) +{ + if(lua_isnumber(L, 2)) + { + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + long index = luaL_checklong(L, 2) - 1; + lua_pushnumber(L, THStorage_(get)(storage, index)); + lua_pushboolean(L, 1); + return 2; + } + else + { + lua_pushboolean(L, 0); + return 1; + } +} + +#if defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_BYTE) +static int torch_Storage_(string)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + if(lua_isstring(L, -1)) + { + size_t len = 0; + const char *str = lua_tolstring(L, -1, &len); + THStorage_(resize)(storage, len); + memmove(storage->data, str, len); + lua_settop(L, 1); + } + else + lua_pushlstring(L, (char*)storage->data, storage->size); + + return 1; /* either storage or string */ +} +#endif + +static int torch_Storage_(totable)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + long i; + + lua_newtable(L); + for(i = 0; i < storage->size; i++) + { + lua_pushnumber(L, (lua_Number)storage->data[i]); + lua_rawseti(L, -2, i+1); + } + return 1; +} + +static int torch_Storage_(factory)(lua_State *L) +{ + THStorage *storage = THStorage_(new)(); + luaT_pushudata(L, storage, torch_Storage_id); + return 1; +} + +static int torch_Storage_(write)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THFile *file = luaT_checkudata(L, 2, torch_File_id); + + THFile_writeLongScalar(file, storage->size); + THFile_writeRealRaw(file, storage->data, storage->size); + + return 0; +} + +static int torch_Storage_(read)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THFile *file = luaT_checkudata(L, 2, torch_File_id); + long size = THFile_readLongScalar(file); + + THStorage_(resize)(storage, size); + THFile_readRealRaw(file, storage->data, storage->size); + + return 0; +} + +static const struct luaL_Reg torch_Storage_(_) [] = { + {"size", torch_Storage_(__len__)}, + {"__len__", torch_Storage_(__len__)}, + {"__newindex__", torch_Storage_(__newindex__)}, + {"__index__", torch_Storage_(__index__)}, + {"resize", torch_Storage_(resize)}, + {"fill", torch_Storage_(fill)}, + {"copy", torch_Storage_(copy)}, + {"totable", torch_Storage_(totable)}, + {"write", torch_Storage_(write)}, + {"read", torch_Storage_(read)}, +#if defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_BYTE) + {"string", torch_Storage_(string)}, +#endif + {NULL, NULL} +}; + +void torch_Storage_(init)(lua_State *L) +{ + torch_File_id = luaT_checktypename2id(L, "torch.File"); + + torch_Storage_id = luaT_newmetatable(L, STRING_torchStorage, NULL, + torch_Storage_(new), torch_Storage_(free), torch_Storage_(factory)); + luaL_register(L, NULL, torch_Storage_(_)); + lua_pop(L, 1); +} + +#endif diff --git a/generic/Tensor.c b/generic/Tensor.c new file mode 100644 index 00000000000..575d4c53216 --- /dev/null +++ b/generic/Tensor.c @@ -0,0 +1,939 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/Tensor.c" +#else + +static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride, + THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_); + +static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_); + +static int torch_Tensor_(size)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + if(lua_isnumber(L,2)) + { + int dim = luaL_checkint(L, 2)-1; + luaL_argcheck(L, dim >= 0 && dim < tensor->nDimension, 2, "out of range"); + lua_pushnumber(L, tensor->size[dim]); + } + else + { + THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension); + memmove(storage->data, tensor->size, sizeof(long)*tensor->nDimension); + luaT_pushudata(L, storage, torch_LongStorage_id); + } + return 1; +} + +static int torch_Tensor_(stride)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + if(lua_isnumber(L,2)) + { + int dim = luaL_checkint(L, 2)-1; + luaL_argcheck(L, dim >= 0 && dim < tensor->nDimension, 2, "out of range"); + lua_pushnumber(L, tensor->stride[dim]); + } + else + { + THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension); + memmove(storage->data, tensor->stride, sizeof(long)*tensor->nDimension); + luaT_pushudata(L, storage, torch_LongStorage_id); + } + return 1; +} + +static int torch_Tensor_(nDimension)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + lua_pushnumber(L, tensor->nDimension); + return 1; +} + +static int torch_Tensor_(storage)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + + if(tensor->storage) + { + THStorage_(retain)(tensor->storage); + luaT_pushudata(L, tensor->storage, torch_Storage_id); + } + else + lua_pushnil(L); + + return 1; +} + +static int torch_Tensor_(storageOffset)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + lua_pushnumber(L, tensor->storageOffset+1); + return 1; +} + +static int torch_Tensor_(new)(lua_State *L) +{ + THTensor *tensor; + long storageOffset; + THLongStorage *size, *stride; + + if(lua_type(L, 1) == LUA_TTABLE) + { + long i, j; + THLongStorage *counter; + long si = 0; + int dimension = 0; + int is_finished = 0; + + lua_settop(L, 1); + size = THLongStorage_new(); + + while( (lua_type(L, -1) == LUA_TTABLE) && (lua_objlen(L, -1) > 0) ) + { + THLongStorage_resize(size, dimension+1); + size->data[dimension] = lua_objlen(L, -1); + dimension++; + lua_rawgeti(L, -1, 1); + } + lua_pop(L, 1); + + counter = THLongStorage_newWithSize(size->size); + THLongStorage_fill(counter, 0); + + tensor = THTensor_(newWithSize)(size, NULL); + + if(size->size == 0) + is_finished = 1; + + while(!is_finished) + { + if(!lua_istable(L, -1)) + { + THLongStorage_free(size); + THLongStorage_free(counter); + THTensor_(free)(tensor); + luaL_error(L, "invalid tensor definition"); + } + + if(lua_objlen(L, -1) != size->data[size->size-1]) + { + THLongStorage_free(size); + THLongStorage_free(counter); + THTensor_(free)(tensor); + luaL_error(L, "invalid tensor sizes"); + } + + for(i = 0; i < size->data[size->size-1]; i++) + { + lua_rawgeti(L, -1, i+1); + if(!lua_isnumber(L, -1)) + { + THLongStorage_free(size); + THLongStorage_free(counter); + THTensor_(free)(tensor); + luaL_error(L, "invalid element (not a number)"); + } + THStorage_(set)(THTensor_(storage)(tensor), si++, (real)lua_tonumber(L, -1)); + lua_pop(L, 1); + } + + if(size->size == 1) + break; + + for(i = size->size-2; i >= 0; i--) + { + if(++counter->data[i] == size->data[i]) + { + if(i == 0) + { + is_finished = 1; + break; + } + else + { + counter->data[i] = 0; + lua_pop(L, 1); + } + } + else + { + lua_pop(L, 1); + for(j = i; j < size->size-1; j++) + { + if(!lua_istable(L, -1)) + { + THLongStorage_free(size); + THLongStorage_free(counter); + THTensor_(free)(tensor); + luaL_error(L, "invalid tensor definition"); + } + if(lua_objlen(L, -1) != size->data[j]) + { + THLongStorage_free(size); + THLongStorage_free(counter); + THTensor_(free)(tensor); + luaL_error(L, "invalid tensor sizes"); + } + lua_rawgeti(L, -1, counter->data[j]+1); + } + break; + } + } + } + + THLongStorage_free(size); + THLongStorage_free(counter); + } + else + { + THStorage *storage; + + torch_Tensor_(c_readTensorStorageSizeStride)(L, 1, 1, 1, 1, 1, + &storage, &storageOffset, &size, &stride); + + tensor = THTensor_(newWithStorage)(storage, storageOffset, size, stride); + + THLongStorage_free(size); + THLongStorage_free(stride); + } + + luaT_pushudata(L, tensor, torch_Tensor_id); + return 1; +} + +static int torch_Tensor_(set)(lua_State *L) +{ + THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); + THStorage *storage; + long storageOffset; + THLongStorage *size, *stride; + + torch_Tensor_(c_readTensorStorageSizeStride)(L, 2, 1, 1, 1, 1, + &storage, &storageOffset, &size, &stride); + + THTensor_(setStorage)(self, storage, storageOffset, size, stride); + + THLongStorage_free(size); + THLongStorage_free(stride); + + lua_settop(L, 1); + return 1; +} + +static int torch_Tensor_(clone)(lua_State *L) +{ + THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); + self = THTensor_(newClone)(self); + luaT_pushudata(L, self, torch_Tensor_id); + return 1; +} + +static int torch_Tensor_(contiguous)(lua_State *L) +{ + THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); + self = THTensor_(newContiguous)(self); + luaT_pushudata(L, self, torch_Tensor_id); + return 1; +} + +/* Resize */ +static int torch_Tensor_(resizeAs)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id); + THTensor_(resizeAs)(tensor, src); + lua_settop(L, 1); + return 1; +} + +static int torch_Tensor_(resize)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THLongStorage *size, *stride; + + torch_Tensor_(c_readSizeStride)(L, 2, 0, &size, &stride); + + THTensor_(resize)(tensor, size, stride); + + THLongStorage_free(size); + THLongStorage_free(stride); + + lua_settop(L, 1); + return 1; +} + +static int torch_Tensor_(narrow)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + int dimension = luaL_checkint(L, 2)-1; + long firstIndex = luaL_checklong(L, 3)-1; + long size = luaL_checklong(L, 4); + +/* THArgCheck( (dimension >= 0) && (dimension < tensor->nDimension), 2, "out of range"); + THArgCheck( (firstIndex >= 0) && (firstIndex < tensor->size[dimension]), 3, "out of range"); + THArgCheck( (size > 0) && (firstIndex+size <= tensor->size[dimension]), 4, "out of range"); +*/ + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, dimension, firstIndex, size); + luaT_pushudata(L, tensor, torch_Tensor_id); + return 1; +} + +static int torch_Tensor_(sub)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + long d0s = -1, d0e = -1, d1s = -1, d1e = -1, d2s = -1, d2e = -1, d3s = -1, d3e = -1; + + d0s = luaL_checklong(L, 2)-1; + d0e = luaL_checklong(L, 3)-1; + if(d0s < 0) + d0s += tensor->size[0]+1; + if(d0e < 0) + d0e += tensor->size[0]+1; + luaL_argcheck(L, tensor->nDimension > 0, 2, "invalid dimension"); + luaL_argcheck(L, d0s >= 0 && d0s < tensor->size[0], 2, "out of range"); + luaL_argcheck(L, d0e >= 0 && d0e < tensor->size[0], 3, "out of range"); + luaL_argcheck(L, d0e >= d0s, 3, "end smaller than beginning"); + + if(!lua_isnone(L, 4)) + { + d1s = luaL_checklong(L, 4)-1; + d1e = luaL_checklong(L, 5)-1; + if(d1s < 0) + d1s += tensor->size[1]+1; + if(d1e < 0) + d1e += tensor->size[1]+1; + luaL_argcheck(L, tensor->nDimension > 1, 4, "invalid dimension"); + luaL_argcheck(L, d1s >= 0 && d1s < tensor->size[1], 4, "out of range"); + luaL_argcheck(L, d1e >= 0 && d1e < tensor->size[1], 5, "out of range"); + luaL_argcheck(L, d1e >= d1s, 5, "end smaller than beginning"); + + if(!lua_isnone(L, 6)) + { + d2s = luaL_checklong(L, 6)-1; + d2e = luaL_checklong(L, 7)-1; + if(d2s < 0) + d2s += tensor->size[2]+1; + if(d2e < 0) + d2e += tensor->size[2]+1; + luaL_argcheck(L, tensor->nDimension > 2, 6, "invalid dimension"); + luaL_argcheck(L, d2s >= 0 && d2s < tensor->size[2], 6, "out of range"); + luaL_argcheck(L, d2e >= 0 && d2e < tensor->size[2], 7, "out of range"); + luaL_argcheck(L, d2e >= d2s, 7, "end smaller than beginning"); + + if(!lua_isnone(L, 8)) + { + d3s = luaL_checklong(L, 8)-1; + d3e = luaL_checklong(L, 9)-1; + if(d3s < 0) + d3s += tensor->size[3]+1; + if(d3e < 0) + d3e += tensor->size[3]+1; + luaL_argcheck(L, tensor->nDimension > 3, 8, "invalid dimension"); + luaL_argcheck(L, d3s >= 0 && d3s < tensor->size[3], 8, "out of range"); + luaL_argcheck(L, d3e >= 0 && d3e < tensor->size[3], 9, "out of range"); + luaL_argcheck(L, d3e >= d3s, 9, "end smaller than beginning"); + } + } + } + + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, d0s, d0e-d0s+1); + if(d1s >= 0) + THTensor_(narrow)(tensor, NULL, 1, d1s, d1e-d1s+1); + if(d2s >= 0) + THTensor_(narrow)(tensor, NULL, 2, d2s, d2e-d2s+1); + if(d3s >= 0) + THTensor_(narrow)(tensor, NULL, 3, d3s, d3e-d3s+1); + luaT_pushudata(L, tensor, torch_Tensor_id); + return 1; +} + +static int torch_Tensor_(select)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + int dimension = luaL_checkint(L, 2)-1; + long sliceIndex = luaL_checklong(L, 3)-1; + +/* THArgCheck(src->nDimension > 1, 1, "cannot select on a vector"); + THArgCheck((dimension >= 0) && (dimension < src->nDimension), 2, "out of range"); + THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 3, "out of range"); +*/ + + if(tensor->nDimension > 1) + { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(select)(tensor, NULL, dimension, sliceIndex); + luaT_pushudata(L, tensor, torch_Tensor_id); + } + else + { + THArgCheck(tensor->nDimension == 1, 1, "empty Tensor"); + lua_pushnumber(L, THTensor_(get1d)(tensor, sliceIndex)); + } + + return 1; +} + + +static int torch_Tensor_(transpose)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + int dimension1 = luaL_checkint(L, 2)-1; + int dimension2 = luaL_checkint(L, 3)-1; + +/* + THArgCheck( (dimension1 >= 0) && (dimension1 < src->nDimension), 2, "out of range"); + THArgCheck( (dimension2 >= 0) && (dimension2 < src->nDimension), 3, "out of range"); +*/ + + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(transpose)(tensor, NULL, dimension1, dimension2); + luaT_pushudata(L, tensor, torch_Tensor_id); + return 1; +} + +static int torch_Tensor_(t)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + + luaL_argcheck(L, tensor->nDimension == 2, 1, "Tensor must have 2 dimensions"); + + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(transpose)(tensor, NULL, 0, 1); + luaT_pushudata(L, tensor, torch_Tensor_id); + return 1; +} + +static int torch_Tensor_(unfold)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + int dimension = luaL_checkint(L, 2)-1; + long size = luaL_checklong(L, 3); + long step = luaL_checklong(L, 4); + +/* + THArgCheck( (src->nDimension > 0), 1, "cannot unfold an empty tensor"); + THArgCheck(dimension < src->nDimension, 2, "out of range"); + THArgCheck(size <= src->size[dimension], 3, "out of range"); +*/ + + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(unfold)(tensor, NULL, dimension, size, step); + luaT_pushudata(L, tensor, torch_Tensor_id); + return 1; +} + +/* is contiguous? [a bit like in TnXIterator] */ +static int torch_Tensor_(isContiguous)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + lua_pushboolean(L, THTensor_(isContiguous)(tensor)); + return 1; +} + +static int torch_Tensor_(nElement)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + lua_pushnumber(L, THTensor_(nElement)(tensor)); + return 1; +} + +static int torch_Tensor_(copy)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + void *src; + if( (src = luaT_toudata(L, 2, torch_Tensor_id)) ) + THTensor_(copy)(tensor, src); + else if( (src = luaT_toudata(L, 2, torch_ByteTensor_id)) ) + THTensor_(copyByte)(tensor, src); + else if( (src = luaT_toudata(L, 2, torch_CharTensor_id)) ) + THTensor_(copyChar)(tensor, src); + else if( (src = luaT_toudata(L, 2, torch_ShortTensor_id)) ) + THTensor_(copyShort)(tensor, src); + else if( (src = luaT_toudata(L, 2, torch_IntTensor_id)) ) + THTensor_(copyInt)(tensor, src); + else if( (src = luaT_toudata(L, 2, torch_LongTensor_id)) ) + THTensor_(copyLong)(tensor, src); + else if( (src = luaT_toudata(L, 2, torch_FloatTensor_id)) ) + THTensor_(copyFloat)(tensor, src); + else if( (src = luaT_toudata(L, 2, torch_DoubleTensor_id)) ) + THTensor_(copyDouble)(tensor, src); + else + luaL_typerror(L, 2, "torch.*Tensor"); + lua_settop(L, 1); + return 1; +} + +static int torch_Tensor_(__newindex__)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THLongStorage *idx = NULL; + + if(lua_isnumber(L, 2)) + { + long index = luaL_checklong(L,2)-1; + void *src; + if (lua_isnumber(L,3)) { + real value = (real)luaL_checknumber(L,3); + luaL_argcheck(L, tensor->nDimension == 1, 1, "must be a one dimensional tensor"); + luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range"); + THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value); + } else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, index, 1); + THTensor_(copy)(tensor, src); + } else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, index, 1); + THTensor_(copyByte)(tensor, src); + } else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, index, 1); + THTensor_(copyChar)(tensor, src); + } else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, index, 1); + THTensor_(copyShort)(tensor, src); + } else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, index, 1); + THTensor_(copyInt)(tensor, src); + } else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, index, 1); + THTensor_(copyLong)(tensor, src); + } else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, index, 1); + THTensor_(copyFloat)(tensor, src); + } else { + luaL_typerror(L, 3, "torch.*Tensor"); + } + lua_pushboolean(L, 1); + } + else if((idx = luaT_toudata(L, 2, torch_LongStorage_id))) + { + long index = THTensor_(storageOffset)(tensor); + real value = (real)luaL_checknumber(L,3); + int dim; + + luaL_argcheck(L, idx->size == tensor->nDimension, 2, "invalid size"); + + for(dim = 0; dim < idx->size; dim++) + { + long z = idx->data[dim]-1; + luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); + index += z*tensor->stride[dim]; + } + + THStorage_(set)(tensor->storage, index, value); + lua_pushboolean(L, 1); + } + else if(lua_istable(L, 2)) + { + long index = THTensor_(storageOffset)(tensor); + real value = (real)luaL_checknumber(L,3); + int dim; + + luaL_argcheck(L, lua_objlen(L,2) == tensor->nDimension, 2, "invalid size"); + + for(dim = 0; dim < tensor->nDimension; dim++) + { + long z; + + lua_rawgeti(L, 2, dim+1); + if(!lua_isnumber(L, -1)) + luaL_error(L, "number expected for each dimension"); + + z = lua_tonumber(L, -1)-1; + lua_pop(L, 1); + + luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); + index += z*tensor->stride[dim]; + } + THStorage_(set)(tensor->storage, index, value); + lua_pushboolean(L, 1); + } + else + lua_pushboolean(L, 0); + + return 1; +} + +static int torch_Tensor_(__index__)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THLongStorage *idx = NULL; + + if(lua_isnumber(L, 2)) + { + long index = luaL_checklong(L,2)-1; + + luaL_argcheck(L, tensor->nDimension > 0, 1, "empty tensor"); + luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range"); + + if(tensor->nDimension == 1) + { + lua_pushnumber(L, THStorage_(get)(tensor->storage, tensor->storageOffset+index*tensor->stride[0])); + } + else + { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(select)(tensor, NULL, 0, index); + luaT_pushudata(L, tensor, torch_Tensor_id); + } + lua_pushboolean(L, 1); + return 2; + } + else if((idx = luaT_toudata(L, 2, torch_LongStorage_id))) + { + long index = THTensor_(storageOffset)(tensor); + int dim; + + luaL_argcheck(L, idx->size == tensor->nDimension, 2, "invalid size"); + + for(dim = 0; dim < idx->size; dim++) + { + long z = idx->data[dim]-1; + luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); + index += z*tensor->stride[dim]; + } + lua_pushnumber(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index)); + lua_pushboolean(L, 1); + return 2; + } + else if(lua_istable(L, 2)) + { + long index = THTensor_(storageOffset)(tensor); + int dim; + + luaL_argcheck(L, lua_objlen(L,2) == tensor->nDimension, 2, "invalid size"); + + for(dim = 0; dim < tensor->nDimension; dim++) + { + long z; + + lua_rawgeti(L, 2, dim+1); + if(!lua_isnumber(L, -1)) + luaL_error(L, "number expected for each dimension"); + + z = lua_tonumber(L, -1)-1; + lua_pop(L, 1); + + luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); + index += z*tensor->stride[dim]; + } + lua_pushnumber(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index)); + lua_pushboolean(L, 1); + return 2; + } + else + { + lua_pushboolean(L, 0); + return 1; + } +} + +static int torch_Tensor_(free)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor_(free)(tensor); + return 0; +} + +/* helpful functions */ +static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_) +{ + THLongStorage *size = NULL; + THLongStorage *stride = NULL; + + if( (size = luaT_toudata(L, index, torch_LongStorage_id)) ) + { + if(!lua_isnoneornil(L, index+1)) + { + if( (stride = luaT_toudata(L, index+1, torch_LongStorage_id)) ) + luaL_argcheck(L, stride->size == size->size, index+1, "provided stride and size are inconsistent"); + else + luaL_argcheck(L, 0, index+1, "torch.LongStorage expected"); + } + THLongStorage_retain(size); + if(stride) + THLongStorage_retain(stride); + } + else + { + int i; + + size = THLongStorage_newWithSize(8); + stride = THLongStorage_newWithSize(8); + THLongStorage_fill(size, -1); + THLongStorage_fill(stride, -1); + + if(allowStride) + { + for(i = 0; i < 8; i++) + { + if(lua_isnone(L, index+2*i)) + break; + size->data[i] = luaL_checklong(L, index+2*i); + + if(lua_isnone(L, index+2*i+1)) + break; + stride->data[i] = luaL_checklong(L, index+2*i+1); + } + } + else + { + for(i = 0; i < 8; i++) + { + if(lua_isnone(L, index+i)) + break; + size->data[i] = luaL_checklong(L, index+i); + } + } + } + + *size_ = size; + *stride_ = stride; +} + +static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride, + THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_) +{ + static char errMsg[64]; + THTensor *src = NULL; + THStorage *storage = NULL; + + int arg1Type = lua_type(L, index); + + if( allowNone && (arg1Type == LUA_TNONE) ) + { + *storage_ = NULL; + *storageOffset_ = 0; + *size_ = NULL; + *stride_ = NULL; + return; + } + else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor_id)) ) + { + *storage_ = src->storage; + *storageOffset_ = src->storageOffset; + *size_ = THTensor_(newSizeOf)(src); + *stride_ = THTensor_(newStrideOf)(src); + return; + } + else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage_id)) ) + { + *storage_ = storage; + if(lua_isnone(L, index+1)) + { + *storageOffset_ = 0; + *size_ = THLongStorage_newWithSize1(storage->size); + *stride_ = THLongStorage_newWithSize1(1); + } + else + { + *storageOffset_ = luaL_checklong(L, index+1)-1; + torch_Tensor_(c_readSizeStride)(L, index+2, allowStride, size_, stride_); + } + return; + } + else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, torch_LongStorage_id)) ) + { + *storage_ = NULL; + *storageOffset_ = 0; + torch_Tensor_(c_readSizeStride)(L, index, 0, size_, stride_); + + return; + } + + *storage_ = NULL; + *storageOffset_ = 0; + + sprintf(errMsg, "expecting number%s%s", (allowTensor ? " or Tensor" : ""), (allowStorage ? " or Storage" : "")); + luaL_argcheck(L, 0, index, errMsg); +} + +static int torch_Tensor_(apply)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + luaL_checktype(L, 2, LUA_TFUNCTION); + lua_settop(L, 2); + + TH_TENSOR_APPLY(real, tensor, + lua_pushvalue(L, 2); + lua_pushnumber(L, *tensor_data); + lua_call(L, 1, 1); + if(lua_isnumber(L, 3)) + { + *tensor_data = (real)lua_tonumber(L, 3); + lua_pop(L, 1); + } + else if(lua_isnil(L, 3)) + lua_pop(L, 1); + else + luaL_error(L, "given function should return a number or nil");); + + lua_settop(L, 1); + return 1; +} + +static int torch_Tensor_(map)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id); + luaL_checktype(L, 3, LUA_TFUNCTION); + lua_settop(L, 3); + + TH_TENSOR_APPLY2(real, tensor, real, src, + lua_pushvalue(L, 3); + lua_pushnumber(L, *tensor_data); + lua_pushnumber(L, *src_data); + lua_call(L, 2, 1); + if(lua_isnumber(L, 4)) + { + *tensor_data = (real)lua_tonumber(L, 4); + lua_pop(L, 1); + } + else if(lua_isnil(L, 4)) + lua_pop(L, 1); + else + luaL_error(L, "given function should return a number or nil");); + + lua_settop(L, 1); + return 1; +} + +static int torch_Tensor_(map2)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor_id); + THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor_id); + luaL_checktype(L, 4, LUA_TFUNCTION); + lua_settop(L, 4); + + TH_TENSOR_APPLY3(real, tensor, real, src1, real, src2, + lua_pushvalue(L, 4); + lua_pushnumber(L, *tensor_data); + lua_pushnumber(L, *src1_data); + lua_pushnumber(L, *src2_data); + lua_call(L, 3, 1); + if(lua_isnumber(L, 5)) + { + *tensor_data = (real)lua_tonumber(L, 5); + lua_pop(L, 1); + } + else if(lua_isnil(L, 5)) + lua_pop(L, 1); + else + luaL_error(L, "given function should return a number or nothing");); + + lua_settop(L, 1); + return 1; +} + +static int torch_Tensor_(factory)(lua_State *L) +{ + THTensor *tensor = THTensor_(new)(); + luaT_pushudata(L, tensor, torch_Tensor_id); + return 1; +} + +static int torch_Tensor_(write)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THFile *file = luaT_checkudata(L, 2, torch_File_id); + + THFile_writeIntScalar(file, tensor->nDimension); + THFile_writeLongRaw(file, tensor->size, tensor->nDimension); + THFile_writeLongRaw(file, tensor->stride, tensor->nDimension); + THFile_writeLongScalar(file, tensor->storageOffset+1); /* to respect Lua convention */ + + lua_getfield(L, 2, "writeObject"); /* the method */ + lua_pushvalue(L, 2); /* the file */ + /* the storage */ + if(tensor->storage) + { + THStorage_(retain)(tensor->storage); + luaT_pushudata(L, tensor->storage, torch_Storage_id); + } + else + lua_pushnil(L); + + lua_call(L, 2, 0); /* call the method */ + + return 0; +} + +static int torch_Tensor_(read)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THFile *file = luaT_checkudata(L, 2, torch_File_id); + + tensor->nDimension = THFile_readIntScalar(file); + tensor->size = THAlloc(sizeof(long)*tensor->nDimension); + tensor->stride = THAlloc(sizeof(long)*tensor->nDimension); + THFile_readLongRaw(file, tensor->size, tensor->nDimension); + THFile_readLongRaw(file, tensor->stride, tensor->nDimension); + tensor->storageOffset = THFile_readLongScalar(file); + tensor->storageOffset--; /* to respect Lua convention */ + + lua_getfield(L, 2, "readObject"); /* the method */ + lua_pushvalue(L, 2); /* the file */ + lua_call(L, 1, 1); /* call the method */ + + tensor->storage = luaT_toudata(L, -1, torch_Storage_id); + if(tensor->storage) + THStorage_(retain)(tensor->storage); + + return 0; +} + +static const struct luaL_Reg torch_Tensor_(_) [] = { + {"contiguous", torch_Tensor_(contiguous)}, + {"size", torch_Tensor_(size)}, + {"__len__", torch_Tensor_(size)}, + {"stride", torch_Tensor_(stride)}, + {"dim", torch_Tensor_(nDimension)}, + {"nDimension", torch_Tensor_(nDimension)}, + {"set", torch_Tensor_(set)}, + {"storage", torch_Tensor_(storage)}, + {"storageOffset", torch_Tensor_(storageOffset)}, + {"clone", torch_Tensor_(clone)}, + {"contiguous", torch_Tensor_(contiguous)}, + {"resizeAs", torch_Tensor_(resizeAs)}, + {"resize", torch_Tensor_(resize)}, + {"narrow", torch_Tensor_(narrow)}, + {"sub", torch_Tensor_(sub)}, + {"select", torch_Tensor_(select)}, + {"transpose", torch_Tensor_(transpose)}, + {"t", torch_Tensor_(t)}, + {"unfold", torch_Tensor_(unfold)}, + {"isContiguous", torch_Tensor_(isContiguous)}, + {"nElement", torch_Tensor_(nElement)}, + {"copy", torch_Tensor_(copy)}, + {"apply", torch_Tensor_(apply)}, + {"map", torch_Tensor_(map)}, + {"map2", torch_Tensor_(map2)}, + {"read", torch_Tensor_(read)}, + {"write", torch_Tensor_(write)}, + {"__index__", torch_Tensor_(__index__)}, + {"__newindex__", torch_Tensor_(__newindex__)}, + {NULL, NULL} +}; + +void torch_Tensor_(init)(lua_State *L) +{ + torch_File_id = luaT_checktypename2id(L, "torch.File"); + torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); + torch_Storage_id = luaT_checktypename2id(L, STRING_torchStorage); + + torch_Tensor_id = luaT_newmetatable(L, STRING_torchTensor, NULL, + torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory)); + luaL_register(L, NULL, torch_Tensor_(_)); + lua_pop(L, 1); +} + +#endif diff --git a/generic/TensorConv.c b/generic/TensorConv.c new file mode 100644 index 00000000000..98e8f4fb205 --- /dev/null +++ b/generic/TensorConv.c @@ -0,0 +1,175 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/TensorConv.c" +#else + +static int torch_(convxcorr2)(lua_State *L, const char* ktype) +{ + int narg = lua_gettop(L); + THTensor *r_ = NULL; + THTensor *im = NULL; + THTensor *ker = NULL; + char type[2]; + int rgiven = 0; + + type[0] = 'v'; + type[1] = ktype[0]; + + if (narg == 2 + && (ker = luaT_toudata(L,2,torch_(Tensor_id))) + && (im = luaT_toudata(L,1,torch_(Tensor_id)))) + { + } + else if (narg == 3 + && (lua_type(L,3) == LUA_TSTRING) + && (ker = luaT_toudata(L,2,torch_(Tensor_id))) + && (im = luaT_toudata(L,1,torch_(Tensor_id)))) + { + type[0] = *(luaL_checkstring(L,3)); + luaL_argcheck(L, (type[0] == 'v' || type[0] == 'V' || type[0] == 'f' || type[0] == 'F'), + 3, "[Tensor, ] Tensor, Tensor [, x or c]"); + if (type[0] == 'V') type[0] = 'v'; + if (type[0] == 'F') type[0] = 'f'; + } + else if (narg == 4 + && (type[0] = *(luaL_checkstring(L,4))) + && (ker = luaT_toudata(L,3,torch_(Tensor_id))) + && (im = luaT_toudata(L,2,torch_(Tensor_id))) + && (r_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + rgiven = 1; + } + else + { + luaL_error(L,"[Tensor, ] Tensor, Tensor [, x or c]"); + } + + if (!r_) r_ = THTensor_(new)(); + + if (im->nDimension == 2 && ker->nDimension == 2) + { + THTensor_(conv2Dmul)(r_,0.0,1.0,im,ker,1,1,type); + } + else if (im->nDimension == 3 && ker->nDimension == 3) + { + THTensor_(conv2Dcmul)(r_,0.0,1.0,im,ker,1,1,type); + } + else if (im->nDimension == 3 && ker->nDimension == 4) + { + THTensor_(conv2Dmv)(r_,0.0,1.0,im,ker,1,1,type); + } + else + { + luaL_error(L," (2D,2D) or (3D,3D) or (3D,4D) "); + } + + pushreturn(rgiven, r_, torch_(Tensor_id)); + + return 1; +} + +static int torch_(convxcorr3)(lua_State *L, char* ktype) +{ + int narg = lua_gettop(L); + THTensor *r_ = NULL; + THTensor *im = NULL; + THTensor *ker = NULL; + char type[2]; + int rgiven = 0; + + type[0] = 'v'; + type[1] = ktype[0]; + + if (narg == 2 + && (ker = luaT_toudata(L,2,torch_(Tensor_id))) + && (im = luaT_toudata(L,1,torch_(Tensor_id)))) + { + } + else if (narg == 3 + && (lua_type(L,3) == LUA_TSTRING) + && (ker = luaT_toudata(L,2,torch_(Tensor_id))) + && (im = luaT_toudata(L,1,torch_(Tensor_id)))) + { + type[0] = *(luaL_checkstring(L,3)); + luaL_argcheck(L, (type[0] == 'v' || type[0] == 'V' || type[0] == 'f' || type[0] == 'F'), + 3, "[Tensor, ] Tensor, Tensor [, x or c]"); + if (type[0] == 'V') type[0] = 'v'; + if (type[0] == 'F') type[0] = 'f'; + } + else if (narg == 4 + && (type[0] = *(luaL_checkstring(L,4))) + && (ker = luaT_toudata(L,3,torch_(Tensor_id))) + && (im = luaT_toudata(L,2,torch_(Tensor_id))) + && (r_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + rgiven = 1; + } + else + { + luaL_error(L,"[Tensor, ] Tensor, Tensor [, x or c]"); + } + + if (!r_) r_ = THTensor_(new)(); + + if (im->nDimension == 3 && ker->nDimension == 3) + { + THTensor_(conv3Dmul)(r_,0.0,1.0,im,ker,1,1,1,type); + } + else if (im->nDimension == 4 && ker->nDimension == 4) + { + THTensor_(conv3Dcmul)(r_,0.0,1.0,im,ker,1,1,1,type); + } + else if (im->nDimension == 4 && ker->nDimension == 5) + { + THTensor_(conv3Dmv)(r_,0.0,1.0,im,ker,1,1,1,type); + } + else + { + luaL_error(L," (3D,3D) or (4D,4D) or (4D,5D) "); + } + + pushreturn(rgiven, r_, torch_(Tensor_id)); + + return 1; +} + +static int torch_(conv2)(lua_State *L) +{ + return torch_(convxcorr2)(L,"convolution"); +} +static int torch_(xcorr2)(lua_State *L) +{ + return torch_(convxcorr2)(L,"xcorrelation"); +} + + +static int torch_(conv3)(lua_State *L) +{ + return torch_(convxcorr3)(L,"convolution"); +} +static int torch_(xcorr3)(lua_State *L) +{ + return torch_(convxcorr3)(L,"xcorrelation"); +} + +static const struct luaL_Reg torch_(Conv__) [] = { + {"conv2", torch_(conv2)}, + {"xcorr2", torch_(xcorr2)}, + {"conv3", torch_(conv3)}, + {"xcorr3", torch_(xcorr3)}, + {NULL, NULL} +}; + +void torch_(Conv_init)(lua_State *L) +{ + torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor)); + + /* register everything into the "torch" field of the tensor metaclass */ + luaT_pushmetaclass(L, torch_(Tensor_id)); + lua_pushstring(L, "torch"); + lua_rawget(L, -2); + luaL_register(L, NULL, torch_(Conv__)); + lua_pop(L, 2); +} + +#endif + diff --git a/generic/TensorLapack.c b/generic/TensorLapack.c new file mode 100644 index 00000000000..11108d40f01 --- /dev/null +++ b/generic/TensorLapack.c @@ -0,0 +1,274 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/TensorLapack.c" +#else + +#define pushreturn(i,t,tid) \ + if (!i) \ + luaT_pushudata(L, t, tid); \ + else \ + lua_pushvalue(L,i) + +static int torch_(gesv)(lua_State *L) +{ + int narg = lua_gettop(L); + THTensor *ra_ = NULL; + THTensor *rb_ = NULL; + THTensor *a_ = NULL; + THTensor *b_ = NULL; + int ragiven = 0; + int rbgiven = 0; + + if (narg == 2 + && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (b_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + } + else if (narg == 3 + && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (b_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + if(lua_toboolean(L,3)) + { + ra_ = a_; + rb_ = b_; + a_ = NULL; + b_ = NULL; + ragiven = 2; + rbgiven = 1; + } + else + { + luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]"); + } + } + else if (narg == 4 + && (a_ = luaT_toudata(L,4,torch_(Tensor_id))) + && (b_ = luaT_toudata(L,3,torch_(Tensor_id))) + && (ra_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (rb_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + ragiven = 2; + rbgiven = 1; + } + else + { + luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]"); + } + + if (!ra_) ra_ = THTensor_(new)(); + if (!rb_) rb_ = THTensor_(new)(); + + THTensor_(gesv)(rb_,ra_,b_,a_); + + pushreturn(rbgiven,rb_,torch_(Tensor_id)); + pushreturn(ragiven,ra_,torch_(Tensor_id)); + + return 2; +} + +static int torch_(gels)(lua_State *L) +{ + int narg = lua_gettop(L); + THTensor *ra_ = NULL; + THTensor *rb_ = NULL; + THTensor *a_ = NULL; + THTensor *b_ = NULL; + int ragiven = 0; + int rbgiven = 0; + + if (narg == 2 + && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (b_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + } + else if (narg == 3 + && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (b_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + if (lua_toboolean(L,3)) + { + ra_ = a_; + rb_ = b_; + a_ = NULL; + b_ = NULL; + ragiven = 2; + rbgiven = 1; + } + else + { + luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]"); + } + } + else if (narg == 4 + && (a_ = luaT_toudata(L,4,torch_(Tensor_id))) + && (b_ = luaT_toudata(L,3,torch_(Tensor_id))) + && (ra_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (rb_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + ragiven = 2; + rbgiven = 1; + } + else + { + luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]"); + } + + if (!ra_) ra_ = THTensor_(new)(); + if (!rb_) rb_ = THTensor_(new)(); + + THTensor_(gels)(rb_,ra_,b_,a_); + + pushreturn(rbgiven,rb_,torch_(Tensor_id)); + pushreturn(ragiven,ra_,torch_(Tensor_id)); + + return 2; +} + +static int torch_(eig)(lua_State *L) +{ + int narg = lua_gettop(L); + THTensor *re_ = NULL; + THTensor *rv_ = NULL; + THTensor *a_ = NULL; + char type = 'N'; + char uplo = 'U'; + int regiven = 0; + int rvgiven = 0; + + if (narg == 1 + && (a_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + } + else if (narg == 2 + && (lua_type(L,2) == LUA_TSTRING) + && (a_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + type = *(luaL_checkstring(L,2)); + luaL_argcheck(L, (type == 'v' || type == 'V' || type == 'n' || type == 'N'), + 2, "[Tensor, ] [Tensor, ] Tensor [, N or V]"); + if (type == 'v') type = 'V'; + if (type == 'n') type = 'N'; + } + else if (narg == 2 + && (a_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (re_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + regiven = 1; + } + else if (narg == 3 + && (a_ = luaT_toudata(L,3,torch_(Tensor_id))) + && (rv_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (re_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + regiven = 1; + rvgiven = 2; + } + else if (narg == 4 + && (type = *(luaL_checkstring(L,4))) + && (a_ = luaT_toudata(L,3,torch_(Tensor_id))) + && (rv_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (re_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + regiven = 1; + rvgiven = 2; + } + else + { + luaL_error(L,"[Tensor, ] [Tensor, ] Tensor [, N or V]"); + } + if (!re_) re_ = THTensor_(new)(); + if (!rv_) rv_ = THTensor_(new)(); + + THTensor_(syev)(re_,rv_,a_,&type,&uplo); + + pushreturn(regiven, re_, torch_(Tensor_id)); + pushreturn(rvgiven, rv_, torch_(Tensor_id)); + + return 2; +} + +static int torch_(svd)(lua_State *L) +{ + int narg = lua_gettop(L); + THTensor *ru_ = NULL; + THTensor *rs_ = NULL; + THTensor *rv_ = NULL; + THTensor *a_ = NULL; + char type = 'S'; + int rugiven = 0; + int rsgiven = 0; + int rvgiven = 0; + + if (narg == 1 + && (a_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + } + else if (narg ==2 + && (type = *(luaL_checkstring(L,2))) + && (a_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + luaL_argcheck(L, (type == 's' || type == 'S' || type == 'a' || type == 'A'), + 2, "[Tensor, ] [Tensor, ] [Tensor, ] Tensor [, A or S]"); + if (type == 's') type = 'S'; + if (type == 'a') type = 'A'; + } + else if (narg == 4 + && (a_ = luaT_toudata(L,4,torch_(Tensor_id))) + && (rv_ = luaT_toudata(L,3,torch_(Tensor_id))) + && (rs_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (ru_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + rugiven = 1; + rsgiven = 2; + rvgiven = 3; + } + else if (narg == 5 + && (type = *(luaL_checkstring(L,5))) + && (a_ = luaT_toudata(L,4,torch_(Tensor_id))) + && (rv_ = luaT_toudata(L,3,torch_(Tensor_id))) + && (rs_ = luaT_toudata(L,2,torch_(Tensor_id))) + && (ru_ = luaT_toudata(L,1,torch_(Tensor_id)))) + { + rugiven = 1; + rsgiven = 2; + rvgiven = 3; + } + else + { + luaL_error(L,"[Tensor, Tensor, Tensor], Tensor, [, 'A' or 'S' ]"); + } + + if (!ru_) ru_ = THTensor_(new)(); + if (!rs_) rs_ = THTensor_(new)(); + if (!rv_) rv_ = THTensor_(new)(); + + THTensor_(gesvd)(ru_,rs_,rv_,a_,&type); + + pushreturn(rugiven,ru_,torch_(Tensor_id)); + pushreturn(rsgiven,rs_,torch_(Tensor_id)); + pushreturn(rvgiven,rv_,torch_(Tensor_id)); + + return 3; +} + +static const struct luaL_Reg torch_(lapack__) [] = { + {"gesv", torch_(gesv)}, + {"gels", torch_(gels)}, + {"eig", torch_(eig)}, + {"svd", torch_(svd)}, + {NULL, NULL} +}; + +void torch_(Lapack_init)(lua_State *L) +{ + torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor)); + + /* register everything into the "torch" field of the tensor metaclass */ + luaT_pushmetaclass(L, torch_(Tensor_id)); + lua_pushstring(L, "torch"); + lua_rawget(L, -2); + luaL_register(L, NULL, torch_(lapack__)); + lua_pop(L, 2); +} + +#endif diff --git a/generic/TensorOperator.c b/generic/TensorOperator.c new file mode 100644 index 00000000000..10d8f8e3a10 --- /dev/null +++ b/generic/TensorOperator.c @@ -0,0 +1,177 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/TensorOperator.c" +#else + +static const void* torch_Tensor_id; + +static int torch_TensorOperator_(__add__)(lua_State *L) +{ + THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); + THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); + THTensor *r; + + if(!tensor1 && !tensor2) + luaL_error(L, "expecting two Tensors or one Tensor and one number"); + else + { + r = THTensor_(new)(); + luaT_pushudata(L, r, torch_Tensor_id); + + if(!tensor1 && tensor2) + { + THTensor_(resizeAs)(r, tensor2); + THTensor_(copy)(r, tensor2); + THTensor_(add)(r, r, luaL_checknumber(L, 1)); + } + else if(tensor1 && !tensor2) + { + THTensor_(resizeAs)(r, tensor1); + THTensor_(copy)(r, tensor1); + THTensor_(add)(r, r, luaL_checknumber(L, 2)); + } + else + { + THTensor_(resizeAs)(r, tensor1); + THTensor_(copy)(r, tensor1); + THTensor_(cadd)(r, r, 1, tensor2); + } + } + return 1; +} + +static int torch_TensorOperator_(__sub__)(lua_State *L) +{ + THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); + THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); + THTensor *r; + + if(!tensor1 && !tensor2) + luaL_error(L, "expecting two Tensors or one Tensor and one number"); + else + { + r = THTensor_(new)(); + luaT_pushudata(L, r, torch_Tensor_id); + + if(!tensor1 && tensor2) + { + THTensor_(resizeAs)(r, tensor2); + THTensor_(fill)(r, luaL_checknumber(L, 1)); + THTensor_(cadd)(r, r, -1, tensor2); + } + else if(tensor1 && !tensor2) + { + THTensor_(resizeAs)(r, tensor1); + THTensor_(copy)(r, tensor1); + THTensor_(add)(r, r, -luaL_checknumber(L, 2)); + } + else + { + THTensor_(resizeAs)(r, tensor1); + THTensor_(copy)(r, tensor1); + THTensor_(cadd)(r, r, -1, tensor2); + } + } + return 1; +} + +static int torch_TensorOperator_(__unm__)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *r; + + r = THTensor_(new)(); + luaT_pushudata(L, r, torch_Tensor_id); + THTensor_(resizeAs)(r, tensor); + THTensor_(copy)(r, tensor); + THTensor_(mul)(r, r, -1); + + return 1; +} + +static int torch_TensorOperator_(__mul__)(lua_State *L) +{ + THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); + THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); + THTensor *r; + + if(!tensor1 && !tensor2) + luaL_error(L, "expecting two Tensors or one Tensor and one number"); + else + { + r = THTensor_(new)(); + luaT_pushudata(L, r, torch_Tensor_id); + + if(!tensor1 && tensor2) + { + THTensor_(resizeAs)(r, tensor2); + THTensor_(copy)(r, tensor2); + THTensor_(mul)(r, r, luaL_checknumber(L, 1)); + } + else if(tensor1 && !tensor2) + { + THTensor_(resizeAs)(r, tensor1); + THTensor_(copy)(r, tensor1); + THTensor_(mul)(r, r, luaL_checknumber(L, 2)); + } + else + { + int dimt = tensor1->nDimension; + int dims = tensor2->nDimension; + + if(dimt == 1 && dims == 1) + lua_pushnumber(L, THTensor_(dot)(tensor1, tensor2)); /* ok, we wasted r, but who cares */ + else if(dimt == 2 && dims == 1) + { + THTensor_(resize1d)(r, tensor1->size[0]); + THTensor_(zero)(r); + THTensor_(addmv)(r, 1, r, 1, tensor1, tensor2); + } + else if(dimt == 2 && dims == 2) + { + THTensor_(resize2d)(r, tensor1->size[0], tensor2->size[1]); + THTensor_(zero)(r); + THTensor_(addmm)(r, 1, r, 1, tensor1, tensor2); + } + else + luaL_error(L, "multiplication between %dD and %dD tensors not yet supported", tensor1->nDimension, tensor2->nDimension); + } + } + return 1; +} + +static int torch_TensorOperator_(__div__)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *r; + + luaL_argcheck(L, lua_isnumber(L,2), 2, "number expected"); + + r = THTensor_(new)(); + luaT_pushudata(L, r, torch_Tensor_id); + + THTensor_(resizeAs)(r, tensor); + THTensor_(copy)(r, tensor); + THTensor_(mul)(r, r, 1/lua_tonumber(L, 2)); + + return 1; +} + +static const struct luaL_Reg torch_TensorOperator_(_) [] = { + {"__add__", torch_TensorOperator_(__add__)}, + {"__sub__", torch_TensorOperator_(__sub__)}, + {"__unm__", torch_TensorOperator_(__unm__)}, + {"__mul__", torch_TensorOperator_(__mul__)}, + {"__div__", torch_TensorOperator_(__div__)}, + {NULL, NULL} +}; + +void torch_TensorOperator_(init)(lua_State *L) +{ + torch_Tensor_id = luaT_checktypename2id(L, STRING_torchTensor); + + luaT_pushmetaclass(L, torch_Tensor_id); + luaL_register(L, NULL, torch_TensorOperator_(_)); + lua_pop(L, 1); +} + +#endif diff --git a/generic/hist.c b/generic/hist.c new file mode 100644 index 00000000000..6b18fd6f091 --- /dev/null +++ b/generic/hist.c @@ -0,0 +1,44 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/lab.c" +#else + +#include "interfaces.c" + +static int lab_(histc)(lua_State *L) +{ + THTensor *r = luaT_checkudata(L, 1, torch_(Tensor_id)); + THTensor *h = luaT_checkudata(L, 2, torch_(Tensor_id)); + int nbins = luaL_checknumber(L, 3); + real *h_data = THTensor_(data)(h); + + TH_TENSOR_APPLY(real, r, \ + if ((*r_data <= nbins) && (*r_data >= 1)) { \ + *(h_data + (int)(*r_data) - 1) += 1; \ + }) + return 0; +} + +static const struct luaL_Reg lab_(stuff__) [] = { + {"_histc", lab_(histc)}, +#endif + {NULL, NULL} +}; + +void lab_(init)(lua_State *L) +{ + torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor)); + torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); + + /* register everything into the "lab" field of the tensor metaclass */ + luaT_pushmetaclass(L, torch_(Tensor_id)); + lua_pushstring(L, "lab"); + lua_newtable(L); + luaL_register(L, NULL, lab_(stuff__)); + lua_rawset(L, -3); + lua_pop(L, 1); + +/* luaT_registeratid(L, lab_(stuff__), torch_(Tensor_id)); */ +/* luaL_register(L, NULL, lab_(stuff__)); */ +} + +#endif diff --git a/hist.lua b/hist.lua new file mode 100644 index 00000000000..c3615e338b7 --- /dev/null +++ b/hist.lua @@ -0,0 +1,123 @@ +-- +-- rudimentary histogram diplay on the command line. +-- +-- Author: Marco Scoffier +-- Date : +-- Mod : Oct 21, 2011 +-- + made 80 columns default +-- + save index of max bin in h.max not pointer to bin +-- +function torch.histc__tostring(h, barHeight) + barHeight = barHeight or 10 + local lastm = h[h.max].nb + local incr = lastm/(barHeight+1) + local m = lastm - incr + local tl = torch.Tensor(#h):fill(0) + local toph = '|' + local topm = ':' + local topl = '.' + local bar = '|' + local blank = ' ' + local yaxis = '--------:' + local str = 'nsamples:' + str = str .. + string.format(' min:(bin:%d/#%d/cntr:%2.2f) max:(bin:%d/#%d/cntr:%2.2f)\n', + h.min,h[h.min].nb,h[h.min].val, + h.max,h[h.max].nb,h[h.max].val) + + str = str .. yaxis + for j = 1,#h do + str = str .. '-' + end + str = str .. '\n' + + for i = 1,barHeight do + -- y axis + if i%1==0 then + str = str .. string.format('%1.2e:',m) + end + for j = 1,#h do + if tl[j] == 1 then + str = str .. bar + elseif h[j].nb < m then + str = str .. blank + else + -- in the bracket + tl[j] = 1 + -- find 1/3rds + local p = (lastm - h[j].nb) / incr + if p > 0.66 then + str = str .. toph + elseif p > 0.33 then + str = str .. topm + else + str = str .. topl + end + end + end + str = str .. '\n' + lastm = m + m = m - incr + end + -- x axis + str = str .. yaxis + for j = 1,#h do + if ((j - 2) % 6 == 0)then + str = str .. '^' + else + str = str .. '-' + end + end + str = str .. '\ncenters ' + for j = 1,#h do + if ((j - 2) % 6 == 0)then + if h[j].val < 0 then + str = str .. '-' + else + str = str .. '+' + end + str = str .. string.format('%1.2f ',math.abs(h[j].val)) + end + end + return str +end + +-- a simple function that computes the histogram of a tensor +function torch.histc(...) + -- get args + local args = {...} + local tensor = args[1] or error('usage: torch.histc (tensor [, nBins] [, min] [, max]') + local bins = args[2] or 80 - 8 + local min = args[3] or tensor:min() + local max = args[4] or tensor:max() + local raw = args[5] or false + + -- compute histogram + local hist = torch.zeros(bins) + local ten = torch.Tensor(tensor:nElement()):copy(tensor) + ten:add(-min):div(max-min):mul(bins - 1e-6):floor():add(1) + ten.torch._histc(ten, hist, bins) + + -- return raw histogram (no extra info) + if raw then return hist end + + -- cleanup hist + local cleanhist = {} + cleanhist.raw = hist + local _,mx = torch.max(cleanhist.raw) + local _,mn = torch.min(cleanhist.raw) + cleanhist.bins = bins + cleanhist.binwidth = (max-min)/bins + for i = 1,bins do + cleanhist[i] = {} + cleanhist[i].val = min + (i-0.5)*cleanhist.binwidth + cleanhist[i].nb = hist[i] + end + cleanhist.max = mx[1] + cleanhist.min = mn[1] + + -- print function + setmetatable(cleanhist, {__tostring=torch.histc__tostring}) + return cleanhist +end + diff --git a/init.c b/init.c new file mode 100644 index 00000000000..584f0609083 --- /dev/null +++ b/init.c @@ -0,0 +1,99 @@ +#include "general.h" +#include "utils.h" + +extern void torch_utils_init(lua_State *L); +extern void torch_random_init(lua_State *L); +extern void torch_File_init(lua_State *L); +extern void torch_File_init_storage_id(lua_State *L); +extern void torch_DiskFile_init(lua_State *L); +extern void torch_MemoryFile_init(lua_State *L); +extern void torch_PipeFile_init(lua_State *L); +extern void torch_Timer_init(lua_State *L); + +extern void torch_ByteStorage_init(lua_State *L); +extern void torch_CharStorage_init(lua_State *L); +extern void torch_ShortStorage_init(lua_State *L); +extern void torch_IntStorage_init(lua_State *L); +extern void torch_LongStorage_init(lua_State *L); +extern void torch_FloatStorage_init(lua_State *L); +extern void torch_DoubleStorage_init(lua_State *L); + +extern void torch_ByteTensor_init(lua_State *L); +extern void torch_CharTensor_init(lua_State *L); +extern void torch_ShortTensor_init(lua_State *L); +extern void torch_IntTensor_init(lua_State *L); +extern void torch_LongTensor_init(lua_State *L); +extern void torch_FloatTensor_init(lua_State *L); +extern void torch_DoubleTensor_init(lua_State *L); + +extern void torch_ByteTensorOperator_init(lua_State *L); +extern void torch_CharTensorOperator_init(lua_State *L); +extern void torch_ShortTensorOperator_init(lua_State *L); +extern void torch_IntTensorOperator_init(lua_State *L); +extern void torch_LongTensorOperator_init(lua_State *L); +extern void torch_FloatTensorOperator_init(lua_State *L); +extern void torch_DoubleTensorOperator_init(lua_State *L); + +extern void torch_TensorMath_init(lua_State *L); + +static lua_State *globalL; +static void luaTorchErrorHandlerFunction(const char *msg) +{ + luaL_error(globalL, msg); +} + +static void luaTorchArgCheckHandlerFunction(int condition, int argNumber, const char *msg) +{ + luaL_argcheck(globalL, condition, argNumber, msg); +} + +DLL_EXPORT int luaopen_libtorch(lua_State *L) +{ + globalL = L; + THSetErrorHandler(luaTorchErrorHandlerFunction); + THSetArgCheckHandler(luaTorchArgCheckHandlerFunction); + + lua_newtable(L); + lua_pushvalue(L, -1); + lua_setfield(L, LUA_GLOBALSINDEX, "torch"); + + torch_File_init(L); + + torch_ByteStorage_init(L); + torch_CharStorage_init(L); + torch_ShortStorage_init(L); + torch_IntStorage_init(L); + torch_LongStorage_init(L); + torch_FloatStorage_init(L); + torch_DoubleStorage_init(L); + + torch_ByteTensor_init(L); + torch_CharTensor_init(L); + torch_ShortTensor_init(L); + torch_IntTensor_init(L); + torch_LongTensor_init(L); + torch_FloatTensor_init(L); + torch_DoubleTensor_init(L); + + torch_File_init_storage_id(L); + + torch_ByteTensorOperator_init(L); + torch_CharTensorOperator_init(L); + torch_ShortTensorOperator_init(L); + torch_IntTensorOperator_init(L); + torch_LongTensorOperator_init(L); + torch_FloatTensorOperator_init(L); + torch_DoubleTensorOperator_init(L); + + torch_Timer_init(L); + torch_DiskFile_init(L); + torch_PipeFile_init(L); + torch_MemoryFile_init(L); + + torch_TensorMath_init(L); + + torch_utils_init(L); + torch_random_init(L); + + return 1; +} diff --git a/init.lua b/init.lua new file mode 100644 index 00000000000..42f2b335e99 --- /dev/null +++ b/init.lua @@ -0,0 +1,78 @@ + +-- We are using paths.require to appease mkl +require "paths" +paths.require "libtorch" +require "libtorch" + +--- package stuff +function torch.packageLuaPath(name) + if not name then + local ret = string.match(torch.packageLuaPath('torch'), '(.*)/') + if not ret then --windows? + ret = string.match(torch.packageLuaPath('torch'), '(.*)\\') + end + return ret + end + for path in string.gmatch(package.path, "(.-);") do + path = string.gsub(path, "%?", name) + local f = io.open(path) + if f then + f:close() + local ret = string.match(path, "(.*)/") + if not ret then --windows? + ret = string.match(path, "(.*)\\") + end + return ret + end + end +end + +function torch.include(package, file) + dofile(torch.packageLuaPath(package) .. '/' .. file) +end + +function torch.class(tname, parenttname) + + local function constructor(...) + local self = {} + torch.setmetatable(self, tname) + if self.__init then + self:__init(...) + end + return self + end + + local function factory() + local self = {} + torch.setmetatable(self, tname) + return self + end + + local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory) + local mpt + if parenttname then + mpt = torch.getmetatable(parenttname) + end + return mt, mpt +end + +function torch.setdefaulttensortype(typename) + assert(type(typename) == 'string', 'string expected') + if torch.getconstructortable(typename) then + torch.Tensor = torch.getconstructortable(typename) + torch.Storage = torch.getconstructortable(torch.typename(torch.Tensor(1):storage())) + torch.__setdefaulttensortype(typename) + else + error(string.format("<%s> is not a string describing a torch object", typename)) + end +end + +torch.setdefaulttensortype('torch.DoubleTensor') + +torch.include('torch', 'Tensor.lua') +torch.include('torch', 'File.lua') +torch.include('torch', 'CmdLine.lua') +torch.include('torch', 'Tester.lua') +torch.include('torch', 'TensorMath.lua') +torch.include('torch', 'test.lua') +return torch diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt new file mode 100644 index 00000000000..cb04528a1db --- /dev/null +++ b/lib/CMakeLists.txt @@ -0,0 +1,2 @@ +ADD_SUBDIRECTORY(TH) +ADD_SUBDIRECTORY(luaT) diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt new file mode 100644 index 00000000000..173603f6a9d --- /dev/null +++ b/lib/TH/CMakeLists.txt @@ -0,0 +1,117 @@ +# -*- cmake -*- + +SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) + +SET(hdr + THGeneral.h THStorage.h THTensor.h THTensorApply.h + THBlas.h THLapack.h THLogAdd.h THRandom.h THVector.h) +SET(src + THGeneral.c THStorage.c THTensor.c THBlas.c THLapack.c + THLogAdd.c THRandom.c + THFile.c THDiskFile.c THMemoryFile.c) + +SET(src ${src} ${hdr}) + +IF(UNIX) + INCLUDE(CheckFunctionExists) + SET(CMAKE_EXTRA_INCLUDE_FILES "sys/mman.h") + CHECK_FUNCTION_EXISTS(mmap HAVE_MMAP) + IF(HAVE_MMAP) + ADD_DEFINITIONS(-DHAVE_MMAP=1) + ENDIF(HAVE_MMAP) +ENDIF(UNIX) + +ADD_LIBRARY(TH SHARED ${src}) + +FIND_PACKAGE(BLAS) +FIND_PACKAGE(LAPACK) + +IF (LAPACK_FOUND) + SET(CMAKE_C_FLAGS "-D__LAPACK__ ${CMAKE_C_FLAGS}") +ENDIF(LAPACK_FOUND) + +FIND_PACKAGE(SSE) + +IF (SSE2_FOUND) + SET(CMAKE_C_FLAGS "-msse2 -D__SSE2__ ${CMAKE_C_FLAGS}") +ENDIF (SSE2_FOUND) +IF (SSE3_FOUND) + SET(CMAKE_C_FLAGS "-msse3 -D__SSE3__ ${CMAKE_C_FLAGS}") +ENDIF (SSE3_FOUND) +IF (SSSE3_FOUND) + SET(CMAKE_C_FLAGS "-mssse3 -D__SSSE3__ ${CMAKE_C_FLAGS}") +ENDIF (SSSE3_FOUND) +IF (SSE4.1_FOUND) + SET(CMAKE_C_FLAGS "-msse4.1 -D__SSE4_1__ ${CMAKE_C_FLAGS}") +ENDIF (SSE4.1_FOUND) + +IF(BLAS_FOUND) + ADD_DEFINITIONS(-DUSE_LAPACK) +# INCLUDE_DIRECTORIES(${CBLAS_INCLUDE_DIR}) + TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES}) +ENDIF(BLAS_FOUND) + +#CONFIGURE_FILE("THCBlas.h.in" "${CMAKE_CURRENT_BINARY_DIR}/THCBlas.h") +#INCLUDE_DIRECTORIES("${CMAKE_CURRENT_BINARY_DIR}") +#INSTALL(FILES "${CMAKE_CURRENT_BINARY_DIR}/THCBlas.h" +# DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH") + +INSTALL(TARGETS TH + RUNTIME DESTINATION "${Torch_INSTALL_BIN_SUBDIR}" + LIBRARY DESTINATION "${Torch_INSTALL_LIB_SUBDIR}" + ARCHIVE DESTINATION "${Torch_INSTALL_LIB_SUBDIR}") + +INSTALL(FILES + TH.h + THBlas.h + THDiskFile.h + THFile.h + THFilePrivate.h + THGeneral.h + THGenerateAllTypes.h + THGenerateFloatTypes.h + THGenerateIntTypes.h + THLapack.h + THLogAdd.h + THMemoryFile.h + THRandom.h + THStorage.h + THTensor.h + THTensorApply.h + THTensorDimApply.h + THTensorMacros.h + THVector.h + DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH") + +INSTALL(FILES + generic/THBlas.c + generic/THBlas.h + generic/THLapack.c + generic/THLapack.h + generic/THStorage.c + generic/THStorage.h + generic/THStorageCopy.c + generic/THStorageCopy.h + generic/THTensor.c + generic/THTensor.h + generic/THTensorConv.c + generic/THTensorConv.h + generic/THTensorCopy.c + generic/THTensorCopy.h + generic/THTensorLapack.c + generic/THTensorLapack.h + generic/THTensorMath.c + generic/THTensorMath.h + generic/THTensorRandom.c + generic/THTensorRandom.h + generic/THVector.c + DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH/generic") + +# Create THConfig.cmake +GET_TARGET_PROPERTY(TH_OUTPUT_NAME TH LOCATION) +GET_FILENAME_COMPONENT(TH_OUTPUT_NAME ${TH_OUTPUT_NAME} NAME) +SET(TH_LIBRARIES "${Torch_INSTALL_LIB}/${TH_OUTPUT_NAME}") +SET(TH_INCLUDE_DIR "${Torch_INSTALL_INCLUDE}/TH") +CONFIGURE_FILE(THConfig.cmake.in "${Torch_BINARY_DIR}/cmake-external/THConfig.cmake") +INSTALL(FILES "${Torch_BINARY_DIR}/cmake-external/THConfig.cmake" + DESTINATION "${Torch_INSTALL_CMAKE_SUBDIR}") diff --git a/lib/TH/TH.h b/lib/TH/TH.h new file mode 100644 index 00000000000..41e9c00cb70 --- /dev/null +++ b/lib/TH/TH.h @@ -0,0 +1,23 @@ +#ifndef TH_INC +#define TH_INC + +#include "THBlas.h" + +#ifdef __LAPACK__ +#include "THLapack.h" +#endif + +#include "THVector.h" +#include "THGeneral.h" +#include "THLogAdd.h" +#include "THRandom.h" +#include "THStorage.h" +#include "THTensor.h" +#include "THTensorApply.h" +#include "THTensorDimApply.h" + +#include "THFile.h" +#include "THDiskFile.h" +#include "THMemoryFile.h" + +#endif diff --git a/lib/TH/THBlas.c b/lib/TH/THBlas.c new file mode 100644 index 00000000000..5884e5c1a63 --- /dev/null +++ b/lib/TH/THBlas.c @@ -0,0 +1,5 @@ +#include "THBlas.h" + +/* #include "THCBlas.h" */ +#include "generic/THBlas.c" +#include "THGenerateAllTypes.h" diff --git a/lib/TH/THBlas.h b/lib/TH/THBlas.h new file mode 100644 index 00000000000..5fef0febcd5 --- /dev/null +++ b/lib/TH/THBlas.h @@ -0,0 +1,11 @@ +#ifndef TH_BLAS_INC +#define TH_BLAS_INC + +#include "THGeneral.h" + +#define THBlas_(NAME) TH_CONCAT_4(TH,Real,Blas_,NAME) + +#include "generic/THBlas.h" +#include "THGenerateAllTypes.h" + +#endif diff --git a/lib/TH/THCBlas.h.in b/lib/TH/THCBlas.h.in new file mode 100644 index 00000000000..82243465043 --- /dev/null +++ b/lib/TH/THCBlas.h.in @@ -0,0 +1,8 @@ +/* -*- C -*- */ + +#cmakedefine USE_CBLAS @USE_CBLAS@ + +#if USE_CBLAS +# include "@CBLAS_INCLUDE_FILE@" +#endif + diff --git a/lib/TH/THConfig.cmake.in b/lib/TH/THConfig.cmake.in new file mode 100644 index 00000000000..306cd878bc7 --- /dev/null +++ b/lib/TH/THConfig.cmake.in @@ -0,0 +1,9 @@ +# Find the TH includes and library +# +# TH_INCLUDE_DIR -- where to find the includes +# TH_LIBRARIES -- list of libraries to link against +# TH_FOUND -- set to 1 if found + +SET(TH_FOUND 1) +SET(TH_INCLUDE_DIR "@TH_INCLUDE_DIR@") +SET(TH_LIBRARIES "@TH_LIBRARIES@") diff --git a/lib/TH/THDiskFile.c b/lib/TH/THDiskFile.c new file mode 100644 index 00000000000..63a7dd2a221 --- /dev/null +++ b/lib/TH/THDiskFile.c @@ -0,0 +1,592 @@ +#include "THGeneral.h" +#include "THDiskFile.h" +#include "THFilePrivate.h" + +typedef struct THDiskFile__ +{ + THFile file; + + FILE *handle; + char *name; + int isNativeEncoding; + +} THDiskFile; + +static int THDiskFile_isOpened(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)self; + return (dfself->handle != NULL); +} + +const char *THDiskFile_name(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)self; + return dfself->name; +} + + +#define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM) \ + static long THDiskFile_read##TYPEC(THFile *self, TYPE *data, long n) \ + { \ + THDiskFile *dfself = (THDiskFile*)(self); \ + long nread = 0L; \ + \ + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \ + THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); \ + \ + if(dfself->file.isBinary) \ + { \ + nread = fread(data, sizeof(TYPE), n, dfself->handle); \ + if(!dfself->isNativeEncoding && (sizeof(TYPE) > 1) && (nread > 0)) \ + THDiskFile_reverseMemory(data, data, sizeof(TYPE), nread); \ + } \ + else \ + { \ + long i; \ + for(i = 0; i < n; i++) \ + { \ + ASCII_READ_ELEM; /* increment here result and break if wrong */ \ + } \ + if(dfself->file.isAutoSpacing && (n > 0)) \ + { \ + int c = fgetc(dfself->handle); \ + if( (c != '\n') && (c != EOF) ) \ + ungetc(c, dfself->handle); \ + } \ + } \ + \ + if(nread != n) \ + { \ + dfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \ + if(!dfself->file.isQuiet) \ + THError("read error: read %d blocks instead of %d", nread, n); \ + } \ + \ + return nread; \ + } \ + \ + static long THDiskFile_write##TYPEC(THFile *self, TYPE *data, long n) \ + { \ + THDiskFile *dfself = (THDiskFile*)(self); \ + long nwrite = 0L; \ + \ + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \ + THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); \ + \ + if(dfself->file.isBinary) \ + { \ + if(dfself->isNativeEncoding) \ + { \ + nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \ + } \ + else \ + { \ + if(sizeof(TYPE) > 1) \ + { \ + char *buffer = THAlloc(sizeof(TYPE)*n); \ + THDiskFile_reverseMemory(buffer, data, sizeof(TYPE), n); \ + nwrite = fwrite(buffer, sizeof(TYPE), n, dfself->handle); \ + THFree(buffer); \ + } \ + else \ + nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \ + } \ + } \ + else \ + { \ + long i; \ + for(i = 0; i < n; i++) \ + { \ + ASCII_WRITE_ELEM; \ + if( dfself->file.isAutoSpacing && (i < n-1) ) \ + fprintf(dfself->handle, " "); \ + } \ + if(dfself->file.isAutoSpacing && (n > 0)) \ + fprintf(dfself->handle, "\n"); \ + } \ + \ + if(nwrite != n) \ + { \ + dfself->file.hasError = 1; \ + if(!dfself->file.isQuiet) \ + THError("write error: wrote %d blocks instead of %d", nwrite, n); \ + } \ + \ + return nwrite; \ +} + +static int THDiskFile_mode(const char *mode, int *isReadable, int *isWritable) +{ + *isReadable = 0; + *isWritable = 0; + if(strlen(mode) == 1) + { + if(*mode == 'r') + { + *isReadable = 1; + return 1; + } + else if(*mode == 'w') + { + *isWritable = 1; + return 1; + } + } + else if(strlen(mode) == 2) + { + if(mode[0] == 'r' && mode[1] == 'w') + { + *isReadable = 1; + *isWritable = 1; + return 1; + } + } + return 0; +} + +static void THDiskFile_synchronize(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + fflush(dfself->handle); +} + +static void THDiskFile_seek(THFile *self, long position) +{ + THDiskFile *dfself = (THDiskFile*)(self); + + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + THArgCheck(position >= 0, 2, "position must be positive"); + + if(fseek(dfself->handle, position, SEEK_SET) < 0) + { + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("unable to seek at position %d", position); + } +} + +static void THDiskFile_seekEnd(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + + if(fseek(dfself->handle, 0L, SEEK_END) < 0) + { + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("unable to seek at end of file"); + } +} + +static long THDiskFile_position(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + return ftell(dfself->handle); +} + +static void THDiskFile_close(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + fclose(dfself->handle); + dfself->handle = NULL; +} + +/* Little and Big Endian */ + +static void THDiskFile_reverseMemory(void *dst, const void *src, long blockSize, long numBlocks) +{ + if(blockSize != 1) + { + long halfBlockSize = blockSize/2; + char *charSrc = (char*)src; + char *charDst = (char*)dst; + long b, i; + for(b = 0; b < numBlocks; b++) + { + for(i = 0; i < halfBlockSize; i++) + { + char z = charSrc[i]; + charDst[i] = charSrc[blockSize-1-i]; + charDst[blockSize-1-i] = z; + } + charSrc += blockSize; + charDst += blockSize; + } + } +} + +int THDiskFile_isLittleEndianCPU(void) +{ + int x = 7; + char *ptr = (char *)&x; + + if(ptr[0] == 0) + return 0; + else + return 1; +} + +int THDiskFile_isBigEndianCPU(void) +{ + return(!THDiskFile_isLittleEndianCPU()); +} + +void THDiskFile_nativeEndianEncoding(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + dfself->isNativeEncoding = 1; +} + +void THDiskFile_littleEndianEncoding(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + dfself->isNativeEncoding = THDiskFile_isLittleEndianCPU(); +} + +void THDiskFile_bigEndianEncoding(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + dfself->isNativeEncoding = !THDiskFile_isLittleEndianCPU(); +} + +/* End of Little and Big Endian Stuff */ + +static void THDiskFile_free(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + if(dfself->handle) + fclose(dfself->handle); + THFree(dfself->name); + THFree(dfself); +} + +/* READ_WRITE_METHODS(int, Bool, */ +/* int value = 0; int ret = fscanf(file->handle, "%d", &value); array[i] = (value ? 1 : 0); if(ret <= 0) break; else result++, */ +/* int value = (array[i] ? 1 : 0); nElemWritten = fprintf(file->handle, "%d", value), */ +/* true) */ + +/* Note that we do a trick */ +READ_WRITE_METHODS(unsigned char, Byte, + nread = fread(data, 1, n, dfself->handle); break, + nwrite = fwrite(data, 1, n, dfself->handle); break) + +READ_WRITE_METHODS(char, Char, + nread = fread(data, 1, n, dfself->handle); break, + nwrite = fwrite(data, 1, n, dfself->handle); break) + +READ_WRITE_METHODS(short, Short, + int ret = fscanf(dfself->handle, "%hd", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%hd", data[i]); if(ret <= 0) break; else nwrite++) + +READ_WRITE_METHODS(int, Int, + int ret = fscanf(dfself->handle, "%d", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%d", data[i]); if(ret <= 0) break; else nwrite++) + +READ_WRITE_METHODS(long, Long, + int ret = fscanf(dfself->handle, "%ld", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%ld", data[i]); if(ret <= 0) break; else nwrite++) + +READ_WRITE_METHODS(float, Float, + int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%g", data[i]); if(ret <= 0) break; else nwrite++) + +READ_WRITE_METHODS(double, Double, + int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%lg", data[i]); if(ret <= 0) break; else nwrite++) + +static long THDiskFile_readString(THFile *self, const char *format, char **str_) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); + THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'"); + +/* note: the string won't survive long, as it is copied into lua */ +/* so 1024 is not that big... */ +#define TBRS_BSZ 1024L + + if(format[1] == 'a') + { + char *p = THAlloc(TBRS_BSZ); + long total = TBRS_BSZ; + long pos = 0L; + + for (;;) + { + if(total-pos == 0) /* we need more space! */ + { + total += TBRS_BSZ; + p = THRealloc(p, total); + } + pos += fread(p+pos, 1, total-pos, dfself->handle); + if (pos < total) /* eof? */ + { + if(pos == 0L) + { + THFree(p); + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("read error: read 0 blocks instead of 1"); + + *str_ = NULL; + return 0; + } + *str_ = p; + return pos; + } + } + } + else + { + char *p = THAlloc(TBRS_BSZ); + long total = TBRS_BSZ; + long pos = 0L; + long size; + + for (;;) + { + if(total-pos <= 1) /* we can only write '\0' in there! */ + { + total += TBRS_BSZ; + p = THRealloc(p, total); + } + if (fgets(p+pos, total-pos, dfself->handle) == NULL) /* eof? */ + { + if(pos == 0L) + { + THFree(p); + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("read error: read 0 blocks instead of 1"); + + *str_ = NULL; + return 0; + } + *str_ = p; + return pos; + } + size = strlen(p+pos); + if (size == 0L || (p+pos)[size-1] != '\n') + { + pos += size; + } + else + { + pos += size-1L; /* do not include `eol' */ + *str_ = p; + return pos; + } + } + } + + *str_ = NULL; + return 0; +} + + +static long THDiskFile_writeString(THFile *self, const char *str, long size) +{ + THDiskFile *dfself = (THDiskFile*)(self); + long nwrite; + + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); + + nwrite = fwrite(str, 1, size, dfself->handle); + if(nwrite != size) + { + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("write error: wrote %ld blocks instead of %ld", nwrite, size); + } + + return nwrite; +} + +THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) +{ + static struct THFileVTable vtable = { + THDiskFile_isOpened, + + THDiskFile_readByte, + THDiskFile_readChar, + THDiskFile_readShort, + THDiskFile_readInt, + THDiskFile_readLong, + THDiskFile_readFloat, + THDiskFile_readDouble, + THDiskFile_readString, + + THDiskFile_writeByte, + THDiskFile_writeChar, + THDiskFile_writeShort, + THDiskFile_writeInt, + THDiskFile_writeLong, + THDiskFile_writeFloat, + THDiskFile_writeDouble, + THDiskFile_writeString, + + THDiskFile_synchronize, + THDiskFile_seek, + THDiskFile_seekEnd, + THDiskFile_position, + THDiskFile_close, + THDiskFile_free + }; + + int isReadable; + int isWritable; + FILE *handle; + THDiskFile *self; + + THArgCheck(THDiskFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); + + if( isReadable && isWritable ) + { + handle = fopen(name, "r+b"); + if(!handle) + { + handle = fopen(name, "wb"); + if(handle) + { + fclose(handle); + handle = fopen(name, "r+b"); + } + } + } + else + handle = fopen(name, (isReadable ? "rb" : "wb")); + + if(!handle) + { + if(isQuiet) + return 0; + else + THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); + } + + self = THAlloc(sizeof(THDiskFile)); + + self->handle = handle; + self->name = THAlloc(strlen(name)+1); + strcpy(self->name, name); + self->isNativeEncoding = 1; + + self->file.vtable = &vtable; + self->file.isQuiet = isQuiet; + self->file.isReadable = isReadable; + self->file.isWritable = isWritable; + self->file.isBinary = 0; + self->file.isAutoSpacing = 1; + self->file.hasError = 0; + + return (THFile*)self; +} + +/* PipeFile */ + +static int THPipeFile_mode(const char *mode, int *isReadable, int *isWritable) +{ + *isReadable = 0; + *isWritable = 0; + if(strlen(mode) == 1) + { + if(*mode == 'r') + { + *isReadable = 1; + return 1; + } + else if(*mode == 'w') + { + *isWritable = 1; + return 1; + } + } + return 0; +} + +static void THPipeFile_free(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + if(dfself->handle) + pclose(dfself->handle); + THFree(dfself->name); + THFree(dfself); +} + +THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) +{ + static struct THFileVTable vtable = { + THDiskFile_isOpened, + + THDiskFile_readByte, + THDiskFile_readChar, + THDiskFile_readShort, + THDiskFile_readInt, + THDiskFile_readLong, + THDiskFile_readFloat, + THDiskFile_readDouble, + THDiskFile_readString, + + THDiskFile_writeByte, + THDiskFile_writeChar, + THDiskFile_writeShort, + THDiskFile_writeInt, + THDiskFile_writeLong, + THDiskFile_writeFloat, + THDiskFile_writeDouble, + THDiskFile_writeString, + + THDiskFile_synchronize, + THDiskFile_seek, + THDiskFile_seekEnd, + THDiskFile_position, + THDiskFile_close, + THPipeFile_free + }; + + int isReadable; + int isWritable; + FILE *handle; + THDiskFile *self; + + THArgCheck(THPipeFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w'"); + +#ifdef _WIN32 + handle = popen(name, (isReadable ? "rb" : "wb")); +#else + handle = popen(name, (isReadable ? "r" : "w")); +#endif + + if(!handle) + { + if(isQuiet) + return 0; + else + THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); + } + + self = THAlloc(sizeof(THDiskFile)); + + self->handle = handle; + self->name = THAlloc(strlen(name)+1); + strcpy(self->name, name); + self->isNativeEncoding = 1; + + self->file.vtable = &vtable; + self->file.isQuiet = isQuiet; + self->file.isReadable = isReadable; + self->file.isWritable = isWritable; + self->file.isBinary = 0; + self->file.isAutoSpacing = 1; + self->file.hasError = 0; + + return (THFile*)self; +} diff --git a/lib/TH/THDiskFile.h b/lib/TH/THDiskFile.h new file mode 100644 index 00000000000..69f21941ab3 --- /dev/null +++ b/lib/TH/THDiskFile.h @@ -0,0 +1,17 @@ +#ifndef TH_DISK_FILE_INC +#define TH_DISK_FILE_INC + +#include "THFile.h" + +THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet); +THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet); + +const char *THDiskFile_name(THFile *self); + +int THDiskFile_isLittleEndianCPU(void); +int THDiskFile_isBigEndianCPU(void); +void THDiskFile_nativeEndianEncoding(THFile *self); +void THDiskFile_littleEndianEncoding(THFile *self); +void THDiskFile_bigEndianEncoding(THFile *self); + +#endif diff --git a/lib/TH/THFile.c b/lib/TH/THFile.c new file mode 100644 index 00000000000..6d15e412817 --- /dev/null +++ b/lib/TH/THFile.c @@ -0,0 +1,154 @@ +#include "THFile.h" +#include "THFilePrivate.h" + +#define IMPLEMENT_THFILE_RW(TYPEC, TYPE) \ + long THFile_read##TYPEC##Raw(THFile *self, TYPE *data, long n) \ + { \ + return (*self->vtable->read##TYPEC)(self, data, n); \ + } \ + \ + long THFile_write##TYPEC##Raw(THFile *self, TYPE *data, long n) \ + { \ + return (*self->vtable->write##TYPEC)(self, data, n); \ + } + +IMPLEMENT_THFILE_RW(Byte, unsigned char) +IMPLEMENT_THFILE_RW(Char, char) +IMPLEMENT_THFILE_RW(Short, short) +IMPLEMENT_THFILE_RW(Int, int) +IMPLEMENT_THFILE_RW(Long, long) +IMPLEMENT_THFILE_RW(Float, float) +IMPLEMENT_THFILE_RW(Double, double) + +long THFile_readStringRaw(THFile *self, const char *format, char **str_) +{ + return self->vtable->readString(self, format, str_); +} + +long THFile_writeStringRaw(THFile *self, const char *str, long size) +{ + return self->vtable->writeString(self, str, size); +} + +void THFile_synchronize(THFile *self) +{ + self->vtable->synchronize(self); +} + +void THFile_seek(THFile *self, long position) +{ + self->vtable->seek(self, position); +} + +void THFile_seekEnd(THFile *self) +{ + self->vtable->seekEnd(self); +} + +long THFile_position(THFile *self) +{ + return self->vtable->position(self); +} + +void THFile_close(THFile *self) +{ + self->vtable->close(self); +} + +void THFile_free(THFile *self) +{ + self->vtable->free(self); +} + +int THFile_isOpened(THFile *self) +{ + return self->vtable->isOpened(self); +} + +#define IMPLEMENT_THFILE_FLAGS(FLAG) \ + int THFile_##FLAG(THFile *self) \ + { \ + return self->FLAG; \ + } + +IMPLEMENT_THFILE_FLAGS(isQuiet) +IMPLEMENT_THFILE_FLAGS(isReadable) +IMPLEMENT_THFILE_FLAGS(isWritable) +IMPLEMENT_THFILE_FLAGS(isBinary) +IMPLEMENT_THFILE_FLAGS(isAutoSpacing) +IMPLEMENT_THFILE_FLAGS(hasError) + +void THFile_binary(THFile *self) +{ + self->isBinary = 1; +} + +void THFile_ascii(THFile *self) +{ + self->isBinary = 0; +} + +void THFile_autoSpacing(THFile *self) +{ + self->isAutoSpacing = 1; +} + +void THFile_noAutoSpacing(THFile *self) +{ + self->isAutoSpacing = 0; +} + +void THFile_quiet(THFile *self) +{ + self->isQuiet = 1; +} + +void THFile_pedantic(THFile *self) +{ + self->isQuiet = 0; +} + +void THFile_clearError(THFile *self) +{ + self->hasError = 0; +} + +#define IMPLEMENT_THFILE_SCALAR(TYPEC, TYPE) \ + TYPE THFile_read##TYPEC##Scalar(THFile *self) \ + { \ + TYPE scalar; \ + THFile_read##TYPEC##Raw(self, &scalar, 1); \ + return scalar; \ + } \ + \ + void THFile_write##TYPEC##Scalar(THFile *self, TYPE scalar) \ + { \ + THFile_write##TYPEC##Raw(self, &scalar, 1); \ + } + +IMPLEMENT_THFILE_SCALAR(Byte, unsigned char) +IMPLEMENT_THFILE_SCALAR(Char, char) +IMPLEMENT_THFILE_SCALAR(Short, short) +IMPLEMENT_THFILE_SCALAR(Int, int) +IMPLEMENT_THFILE_SCALAR(Long, long) +IMPLEMENT_THFILE_SCALAR(Float, float) +IMPLEMENT_THFILE_SCALAR(Double, double) + +#define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \ + long THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ + { \ + return THFile_read##TYPEC##Raw(self, storage->data, storage->size); \ + } \ + \ + long THFile_write##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ + { \ + return THFile_write##TYPEC##Raw(self, storage->data, storage->size); \ + } + +IMPLEMENT_THFILE_STORAGE(Byte, unsigned char) +IMPLEMENT_THFILE_STORAGE(Char, char) +IMPLEMENT_THFILE_STORAGE(Short, short) +IMPLEMENT_THFILE_STORAGE(Int, int) +IMPLEMENT_THFILE_STORAGE(Long, long) +IMPLEMENT_THFILE_STORAGE(Float, float) +IMPLEMENT_THFILE_STORAGE(Double, double) diff --git a/lib/TH/THFile.h b/lib/TH/THFile.h new file mode 100644 index 00000000000..7571e4d1a80 --- /dev/null +++ b/lib/TH/THFile.h @@ -0,0 +1,84 @@ +#ifndef TH_FILE_INC +#define TH_FILE_INC + +#include "THStorage.h" + +typedef struct THFile__ THFile; + +int THFile_isOpened(THFile *self); +int THFile_isQuiet(THFile *self); +int THFile_isReadable(THFile *self); +int THFile_isWritable(THFile *self); +int THFile_isBinary(THFile *self); +int THFile_isAutoSpacing(THFile *self); +int THFile_hasError(THFile *self); + +void THFile_binary(THFile *self); +void THFile_ascii(THFile *self); +void THFile_autoSpacing(THFile *self); +void THFile_noAutoSpacing(THFile *self); +void THFile_quiet(THFile *self); +void THFile_pedantic(THFile *self); +void THFile_clearError(THFile *self); + +/* scalar */ +unsigned char THFile_readByteScalar(THFile *self); +char THFile_readCharScalar(THFile *self); +short THFile_readShortScalar(THFile *self); +int THFile_readIntScalar(THFile *self); +long THFile_readLongScalar(THFile *self); +float THFile_readFloatScalar(THFile *self); +double THFile_readDoubleScalar(THFile *self); + +void THFile_writeByteScalar(THFile *self, unsigned char scalar); +void THFile_writeCharScalar(THFile *self, char scalar); +void THFile_writeShortScalar(THFile *self, short scalar); +void THFile_writeIntScalar(THFile *self, int scalar); +void THFile_writeLongScalar(THFile *self, long scalar); +void THFile_writeFloatScalar(THFile *self, float scalar); +void THFile_writeDoubleScalar(THFile *self, double scalar); + +/* storage */ +long THFile_readByte(THFile *self, THByteStorage *storage); +long THFile_readChar(THFile *self, THCharStorage *storage); +long THFile_readShort(THFile *self, THShortStorage *storage); +long THFile_readInt(THFile *self, THIntStorage *storage); +long THFile_readLong(THFile *self, THLongStorage *storage); +long THFile_readFloat(THFile *self, THFloatStorage *storage); +long THFile_readDouble(THFile *self, THDoubleStorage *storage); + +long THFile_writeByte(THFile *self, THByteStorage *storage); +long THFile_writeChar(THFile *self, THCharStorage *storage); +long THFile_writeShort(THFile *self, THShortStorage *storage); +long THFile_writeInt(THFile *self, THIntStorage *storage); +long THFile_writeLong(THFile *self, THLongStorage *storage); +long THFile_writeFloat(THFile *self, THFloatStorage *storage); +long THFile_writeDouble(THFile *self, THDoubleStorage *storage); + +/* raw */ +long THFile_readByteRaw(THFile *self, unsigned char *data, long n); +long THFile_readCharRaw(THFile *self, char *data, long n); +long THFile_readShortRaw(THFile *self, short *data, long n); +long THFile_readIntRaw(THFile *self, int *data, long n); +long THFile_readLongRaw(THFile *self, long *data, long n); +long THFile_readFloatRaw(THFile *self, float *data, long n); +long THFile_readDoubleRaw(THFile *self, double *data, long n); +long THFile_readStringRaw(THFile *self, const char *format, char **str_); /* you must deallocate str_ */ + +long THFile_writeByteRaw(THFile *self, unsigned char *data, long n); +long THFile_writeCharRaw(THFile *self, char *data, long n); +long THFile_writeShortRaw(THFile *self, short *data, long n); +long THFile_writeIntRaw(THFile *self, int *data, long n); +long THFile_writeLongRaw(THFile *self, long *data, long n); +long THFile_writeFloatRaw(THFile *self, float *data, long n); +long THFile_writeDoubleRaw(THFile *self, double *data, long n); +long THFile_writeStringRaw(THFile *self, const char *str, long size); + +void THFile_synchronize(THFile *self); +void THFile_seek(THFile *self, long position); +void THFile_seekEnd(THFile *self); +long THFile_position(THFile *self); +void THFile_close(THFile *self); +void THFile_free(THFile *self); + +#endif diff --git a/lib/TH/THFilePrivate.h b/lib/TH/THFilePrivate.h new file mode 100644 index 00000000000..9097fb9798e --- /dev/null +++ b/lib/TH/THFilePrivate.h @@ -0,0 +1,43 @@ +struct THFile__ +{ + struct THFileVTable *vtable; + + int isQuiet; + int isReadable; + int isWritable; + int isBinary; + int isAutoSpacing; + int hasError; +}; + +/* virtual table definition */ + +struct THFileVTable +{ + int (*isOpened)(THFile *self); + + long (*readByte)(THFile *self, unsigned char *data, long n); + long (*readChar)(THFile *self, char *data, long n); + long (*readShort)(THFile *self, short *data, long n); + long (*readInt)(THFile *self, int *data, long n); + long (*readLong)(THFile *self, long *data, long n); + long (*readFloat)(THFile *self, float *data, long n); + long (*readDouble)(THFile *self, double *data, long n); + long (*readString)(THFile *self, const char *format, char **str_); + + long (*writeByte)(THFile *self, unsigned char *data, long n); + long (*writeChar)(THFile *self, char *data, long n); + long (*writeShort)(THFile *self, short *data, long n); + long (*writeInt)(THFile *self, int *data, long n); + long (*writeLong)(THFile *self, long *data, long n); + long (*writeFloat)(THFile *self, float *data, long n); + long (*writeDouble)(THFile *self, double *data, long n); + long (*writeString)(THFile *self, const char *str, long size); + + void (*synchronize)(THFile *self); + void (*seek)(THFile *self, long position); + void (*seekEnd)(THFile *self); + long (*position)(THFile *self); + void (*close)(THFile *self); + void (*free)(THFile *self); +}; diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c new file mode 100644 index 00000000000..f4007a3b36b --- /dev/null +++ b/lib/TH/THGeneral.c @@ -0,0 +1,110 @@ +#include "THGeneral.h" + +/* Torch Error Handling */ +static void defaultTorchErrorHandlerFunction(const char *msg) +{ + printf("$ Error: %s\n", msg); + exit(-1); +} + +static void (*torchErrorHandlerFunction)(const char *msg) = defaultTorchErrorHandlerFunction; + +void THError(const char *fmt, ...) +{ + static char msg[1024]; + va_list args; + + /* vasprintf not standard */ + /* vsnprintf: how to handle if does not exists? */ + va_start(args, fmt); + vsnprintf(msg, 1024, fmt, args); + va_end(args); + + (*torchErrorHandlerFunction)(msg); +} + +void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg) ) +{ + if(torchErrorHandlerFunction_) + torchErrorHandlerFunction = torchErrorHandlerFunction_; + else + torchErrorHandlerFunction = defaultTorchErrorHandlerFunction; +} + +/* Torch Arg Checking Handling */ +static void defaultTorchArgCheckHandlerFunction(int condition, int argNumber, const char *msg) +{ + if(!condition) + { + if(msg) + printf("$ Invalid argument %d: %s\n", argNumber, msg); + else + printf("$ Invalid argument %d\n", argNumber); + exit(-1); + } +} +static void (*torchArgCheckHandlerFunction)(int condition, int argNumber, const char *msg) = defaultTorchArgCheckHandlerFunction; + +void THArgCheck(int condition, int argNumber, const char *msg) +{ + (*torchArgCheckHandlerFunction)(condition, argNumber, msg); +} + +void THSetArgCheckHandler( void (*torchArgCheckHandlerFunction_)(int condition, int argNumber, const char *msg) ) +{ + if(torchArgCheckHandlerFunction_) + torchArgCheckHandlerFunction = torchArgCheckHandlerFunction_; + else + torchArgCheckHandlerFunction = defaultTorchArgCheckHandlerFunction; +} + +void* THAlloc(long size) +{ + void *ptr; + + if(size < 0) + THError("$ Torch: invalid memory size -- maybe an overflow?"); + + if(size == 0) + return NULL; + + ptr = malloc(size); + if(!ptr) + THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824); + + return ptr; +} + +void* THRealloc(void *ptr, long size) +{ + if(!ptr) + return(THAlloc(size)); + + if(size == 0) + { + THFree(ptr); + return NULL; + } + + if(size < 0) + THError("$ Torch: invalid memory size -- maybe an overflow?"); + + ptr = realloc(ptr, size); + if(!ptr) + THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824); + return ptr; +} + +void THFree(void *ptr) +{ + free(ptr); +} + +#ifdef _MSC_VER +double log1p(const double x) +{ + volatile double y; + y = 1 + x; + return log(y) - ((y-1)-x)/y ; /* cancels errors with IEEE arithmetic */ +} +#endif diff --git a/lib/TH/THGeneral.h b/lib/TH/THGeneral.h new file mode 100644 index 00000000000..33c8cd65ded --- /dev/null +++ b/lib/TH/THGeneral.h @@ -0,0 +1,72 @@ +#ifndef TH_GENERAL_INC +#define TH_GENERAL_INC + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __cplusplus +# define TH_EXTERNC extern "C" +#else +# define TH_EXTERNC extern +#endif + +#ifdef WIN32 +# ifdef TH_EXPORTS +# define TH_API TH_EXTERNC __declspec(dllexport) +# else +# define TH_API TH_EXTERNC __declspec(dllimport) +# endif +#else +# define TH_API TH_EXTERNC +#endif + +#define THInf DBL_MAX + +#if !defined(inline) +# define inline +#endif + +#ifndef M_PI +# define M_PI 3.14159265358979323846 +#endif + +#ifdef _MSC_VER +TH_API double log1p(const double x); +#endif + +TH_API void THError(const char *fmt, ...); +TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg) ); +TH_API void THArgCheck(int condition, int argNumber, const char *msg); +TH_API void THSetArgCheckHandler( void (*torchArgCheckHandlerFunction)(int condition, int argNumber, const char *msg) ); +TH_API void* THAlloc(long size); +TH_API void* THRealloc(void *ptr, long size); +TH_API void THFree(void *ptr); + +#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y) +#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y + +#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z) +#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z + +#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w) +#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w + +#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y) +#define TH_CONCAT_2_EXPAND(x,y) x ## y + +#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z) +#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z + +#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w +#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w) + +#define THMin(X, Y) ((X) < (Y) ? (X) : (Y)) +#define THMax(X, Y) ((X) > (Y) ? (X) : (Y)) + +#endif diff --git a/lib/TH/THGenerateAllTypes.h b/lib/TH/THGenerateAllTypes.h new file mode 100644 index 00000000000..54ae1474eae --- /dev/null +++ b/lib/TH/THGenerateAllTypes.h @@ -0,0 +1,83 @@ +#ifndef TH_GENERIC_FILE +#error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h" +#endif + +#define real unsigned char +#define accreal long +#define Real Byte +#define TH_REAL_IS_BYTE +#line 1 TH_GENERIC_FILE +/*#line 1 "THByteStorage.h"*/ +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_BYTE + +#define real char +#define accreal long +#define Real Char +#define TH_REAL_IS_CHAR +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_CHAR + +#define real short +#define accreal long +#define Real Short +#define TH_REAL_IS_SHORT +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_SHORT + +#define real int +#define accreal long +#define Real Int +#define TH_REAL_IS_INT +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_INT + +#define real long +#define accreal long +#define Real Long +#define TH_REAL_IS_LONG +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_LONG + +#define real float +#define accreal double +#define Real Float +#define TH_REAL_IS_FLOAT +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_FLOAT + +#define real double +#define accreal double +#define Real Double +#define TH_REAL_IS_DOUBLE +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_DOUBLE + +#undef TH_GENERIC_FILE diff --git a/lib/TH/THGenerateFloatTypes.h b/lib/TH/THGenerateFloatTypes.h new file mode 100644 index 00000000000..5feea4e2aa2 --- /dev/null +++ b/lib/TH/THGenerateFloatTypes.h @@ -0,0 +1,27 @@ +#ifndef TH_GENERIC_FILE +#error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h" +#endif + +#define real float +#define accreal double +#define Real Float +#define TH_REAL_IS_FLOAT +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef accreal +#undef real +#undef Real +#undef TH_REAL_IS_FLOAT + +#define real double +#define accreal double +#define Real Double +#define TH_REAL_IS_DOUBLE +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef accreal +#undef real +#undef Real +#undef TH_REAL_IS_DOUBLE + +#undef TH_GENERIC_FILE diff --git a/lib/TH/THGenerateIntTypes.h b/lib/TH/THGenerateIntTypes.h new file mode 100644 index 00000000000..d340b0eba8c --- /dev/null +++ b/lib/TH/THGenerateIntTypes.h @@ -0,0 +1,60 @@ +#ifndef TH_GENERIC_FILE +#error "You must define TH_GENERIC_FILE before including THGenerateIntTypes.h" +#endif + +#define real unsigned char +#define accreal long +#define Real Byte +#define TH_REAL_IS_BYTE +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_BYTE + +#define real char +#define accreal long +#define Real Char +#define TH_REAL_IS_CHAR +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_CHAR + +#define real short +#define accreal long +#define Real Short +#define TH_REAL_IS_SHORT +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_SHORT + +#define real int +#define accreal long +#define Real Int +#define TH_REAL_IS_INT +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_INT + +#define real long +#define accreal long +#define Real Long +#define TH_REAL_IS_LONG +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef TH_REAL_IS_LONG + +#undef TH_GENERIC_FILE diff --git a/lib/TH/THLapack.c b/lib/TH/THLapack.c new file mode 100644 index 00000000000..01bdb5ff161 --- /dev/null +++ b/lib/TH/THLapack.c @@ -0,0 +1,5 @@ +#include "THLapack.h" + +/* #include "THCBlas.h" */ +#include "generic/THLapack.c" +#include "THGenerateFloatTypes.h" diff --git a/lib/TH/THLapack.h b/lib/TH/THLapack.h new file mode 100644 index 00000000000..bed8ae2bf61 --- /dev/null +++ b/lib/TH/THLapack.h @@ -0,0 +1,11 @@ +#ifndef TH_LAPACK_INC +#define TH_LAPACK_INC + +#include "THGeneral.h" + +#define THLapack_(NAME) TH_CONCAT_4(TH,Real,Lapack_,NAME) + +#include "generic/THLapack.h" +#include "THGenerateAllTypes.h" + +#endif diff --git a/lib/TH/THLogAdd.c b/lib/TH/THLogAdd.c new file mode 100644 index 00000000000..542b9b074f5 --- /dev/null +++ b/lib/TH/THLogAdd.c @@ -0,0 +1,86 @@ +#include "THLogAdd.h" + +#ifdef USE_DOUBLE +#define MINUS_LOG_THRESHOLD -39.14 +#else +#define MINUS_LOG_THRESHOLD -18.42 +#endif + +const double THLog2Pi=1.83787706640934548355; +const double THLogZero=-THInf; +const double THLogOne=0; + +double THLogAdd(double log_a, double log_b) +{ + double minusdif; + + if (log_a < log_b) + { + double tmp = log_a; + log_a = log_b; + log_b = tmp; + } + + minusdif = log_b - log_a; +#ifdef DEBUG + if (isnan(minusdif)) + THError("THLogAdd: minusdif (%f) log_b (%f) or log_a (%f) is nan", minusdif, log_b, log_a); +#endif + if (minusdif < MINUS_LOG_THRESHOLD) + return log_a; + else + return log_a + log1p(exp(minusdif)); +} + +double THLogSub(double log_a, double log_b) +{ + double minusdif; + + if (log_a < log_b) + THError("LogSub: log_a (%f) should be greater than log_b (%f)", log_a, log_b); + + minusdif = log_b - log_a; +#ifdef DEBUG + if (isnan(minusdif)) + THError("LogSub: minusdif (%f) log_b (%f) or log_a (%f) is nan", minusdif, log_b, log_a); +#endif + if (log_a == log_b) + return THLogZero; + else if (minusdif < MINUS_LOG_THRESHOLD) + return log_a; + else + return log_a + log1p(-exp(minusdif)); +} + +/* Credits to Leon Bottou */ +double THExpMinusApprox(double x) +{ +#define EXACT_EXPONENTIAL 0 +#if EXACT_EXPONENTIAL + return exp(-x); +#else + /* fast approximation of exp(-x) for x positive */ +# define A0 (1.0) +# define A1 (0.125) +# define A2 (0.0078125) +# define A3 (0.00032552083) +# define A4 (1.0172526e-5) + if (x < 13.0) + { +/* assert(x>=0); */ + double y; + y = A0+x*(A1+x*(A2+x*(A3+x*A4))); + y *= y; + y *= y; + y *= y; + y = 1/y; + return y; + } + return 0; +# undef A0 +# undef A1 +# undef A2 +# undef A3 +# undef A4 +#endif +} diff --git a/lib/TH/THLogAdd.h b/lib/TH/THLogAdd.h new file mode 100644 index 00000000000..9319b8f4643 --- /dev/null +++ b/lib/TH/THLogAdd.h @@ -0,0 +1,14 @@ +#ifndef TH_LOG_ADD_INC +#define TH_LOG_ADD_INC + +#include "THGeneral.h" + +TH_API const double THLog2Pi; +TH_API const double THLogZero; +TH_API const double THLogOne; + +TH_API double THLogAdd(double log_a, double log_b); +TH_API double THLogSub(double log_a, double log_b); +TH_API double THExpMinusApprox(const double x); + +#endif diff --git a/lib/TH/THMemoryFile.c b/lib/TH/THMemoryFile.c new file mode 100644 index 00000000000..29f04e8ff36 --- /dev/null +++ b/lib/TH/THMemoryFile.c @@ -0,0 +1,492 @@ +#include "THMemoryFile.h" +#include "THFilePrivate.h" + +typedef struct THMemoryFile__ +{ + THFile file; + THCharStorage *storage; + long size; + long position; + +} THMemoryFile; + +static int THMemoryFile_isOpened(THFile *self) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + return (mfself->storage != NULL); +} + +static char *THMemoryFile_strnextspace(char *str_, char *c_) +{ + char c; + + while( (c = *str_) ) + { + if( (c != ' ') && (c != '\n') && (c != ':') && (c != ';') ) + break; + str_++; + } + + while( (c = *str_) ) + { + if( (c == ' ') || (c == '\n') || (c == ':') || (c == ';') ) + { + *c_ = c; + *str_ = '\0'; + return(str_); + } + str_++; + } + return NULL; +} + +static void THMemoryFile_grow(THMemoryFile *self, long size) +{ + long missingSpace; + + if(size <= self->size) + return; + else + { + if(size < self->storage->size) /* note the "<" and not "<=" */ + { + self->size = size; + self->storage->data[self->size] = '\0'; + return; + } + } + + missingSpace = size-self->storage->size+1; /* +1 for the '\0' */ + THCharStorage_resize(self->storage, (self->storage->size/2 > missingSpace ? + self->storage->size + (self->storage->size/2) + : self->storage->size + missingSpace)); +} + +static int THMemoryFile_mode(const char *mode, int *isReadable, int *isWritable) +{ + *isReadable = 0; + *isWritable = 0; + if(strlen(mode) == 1) + { + if(*mode == 'r') + { + *isReadable = 1; + return 1; + } + else if(*mode == 'w') + { + *isWritable = 1; + return 1; + } + } + else if(strlen(mode) == 2) + { + if(mode[0] == 'r' && mode[1] == 'w') + { + *isReadable = 1; + *isWritable = 1; + return 1; + } + } + return 0; +} + +/********************************************************/ + +#define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM, INSIDE_SPACING) \ + static long THMemoryFile_read##TYPEC(THFile *self, TYPE *data, long n) \ + { \ + THMemoryFile *mfself = (THMemoryFile*)self; \ + long nread = 0L; \ + \ + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); \ + THArgCheck(mfself->file.isReadable, 1, "attempt to read in a write-only file"); \ + \ + if(mfself->file.isBinary) \ + { \ + long nByte = sizeof(TYPE)*n; \ + long nByteRemaining = (mfself->position + nByte <= mfself->size ? nByte : mfself->size-mfself->position); \ + nread = nByteRemaining/sizeof(TYPE); \ + memmove(data, mfself->storage->data+mfself->position, nread*sizeof(TYPE)); \ + mfself->position += nread*sizeof(TYPE); \ + } \ + else \ + { \ + long i; \ + for(i = 0; i < n; i++) \ + { \ + long nByteRead = 0; \ + char spaceChar = 0; \ + char *spacePtr = THMemoryFile_strnextspace(mfself->storage->data+mfself->position, &spaceChar); \ + ASCII_READ_ELEM; \ + if(ret == EOF) \ + { \ + while(mfself->storage->data[mfself->position]) \ + mfself->position++; \ + } \ + else \ + mfself->position += nByteRead; \ + if(spacePtr) \ + *spacePtr = spaceChar; \ + } \ + if(mfself->file.isAutoSpacing && (n > 0)) \ + { \ + if( (mfself->position < mfself->size) && (mfself->storage->data[mfself->position] == '\n') ) \ + mfself->position++; \ + } \ + } \ + \ + if(nread != n) \ + { \ + mfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \ + if(!mfself->file.isQuiet) \ + THError("read error: read %d blocks instead of %d", nread, n); \ + } \ + \ + return nread; \ + } \ + \ + static long THMemoryFile_write##TYPEC(THFile *self, TYPE *data, long n) \ + { \ + THMemoryFile *mfself = (THMemoryFile*)self; \ + long nread = 0L; \ + \ + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); \ + THArgCheck(mfself->file.isWritable, 1, "attempt to write in a read-only file"); \ + \ + if(mfself->file.isBinary) \ + { \ + long nByte = sizeof(TYPE)*n; \ + THMemoryFile_grow(mfself, mfself->position+nByte); \ + memmove(mfself->storage->data+mfself->position, data, nByte); \ + mfself->position += nByte; \ + if(mfself->position > mfself->size) \ + { \ + mfself->size = mfself->position; \ + mfself->storage->data[mfself->size] = '\0'; \ + } \ + } \ + else \ + { \ + long i; \ + for(i = 0; i < n; i++) \ + { \ + long nByteWritten; \ + while (1) \ + { \ + ASCII_WRITE_ELEM; \ + if( (nByteWritten > -1) && (nByteWritten < mfself->storage->size-mfself->position) ) \ + { \ + mfself->position += nByteWritten; \ + break; \ + } \ + THMemoryFile_grow(mfself, mfself->storage->size + (mfself->storage->size/2) + 2); \ + } \ + if(mfself->file.isAutoSpacing) \ + { \ + if(i < n-1) \ + { \ + THMemoryFile_grow(mfself, mfself->position+1); \ + sprintf(mfself->storage->data+mfself->position, " "); \ + mfself->position++; \ + } \ + if(i == n-1) \ + { \ + THMemoryFile_grow(mfself, mfself->position+1); \ + sprintf(mfself->storage->data+mfself->position, "\n"); \ + mfself->position++; \ + } \ + } \ + } \ + if(mfself->position > mfself->size) \ + { \ + mfself->size = mfself->position; \ + mfself->storage->data[mfself->size] = '\0'; \ + } \ + } \ + \ + return n; \ + } + + +THCharStorage *THMemoryFile_storage(THFile *self) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); + + THCharStorage_resize(mfself->storage, mfself->size+1); + + return mfself->storage; +} + +static void THMemoryFile_synchronize(THFile *self) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); +} + +static void THMemoryFile_seek(THFile *self, long position) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); + THArgCheck(position >= 0, 2, "position must be positive"); + + if(position <= mfself->size) + mfself->position = position; + else + { + mfself->file.hasError = 1; + if(!mfself->file.isQuiet) + THError("unable to seek at position %d", position); + } +} + +static void THMemoryFile_seekEnd(THFile *self) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); + + mfself->position = mfself->size; +} + +static long THMemoryFile_position(THFile *self) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); + return mfself->position; +} + +static void THMemoryFile_close(THFile *self) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); + THCharStorage_free(mfself->storage); + mfself->storage = NULL; +} + +static void THMemoryFile_free(THFile *self) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + + if(mfself->storage) + THCharStorage_free(mfself->storage); + + THFree(mfself); +} + +/* READ_WRITE_METHODS(bool, Bool, */ +/* int value = 0; int ret = sscanf(mfself->storage->data+mfself->position, "%d%n", &value, &nByteRead); data[i] = (value ? 1 : 0), */ +/* int value = (data[i] ? 1 : 0); nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%d", value), */ +/* 1) */ + +READ_WRITE_METHODS(unsigned char, Byte, + long ret = (mfself->position + n <= mfself->size ? n : mfself->size-mfself->position); \ + if(spacePtr) *spacePtr = spaceChar; \ + nByteRead = ret; \ + nread = ret; \ + i = n-1; \ + memmove(data, mfself->storage->data+mfself->position, nByteRead), + nByteWritten = (n < mfself->storage->size-mfself->position ? n : -1); \ + i = n-1; \ + if(nByteWritten > -1) + memmove(mfself->storage->data+mfself->position, data, nByteWritten), + 0) + +/* DEBUG: we should check if %n is count or not as a element (so ret might need to be ret-- on some systems) */ +/* Note that we do a trick for char */ +READ_WRITE_METHODS(char, Char, + long ret = (mfself->position + n <= mfself->size ? n : mfself->size-mfself->position); \ + if(spacePtr) *spacePtr = spaceChar; \ + nByteRead = ret; \ + nread = ret; \ + i = n-1; \ + memmove(data, mfself->storage->data+mfself->position, nByteRead), + nByteWritten = (n < mfself->storage->size-mfself->position ? n : -1); \ + i = n-1; \ + if(nByteWritten > -1) + memmove(mfself->storage->data+mfself->position, data, nByteWritten), + 0) + +READ_WRITE_METHODS(short, Short, + int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%hd%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, + nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%hd", data[i]), + 1) + +READ_WRITE_METHODS(int, Int, + int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%d%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, + nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%d", data[i]), + 1) + +READ_WRITE_METHODS(long, Long, + int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%ld%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, + nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%ld", data[i]), + 1) + +READ_WRITE_METHODS(float, Float, + int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, + nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%g", data[i]), + 1) + +READ_WRITE_METHODS(double, Double, + int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%lg%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, + nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%lg", data[i]), + 1) + +static char* THMemoryFile_cloneString(const char *str, long size) +{ + char *cstr = THAlloc(size); + memcpy(cstr, str, size); + return cstr; +} + +static long THMemoryFile_readString(THFile *self, const char *format, char **str_) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); + THArgCheck(mfself->file.isReadable, 1, "attempt to read in a write-only file"); + THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'"); + + if(mfself->position == mfself->size) /* eof ? */ + { + mfself->file.hasError = 1; + if(!mfself->file.isQuiet) + THError("read error: read 0 blocks instead of 1"); + + *str_ = NULL; + return 0; + } + + if(format[1] == 'a') + { + long str_size = mfself->size-mfself->position; + + *str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, str_size); + mfself->position = mfself->size; + + return str_size; + } + else + { + char *p = mfself->storage->data+mfself->position; + long posEol = -1; + long i; + for(i = 0L; i < mfself->size-mfself->position; i++) + { + if(p[i] == '\n') + { + posEol = i; + break; + } + } + + if(posEol >= 0) + { + *str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, posEol); + mfself->position += posEol+1; + return posEol; + } + else /* well, we read all! */ + { + long str_size = mfself->size-mfself->position; + + *str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, str_size); + mfself->position = mfself->size; + + return str_size; + } + } + + *str_ = NULL; + return 0; +} + +static long THMemoryFile_writeString(THFile *self, const char *str, long size) +{ + THMemoryFile *mfself = (THMemoryFile*)self; + + THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); + THArgCheck(mfself->file.isWritable, 1, "attempt to write in a read-only file"); + + THMemoryFile_grow(mfself, mfself->position+size); + memmove(mfself->storage->data+mfself->position, str, size); + mfself->position += size; + if(mfself->position > mfself->size) + { + mfself->size = mfself->position; + mfself->storage->data[mfself->size] = '\0'; + } + + return size; +} + +THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) +{ + static struct THFileVTable vtable = { + THMemoryFile_isOpened, + + THMemoryFile_readByte, + THMemoryFile_readChar, + THMemoryFile_readShort, + THMemoryFile_readInt, + THMemoryFile_readLong, + THMemoryFile_readFloat, + THMemoryFile_readDouble, + THMemoryFile_readString, + + THMemoryFile_writeByte, + THMemoryFile_writeChar, + THMemoryFile_writeShort, + THMemoryFile_writeInt, + THMemoryFile_writeLong, + THMemoryFile_writeFloat, + THMemoryFile_writeDouble, + THMemoryFile_writeString, + + THMemoryFile_synchronize, + THMemoryFile_seek, + THMemoryFile_seekEnd, + THMemoryFile_position, + THMemoryFile_close, + THMemoryFile_free + }; + + THMemoryFile *mfself; + int isReadable; + int isWritable; + + if(storage) + { + THArgCheck(storage->data[storage->size-1] == '\0', 1, "provided CharStorage must be terminated by 0"); + THArgCheck(THMemoryFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); + THCharStorage_retain(storage); + } + else + { + THArgCheck(THMemoryFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); + storage = THCharStorage_newWithSize(1); + storage->data[0] = '\0'; + } + + mfself = THAlloc(sizeof(THMemoryFile)); + + mfself->storage = storage; + mfself->size = (storage ? storage->size-1 : 0); + mfself->position = 0; + + mfself->file.vtable = &vtable; + mfself->file.isQuiet = 0; + mfself->file.isReadable = isReadable; + mfself->file.isWritable = isWritable; + mfself->file.isBinary = 0; + mfself->file.isAutoSpacing = 1; + mfself->file.hasError = 0; + + return (THFile*)mfself; +} + +THFile *THMemoryFile_new(const char *mode) +{ + return THMemoryFile_newWithStorage(NULL, mode); +} diff --git a/lib/TH/THMemoryFile.h b/lib/TH/THMemoryFile.h new file mode 100644 index 00000000000..48871f04e23 --- /dev/null +++ b/lib/TH/THMemoryFile.h @@ -0,0 +1,12 @@ +#ifndef TH_MEMORY_FILE_INC +#define TH_MEMORY_FILE_INC + +#include "THFile.h" +#include "THStorage.h" + +THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode); +THFile *THMemoryFile_new(const char *mode); + +THCharStorage *THMemoryFile_storage(THFile *self); + +#endif diff --git a/lib/TH/THRandom.c b/lib/TH/THRandom.c new file mode 100644 index 00000000000..d34ed70982a --- /dev/null +++ b/lib/TH/THRandom.c @@ -0,0 +1,238 @@ +#include "THGeneral.h" +#include "THRandom.h" + +/* The initial seed. */ +static unsigned long the_initial_seed; + +/* Code for the Mersenne Twister random generator.... */ +#define n 624 +#define m 397 +static int left = 1; +static int initf = 0; +static unsigned long *next; +static unsigned long state[n]; /* the array for the state vector */ +/********************************/ + +/* For normal distribution */ +static double normal_x; +static double normal_y; +static double normal_rho; +static int normal_is_valid = 0; + +unsigned long THRandom_seed() +{ + unsigned long s = (unsigned long)time(0); + THRandom_manualSeed(s); + return s; +} + +/* The next 4 methods are taken from http:www.math.keio.ac.jpmatumotoemt.html + Here is the copyright: + Some minor modifications have been made to adapt to "my" C... */ + +/* + A C-program for MT19937, with initialization improved 2002/2/10. + Coded by Takuji Nishimura and Makoto Matsumoto. + This is a faster version by taking Shawn Cokus's optimization, + Matthe Bellew's simplification, Isaku Wada's double version. + + Before using, initialize the state by using init_genrand(seed) + or init_by_array(init_key, key_length). + + Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. The names of its contributors may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + Any feedback is very welcome. + http://www.math.keio.ac.jp/matumoto/emt.html + email: matumoto@math.keio.ac.jp +*/ + +/* Macros for the Mersenne Twister random generator... */ +/* Period parameters */ +/* #define n 624 */ +/* #define m 397 */ +#define MATRIX_A 0x9908b0dfUL /* constant vector a */ +#define UMASK 0x80000000UL /* most significant w-r bits */ +#define LMASK 0x7fffffffUL /* least significant r bits */ +#define MIXBITS(u,v) ( ((u) & UMASK) | ((v) & LMASK) ) +#define TWIST(u,v) ((MIXBITS(u,v) >> 1) ^ ((v)&1UL ? MATRIX_A : 0UL)) +/*********************************************************** That's it. */ + +void THRandom_manualSeed(unsigned long the_seed_) +{ + int j; + the_initial_seed = the_seed_; + state[0]= the_initial_seed & 0xffffffffUL; + for(j = 1; j < n; j++) + { + state[j] = (1812433253UL * (state[j-1] ^ (state[j-1] >> 30)) + j); + /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ + /* In the previous versions, mSBs of the seed affect */ + /* only mSBs of the array state[]. */ + /* 2002/01/09 modified by makoto matsumoto */ + state[j] &= 0xffffffffUL; /* for >32 bit machines */ + } + left = 1; + initf = 1; +} + +unsigned long THRandom_initialSeed() +{ + if(initf == 0) + { + THRandom_seed(); + } + + return the_initial_seed; +} + +void THRandom_nextState() +{ + unsigned long *p=state; + int j; + + /* if init_genrand() has not been called, */ + /* a default initial seed is used */ + if(initf == 0) + THRandom_seed(); + + left = n; + next = state; + + for(j = n-m+1; --j; p++) + *p = p[m] ^ TWIST(p[0], p[1]); + + for(j = m; --j; p++) + *p = p[m-n] ^ TWIST(p[0], p[1]); + + *p = p[m-n] ^ TWIST(p[0], state[0]); +} + +unsigned long THRandom_random() +{ + unsigned long y; + + if (--left == 0) + THRandom_nextState(); + y = *next++; + + /* Tempering */ + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680UL; + y ^= (y << 15) & 0xefc60000UL; + y ^= (y >> 18); + + return y; +} + +/* generates a random number on [0,1)-double-interval */ +static double __uniform__() +{ + unsigned long y; + + if(--left == 0) + THRandom_nextState(); + y = *next++; + + /* Tempering */ + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680UL; + y ^= (y << 15) & 0xefc60000UL; + y ^= (y >> 18); + + return (double)y * (1.0/4294967296.0); + /* divided by 2^32 */ +} + +/********************************************************* + + Thanks *a lot* Takuji Nishimura and Makoto Matsumoto! + + Now my own code... + +*********************************************************/ + +double THRandom_uniform(double a, double b) +{ + return(__uniform__() * (b - a) + a); +} + +double THRandom_normal(double mean, double stdv) +{ + THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive"); + + if(!normal_is_valid) + { + normal_x = __uniform__(); + normal_y = __uniform__(); + normal_rho = sqrt(-2. * log(1.0-normal_y)); + normal_is_valid = 1; + } + else + normal_is_valid = 0; + + if(normal_is_valid) + return normal_rho*cos(2.*M_PI*normal_x)*stdv+mean; + else + return normal_rho*sin(2.*M_PI*normal_x)*stdv+mean; +} + +double THRandom_exponential(double lambda) +{ + return(-1. / lambda * log(1-__uniform__())); +} + +double THRandom_cauchy(double median, double sigma) +{ + return(median + sigma * tan(M_PI*(__uniform__()-0.5))); +} + +/* Faut etre malade pour utiliser ca. + M'enfin. */ +double THRandom_logNormal(double mean, double stdv) +{ + double zm = mean*mean; + double zs = stdv*stdv; + THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive"); + return(exp(THRandom_normal(log(zm/sqrt(zs + zm)), sqrt(log(zs/zm+1)) ))); +} + +int THRandom_geometric(double p) +{ + THArgCheck(p > 0 && p < 1, 1, "must be > 0 and < 1"); + return((int)(log(1-__uniform__()) / log(p)) + 1); +} + +int THRandom_bernoulli(double p) +{ + THArgCheck(p > 0 && p < 1, 1, "must be > 0 and < 1"); + return(__uniform__() <= p); +} diff --git a/lib/TH/THRandom.h b/lib/TH/THRandom.h new file mode 100644 index 00000000000..537f8bb917b --- /dev/null +++ b/lib/TH/THRandom.h @@ -0,0 +1,52 @@ +#ifndef TH_RANDOM_INC +#define TH_RANDOM_INC + +#include "THGeneral.h" + +/* Initializes the random number generator with the current time (granularity: seconds) and returns the seed. */ +TH_API unsigned long THRandom_seed(); + +/* Initializes the random number generator with the given long "the_seed_". */ +TH_API void THRandom_manualSeed(unsigned long the_seed_); + +/* Returns the starting seed used. */ +TH_API unsigned long THRandom_initialSeed(); + +/* Generates a uniform 32 bits integer. */ +TH_API unsigned long THRandom_random(); + +/* Generates a uniform random number on [0,1[. */ +TH_API double THRandom_uniform(double a, double b); + +/** Generates a random number from a normal distribution. + (With mean #mean# and standard deviation #stdv >= 0#). +*/ +TH_API double THRandom_normal(double mean, double stdv); + +/** Generates a random number from an exponential distribution. + The density is $p(x) = lambda * exp(-lambda * x)$, where + lambda is a positive number. +*/ +TH_API double THRandom_exponential(double lambda); + +/** Returns a random number from a Cauchy distribution. + The Cauchy density is $p(x) = sigma/(pi*(sigma^2 + (x-median)^2))$ +*/ +TH_API double THRandom_cauchy(double median, double sigma); + +/** Generates a random number from a log-normal distribution. + (#mean > 0# is the mean of the log-normal distribution + and #stdv# is its standard deviation). +*/ +TH_API double THRandom_logNormal(double mean, double stdv); + +/** Generates a random number from a geometric distribution. + It returns an integer #i#, where $p(i) = (1-p) * p^(i-1)$. + p must satisfy $0 < p < 1$. +*/ +TH_API int THRandom_geometric(double p); + +/* Returns true with probability $p$ and false with probability $1-p$ (p > 0). */ +TH_API int THRandom_bernoulli(double p); + +#endif diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c new file mode 100644 index 00000000000..bdbd46790a7 --- /dev/null +++ b/lib/TH/THStorage.c @@ -0,0 +1,7 @@ +#include "THStorage.h" + +#include "generic/THStorage.c" +#include "THGenerateAllTypes.h" + +#include "generic/THStorageCopy.c" +#include "THGenerateAllTypes.h" diff --git a/lib/TH/THStorage.h b/lib/TH/THStorage.h new file mode 100644 index 00000000000..52b91fef5aa --- /dev/null +++ b/lib/TH/THStorage.h @@ -0,0 +1,33 @@ +#ifndef TH_STORAGE_INC +#define TH_STORAGE_INC + +#include "THGeneral.h" + +/* stuff for mapped files */ +#ifdef _WIN32 +#include +#endif + +#if HAVE_MMAP +#include +#include +#include +#include +#include +#endif +/* end of stuff for mapped files */ + +#define THStorage TH_CONCAT_3(TH,Real,Storage) +#define THStorage_(NAME) TH_CONCAT_4(TH,Real,Storage_,NAME) + +/* fast access methods */ +#define TH_STORAGE_GET(storage, idx) ((storage)->data[(idx)]) +#define TH_STORAGE_SET(storage, idx, value) ((storage)->data[(idx)] = (value)) + +#include "generic/THStorage.h" +#include "THGenerateAllTypes.h" + +#include "generic/THStorageCopy.h" +#include "THGenerateAllTypes.h" + +#endif diff --git a/lib/TH/THTensor.c b/lib/TH/THTensor.c new file mode 100644 index 00000000000..7a12887b31f --- /dev/null +++ b/lib/TH/THTensor.c @@ -0,0 +1,24 @@ +#include "THTensor.h" +#include "THVector.h" +#include "THBlas.h" +#include "THLapack.h" +#include "THRandom.h" +#include "THTensorDimApply.h" + +#include "generic/THTensor.c" +#include "THGenerateAllTypes.h" + +#include "generic/THTensorCopy.c" +#include "THGenerateAllTypes.h" + +#include "generic/THTensorRandom.c" +#include "THGenerateAllTypes.h" + +#include "generic/THTensorMath.c" +#include "THGenerateAllTypes.h" + +#include "generic/THTensorConv.c" +#include "THGenerateAllTypes.h" + +#include "generic/THTensorLapack.c" +#include "THGenerateFloatTypes.h" diff --git a/lib/TH/THTensor.h b/lib/TH/THTensor.h new file mode 100644 index 00000000000..e513da6c222 --- /dev/null +++ b/lib/TH/THTensor.h @@ -0,0 +1,35 @@ +#ifndef TH_TENSOR_INC +#define TH_TENSOR_INC + +#include "THStorage.h" +#include "THTensorApply.h" + +#define THTensor TH_CONCAT_3(TH,Real,Tensor) +#define THTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME) + +/* basics */ +#include "generic/THTensor.h" +#include "THGenerateAllTypes.h" + +#include "generic/THTensorCopy.h" +#include "THGenerateAllTypes.h" + +#include "THTensorMacros.h" + +/* random numbers */ +#include "generic/THTensorRandom.h" +#include "THGenerateAllTypes.h" + +/* maths */ +#include "generic/THTensorMath.h" +#include "THGenerateAllTypes.h" + +/* convolutions */ +#include "generic/THTensorConv.h" +#include "THGenerateAllTypes.h" + +/* lapack support */ +#include "generic/THTensorLapack.h" +#include "THGenerateFloatTypes.h" + +#endif diff --git a/lib/TH/THTensorApply.h b/lib/TH/THTensorApply.h new file mode 100644 index 00000000000..761623c43f9 --- /dev/null +++ b/lib/TH/THTensorApply.h @@ -0,0 +1,428 @@ +#ifndef TH_TENSOR_APPLY_INC +#define TH_TENSOR_APPLY_INC + +#define TH_TENSOR_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, CODE) \ +{ \ + TYPE1 *TENSOR1##_data = NULL; \ + long *TENSOR1##_counter = NULL; \ + long TENSOR1##_stride = 0, TENSOR1##_size = 0, TENSOR1##_dim = 0, TENSOR1##_i, TENSOR1##_n; \ + TYPE2 *TENSOR2##_data = NULL; \ + long *TENSOR2##_counter = NULL; \ + long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \ + TYPE2 *TENSOR3##_data = NULL; \ + long *TENSOR3##_counter = NULL; \ + long TENSOR3##_stride = 0, TENSOR3##_size = 0, TENSOR3##_dim = 0, TENSOR3##_i, TENSOR3##_n; \ + int TH_TENSOR_APPLY_hasFinished = 0; \ +\ + TENSOR1##_n = (TENSOR1->nDimension ? 1 : 0); \ + for(TENSOR1##_i = 0; TENSOR1##_i < TENSOR1->nDimension; TENSOR1##_i++) \ + TENSOR1##_n *= TENSOR1->size[TENSOR1##_i]; \ +\ + TENSOR2##_n = (TENSOR2->nDimension ? 1 : 0); \ + for(TENSOR2##_i = 0; TENSOR2##_i < TENSOR2->nDimension; TENSOR2##_i++) \ + TENSOR2##_n *= TENSOR2->size[TENSOR2##_i]; \ +\ + TENSOR3##_n = (TENSOR3->nDimension ? 1 : 0); \ + for(TENSOR3##_i = 0; TENSOR3##_i < TENSOR3->nDimension; TENSOR3##_i++) \ + TENSOR3##_n *= TENSOR3->size[TENSOR3##_i]; \ +\ + if(TENSOR1##_n != TENSOR2##_n || TENSOR1##_n != TENSOR3##_n) /* should we do the check in the function instead? i think so */ \ + THError("inconsistent tensor size"); \ +\ + if(TENSOR1->nDimension == 0) \ + TH_TENSOR_APPLY_hasFinished = 1; \ + else \ + { \ + TENSOR1##_data = TENSOR1->storage->data+TENSOR1->storageOffset; \ + for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \ + { \ + if(TENSOR1->size[TENSOR1##_dim] != 1) \ + break; \ + } \ + TENSOR1##_stride = (TENSOR1##_dim == -1 ? 0 : TENSOR1->stride[TENSOR1##_dim]); \ + TENSOR1##_size = 1; \ + for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \ + { \ + if(TENSOR1->size[TENSOR1##_dim] != 1) \ + { \ + if(TENSOR1->stride[TENSOR1##_dim] == TENSOR1##_size) \ + TENSOR1##_size *= TENSOR1->size[TENSOR1##_dim]; \ + else \ + break; \ + } \ + } \ + TENSOR1##_counter = (long*)THAlloc(sizeof(long)*(TENSOR1##_dim+1)); \ + for(TENSOR1##_i = 0; TENSOR1##_i <= TENSOR1##_dim; TENSOR1##_i++) \ + TENSOR1##_counter[TENSOR1##_i] = 0; \ +\ + TENSOR2##_data = TENSOR2->storage->data+TENSOR2->storageOffset; \ + for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \ + { \ + if(TENSOR2->size[TENSOR2##_dim] != 1) \ + break; \ + } \ + TENSOR2##_stride = (TENSOR2##_dim == -1 ? 0 : TENSOR2->stride[TENSOR2##_dim]); \ + TENSOR2##_size = 1; \ + for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \ + { \ + if(TENSOR2->size[TENSOR2##_dim] != 1) \ + { \ + if(TENSOR2->stride[TENSOR2##_dim] == TENSOR2##_size) \ + TENSOR2##_size *= TENSOR2->size[TENSOR2##_dim]; \ + else \ + break; \ + } \ + } \ + TENSOR2##_counter = (long*)THAlloc(sizeof(long)*(TENSOR2##_dim+1)); \ + for(TENSOR2##_i = 0; TENSOR2##_i <= TENSOR2##_dim; TENSOR2##_i++) \ + TENSOR2##_counter[TENSOR2##_i] = 0; \ +\ + TENSOR3##_data = TENSOR3->storage->data+TENSOR3->storageOffset; \ + for(TENSOR3##_dim = TENSOR3->nDimension-1; TENSOR3##_dim >= 0; TENSOR3##_dim--) \ + { \ + if(TENSOR3->size[TENSOR3##_dim] != 1) \ + break; \ + } \ + TENSOR3##_stride = (TENSOR3##_dim == -1 ? 0 : TENSOR3->stride[TENSOR3##_dim]); \ + TENSOR3##_size = 1; \ + for(TENSOR3##_dim = TENSOR3->nDimension-1; TENSOR3##_dim >= 0; TENSOR3##_dim--) \ + { \ + if(TENSOR3->size[TENSOR3##_dim] != 1) \ + { \ + if(TENSOR3->stride[TENSOR3##_dim] == TENSOR3##_size) \ + TENSOR3##_size *= TENSOR3->size[TENSOR3##_dim]; \ + else \ + break; \ + } \ + } \ + TENSOR3##_counter = (long*)THAlloc(sizeof(long)*(TENSOR3##_dim+1)); \ + for(TENSOR3##_i = 0; TENSOR3##_i <= TENSOR3##_dim; TENSOR3##_i++) \ + TENSOR3##_counter[TENSOR3##_i] = 0; \ + } \ +\ + TENSOR1##_i = 0; \ + TENSOR2##_i = 0; \ + TENSOR3##_i = 0; \ + while(!TH_TENSOR_APPLY_hasFinished) \ + { \ + for(; TENSOR1##_i < TENSOR1##_size && TENSOR2##_i < TENSOR2##_size && TENSOR3##_i < TENSOR3##_size; TENSOR1##_i++, TENSOR2##_i++, TENSOR3##_i++, TENSOR1##_data += TENSOR1##_stride, TENSOR2##_data += TENSOR2##_stride, TENSOR3##_data += TENSOR3##_stride) /* 0 et pas TENSOR##_dim! */ \ + { \ + CODE \ + } \ +\ + if(TENSOR1##_i == TENSOR1##_size) \ + { \ + if(TENSOR1##_dim == -1) \ + break; \ +\ + TENSOR1##_data -= TENSOR1##_size*TENSOR1##_stride; \ + for(TENSOR1##_i = TENSOR1##_dim; TENSOR1##_i >= 0; TENSOR1##_i--) \ + { \ + TENSOR1##_counter[TENSOR1##_i]++; \ + TENSOR1##_data += TENSOR1->stride[TENSOR1##_i]; \ +\ + if(TENSOR1##_counter[TENSOR1##_i] == TENSOR1->size[TENSOR1##_i]) \ + { \ + if(TENSOR1##_i == 0) \ + { \ + TH_TENSOR_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR1##_data -= TENSOR1##_counter[TENSOR1##_i]*TENSOR1->stride[TENSOR1##_i]; \ + TENSOR1##_counter[TENSOR1##_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + TENSOR1##_i = 0; \ + } \ +\ + if(TENSOR2##_i == TENSOR2##_size) \ + { \ + if(TENSOR2##_dim == -1) \ + break; \ +\ + TENSOR2##_data -= TENSOR2##_size*TENSOR2##_stride; \ + for(TENSOR2##_i = TENSOR2##_dim; TENSOR2##_i >= 0; TENSOR2##_i--) \ + { \ + TENSOR2##_counter[TENSOR2##_i]++; \ + TENSOR2##_data += TENSOR2->stride[TENSOR2##_i]; \ +\ + if(TENSOR2##_counter[TENSOR2##_i] == TENSOR2->size[TENSOR2##_i]) \ + { \ + if(TENSOR2##_i == 0) \ + { \ + TH_TENSOR_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR2##_data -= TENSOR2##_counter[TENSOR2##_i]*TENSOR2->stride[TENSOR2##_i]; \ + TENSOR2##_counter[TENSOR2##_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + TENSOR2##_i = 0; \ + } \ +\ + if(TENSOR3##_i == TENSOR3##_size) \ + { \ + if(TENSOR3##_dim == -1) \ + break; \ +\ + TENSOR3##_data -= TENSOR3##_size*TENSOR3##_stride; \ + for(TENSOR3##_i = TENSOR3##_dim; TENSOR3##_i >= 0; TENSOR3##_i--) \ + { \ + TENSOR3##_counter[TENSOR3##_i]++; \ + TENSOR3##_data += TENSOR3->stride[TENSOR3##_i]; \ +\ + if(TENSOR3##_counter[TENSOR3##_i] == TENSOR3->size[TENSOR3##_i]) \ + { \ + if(TENSOR3##_i == 0) \ + { \ + TH_TENSOR_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR3##_data -= TENSOR3##_counter[TENSOR3##_i]*TENSOR3->stride[TENSOR3##_i]; \ + TENSOR3##_counter[TENSOR3##_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + TENSOR3##_i = 0; \ + } \ + } \ + THFree(TENSOR1##_counter); \ + THFree(TENSOR2##_counter); \ + THFree(TENSOR3##_counter); \ +} + +#define TH_TENSOR_APPLY2(TYPE1, TENSOR1, TYPE2, TENSOR2, CODE) \ +{ \ + TYPE1 *TENSOR1##_data = NULL; \ + long *TENSOR1##_counter = NULL; \ + long TENSOR1##_stride = 0, TENSOR1##_size = 0, TENSOR1##_dim = 0, TENSOR1##_i, TENSOR1##_n; \ + TYPE2 *TENSOR2##_data = NULL; \ + long *TENSOR2##_counter = NULL; \ + long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \ + int TH_TENSOR_APPLY_hasFinished = 0; \ +\ + TENSOR1##_n = (TENSOR1->nDimension ? 1 : 0); \ + for(TENSOR1##_i = 0; TENSOR1##_i < TENSOR1->nDimension; TENSOR1##_i++) \ + TENSOR1##_n *= TENSOR1->size[TENSOR1##_i]; \ +\ + TENSOR2##_n = (TENSOR2->nDimension ? 1 : 0); \ + for(TENSOR2##_i = 0; TENSOR2##_i < TENSOR2->nDimension; TENSOR2##_i++) \ + TENSOR2##_n *= TENSOR2->size[TENSOR2##_i]; \ +\ + if(TENSOR1##_n != TENSOR2##_n) /* should we do the check in the function instead? i think so */ \ + THError("inconsistent tensor size"); \ +\ + if(TENSOR1->nDimension == 0) \ + TH_TENSOR_APPLY_hasFinished = 1; \ + else \ + { \ + TENSOR1##_data = TENSOR1->storage->data+TENSOR1->storageOffset; \ + for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \ + { \ + if(TENSOR1->size[TENSOR1##_dim] != 1) \ + break; \ + } \ + TENSOR1##_stride = (TENSOR1##_dim == -1 ? 0 : TENSOR1->stride[TENSOR1##_dim]); \ + TENSOR1##_size = 1; \ + for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \ + { \ + if(TENSOR1->size[TENSOR1##_dim] != 1) \ + { \ + if(TENSOR1->stride[TENSOR1##_dim] == TENSOR1##_size) \ + TENSOR1##_size *= TENSOR1->size[TENSOR1##_dim]; \ + else \ + break; \ + } \ + } \ + TENSOR1##_counter = (long*)THAlloc(sizeof(long)*(TENSOR1##_dim+1)); \ + for(TENSOR1##_i = 0; TENSOR1##_i <= TENSOR1##_dim; TENSOR1##_i++) \ + TENSOR1##_counter[TENSOR1##_i] = 0; \ +\ + TENSOR2##_data = TENSOR2->storage->data+TENSOR2->storageOffset; \ + for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \ + { \ + if(TENSOR2->size[TENSOR2##_dim] != 1) \ + break; \ + } \ + TENSOR2##_stride = (TENSOR2##_dim == -1 ? 0 : TENSOR2->stride[TENSOR2##_dim]); \ + TENSOR2##_size = 1; \ + for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \ + { \ + if(TENSOR2->size[TENSOR2##_dim] != 1) \ + { \ + if(TENSOR2->stride[TENSOR2##_dim] == TENSOR2##_size) \ + TENSOR2##_size *= TENSOR2->size[TENSOR2##_dim]; \ + else \ + break; \ + } \ + } \ + TENSOR2##_counter = (long*)THAlloc(sizeof(long)*(TENSOR2##_dim+1)); \ + for(TENSOR2##_i = 0; TENSOR2##_i <= TENSOR2##_dim; TENSOR2##_i++) \ + TENSOR2##_counter[TENSOR2##_i] = 0; \ + } \ +\ + TENSOR1##_i = 0; \ + TENSOR2##_i = 0; \ + while(!TH_TENSOR_APPLY_hasFinished) \ + { \ + for(; TENSOR1##_i < TENSOR1##_size && TENSOR2##_i < TENSOR2##_size; TENSOR1##_i++, TENSOR2##_i++, TENSOR1##_data += TENSOR1##_stride, TENSOR2##_data += TENSOR2##_stride) /* 0 et pas TENSOR##_dim! */ \ + { \ + CODE \ + } \ +\ + if(TENSOR1##_i == TENSOR1##_size) \ + { \ + if(TENSOR1##_dim == -1) \ + break; \ +\ + TENSOR1##_data -= TENSOR1##_size*TENSOR1##_stride; \ + for(TENSOR1##_i = TENSOR1##_dim; TENSOR1##_i >= 0; TENSOR1##_i--) \ + { \ + TENSOR1##_counter[TENSOR1##_i]++; \ + TENSOR1##_data += TENSOR1->stride[TENSOR1##_i]; \ +\ + if(TENSOR1##_counter[TENSOR1##_i] == TENSOR1->size[TENSOR1##_i]) \ + { \ + if(TENSOR1##_i == 0) \ + { \ + TH_TENSOR_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR1##_data -= TENSOR1##_counter[TENSOR1##_i]*TENSOR1->stride[TENSOR1##_i]; \ + TENSOR1##_counter[TENSOR1##_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + TENSOR1##_i = 0; \ + } \ +\ + if(TENSOR2##_i == TENSOR2##_size) \ + { \ + if(TENSOR2##_dim == -1) \ + break; \ +\ + TENSOR2##_data -= TENSOR2##_size*TENSOR2##_stride; \ + for(TENSOR2##_i = TENSOR2##_dim; TENSOR2##_i >= 0; TENSOR2##_i--) \ + { \ + TENSOR2##_counter[TENSOR2##_i]++; \ + TENSOR2##_data += TENSOR2->stride[TENSOR2##_i]; \ +\ + if(TENSOR2##_counter[TENSOR2##_i] == TENSOR2->size[TENSOR2##_i]) \ + { \ + if(TENSOR2##_i == 0) \ + { \ + TH_TENSOR_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR2##_data -= TENSOR2##_counter[TENSOR2##_i]*TENSOR2->stride[TENSOR2##_i]; \ + TENSOR2##_counter[TENSOR2##_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + TENSOR2##_i = 0; \ + } \ + } \ + THFree(TENSOR1##_counter); \ + THFree(TENSOR2##_counter); \ +} + +#define TH_TENSOR_APPLY(TYPE, TENSOR, CODE) \ +{ \ + TYPE *TENSOR##_data = NULL; \ + long *TENSOR##_counter = NULL; \ + long TENSOR##_stride = 0, TENSOR##_size = 0, TENSOR##_dim = 0, TENSOR##_i; \ + int TH_TENSOR_APPLY_hasFinished = 0; \ +\ + if(TENSOR->nDimension == 0) \ + TH_TENSOR_APPLY_hasFinished = 1; \ + else \ + { \ + TENSOR##_data = TENSOR->storage->data+TENSOR->storageOffset; \ +\ + /* what is the first stride (ignore first dims=1)? */ \ + /* it will be used for the whole largest contiguous section */ \ + for(TENSOR##_dim = TENSOR->nDimension-1; TENSOR##_dim >= 0; TENSOR##_dim--) \ + { \ + if(TENSOR->size[TENSOR##_dim] != 1) \ + break; \ + } \ + TENSOR##_stride = (TENSOR##_dim == -1 ? 0 : TENSOR->stride[TENSOR##_dim]); \ +\ + /* what is the largest contiguous section? */ \ + TENSOR##_size = 1; \ + for(TENSOR##_dim = TENSOR->nDimension-1; TENSOR##_dim >= 0; TENSOR##_dim--) \ + { \ + if(TENSOR->size[TENSOR##_dim] != 1) \ + { \ + if(TENSOR->stride[TENSOR##_dim] == TENSOR##_size) \ + TENSOR##_size *= TENSOR->size[TENSOR##_dim]; \ + else \ + break; \ + } \ + } \ +\ + /* counter over found dimensions */ \ + TENSOR##_counter = (long*)THAlloc(sizeof(long)*(TENSOR##_dim+1)); \ + for(TENSOR##_i = 0; TENSOR##_i <= TENSOR##_dim; TENSOR##_i++) \ + TENSOR##_counter[TENSOR##_i] = 0; \ + } \ +\ + while(!TH_TENSOR_APPLY_hasFinished) \ + { \ + for(TENSOR##_i = 0; TENSOR##_i < TENSOR##_size; TENSOR##_i++, TENSOR##_data += TENSOR##_stride) /* 0 et pas TENSOR##_dim! */ \ + { \ + CODE \ + } \ +\ + if(TENSOR##_dim == -1) \ + break; \ + \ + TENSOR##_data -= TENSOR##_i*TENSOR##_stride; \ + for(TENSOR##_i = TENSOR##_dim; TENSOR##_i >= 0; TENSOR##_i--) \ + { \ + TENSOR##_counter[TENSOR##_i]++; \ + TENSOR##_data += TENSOR->stride[TENSOR##_i]; \ +\ + if(TENSOR##_counter[TENSOR##_i] == TENSOR->size[TENSOR##_i]) \ + { \ + if(TENSOR##_i == 0) \ + { \ + TH_TENSOR_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR##_data -= TENSOR##_counter[TENSOR##_i]*TENSOR->stride[TENSOR##_i]; \ + TENSOR##_counter[TENSOR##_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + } \ + THFree(TENSOR##_counter); \ +} + +#endif diff --git a/lib/TH/THTensorDimApply.h b/lib/TH/THTensorDimApply.h new file mode 100644 index 00000000000..40822aa6bad --- /dev/null +++ b/lib/TH/THTensorDimApply.h @@ -0,0 +1,232 @@ +#ifndef TH_TENSOR_DIM_APPLY_INC +#define TH_TENSOR_DIM_APPLY_INC + +#define TH_TENSOR_DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIMENSION, CODE) \ +{ \ + TYPE1 *TENSOR1##_data = NULL; \ + long TENSOR1##_stride = 0, TENSOR1##_size = 0; \ + TYPE2 *TENSOR2##_data = NULL; \ + long TENSOR2##_stride = 0, TENSOR2##_size = 0; \ + TYPE3 *TENSOR3##_data = NULL; \ + long TENSOR3##_stride = 0, TENSOR3##_size = 0; \ + long *TH_TENSOR_DIM_APPLY_counter = NULL; \ + int TH_TENSOR_DIM_APPLY_hasFinished = 0; \ + int TH_TENSOR_DIM_APPLY_i; \ +\ + if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->nDimension) ) \ + THError("invalid dimension"); \ + if( TENSOR1->nDimension != TENSOR2->nDimension ) \ + THError("inconsistent tensor sizes"); \ + if( TENSOR1->nDimension != TENSOR3->nDimension ) \ + THError("inconsistent tensor sizes"); \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ + continue; \ + if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) \ + THError("inconsistent tensor sizes"); \ + if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR3->size[TH_TENSOR_DIM_APPLY_i]) \ + THError("inconsistent tensor sizes"); \ + } \ +\ + TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR1->nDimension)); \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ +\ + TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \ + TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \ + TENSOR1##_size = TENSOR1->size[DIMENSION]; \ +\ + TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \ + TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \ + TENSOR2##_size = TENSOR2->size[DIMENSION]; \ +\ + TENSOR3##_data = (TENSOR3)->storage->data+(TENSOR3)->storageOffset; \ + TENSOR3##_stride = (TENSOR3)->stride[DIMENSION]; \ + TENSOR3##_size = TENSOR3->size[DIMENSION]; \ +\ + while(!TH_TENSOR_DIM_APPLY_hasFinished) \ + { \ + CODE \ +\ + if(TENSOR1->nDimension == 1) \ + break; \ + \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \ + { \ + TH_TENSOR_DIM_APPLY_hasFinished = 1; \ + break; \ + } \ + continue; \ + } \ +\ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \ + TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \ + TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \ + TENSOR3##_data += TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \ +\ + if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \ + { \ + TH_TENSOR_DIM_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \ + TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \ + TENSOR3##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + } \ + THFree(TH_TENSOR_DIM_APPLY_counter); \ +} + +#define TH_TENSOR_DIM_APPLY2(TYPE1, TENSOR1, TYPE2, TENSOR2, DIMENSION, CODE) \ +{ \ + TYPE1 *TENSOR1##_data = NULL; \ + long TENSOR1##_stride = 0, TENSOR1##_size = 0; \ + TYPE2 *TENSOR2##_data = NULL; \ + long TENSOR2##_stride = 0, TENSOR2##_size = 0; \ + long *TH_TENSOR_DIM_APPLY_counter = NULL; \ + int TH_TENSOR_DIM_APPLY_hasFinished = 0; \ + int TH_TENSOR_DIM_APPLY_i; \ +\ + if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->nDimension) ) \ + THError("invalid dimension"); \ + if( TENSOR1->nDimension != TENSOR2->nDimension ) \ + THError("inconsistent tensor sizes"); \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ + continue; \ + if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) \ + THError("inconsistent tensor sizes"); \ + } \ +\ + TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR1->nDimension)); \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ +\ + TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \ + TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \ + TENSOR1##_size = TENSOR1->size[DIMENSION]; \ +\ + TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \ + TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \ + TENSOR2##_size = TENSOR2->size[DIMENSION]; \ +\ + while(!TH_TENSOR_DIM_APPLY_hasFinished) \ + { \ + CODE \ +\ + if(TENSOR1->nDimension == 1) \ + break; \ + \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \ + { \ + TH_TENSOR_DIM_APPLY_hasFinished = 1; \ + break; \ + } \ + continue; \ + } \ +\ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \ + TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \ + TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \ +\ + if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \ + { \ + TH_TENSOR_DIM_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \ + TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + } \ + THFree(TH_TENSOR_DIM_APPLY_counter); \ +} + +#define TH_TENSOR_DIM_APPLY(TYPE, TENSOR, DIMENSION, CODE) \ +{ \ + TYPE *TENSOR##_data = NULL; \ + long TENSOR##_stride = 0, TENSOR##_size = 0; \ + long *TH_TENSOR_DIM_APPLY_counter = NULL; \ + int TH_TENSOR_DIM_APPLY_hasFinished = 0; \ + int TH_TENSOR_DIM_APPLY_i; \ +\ + if( (DIMENSION < 0) || (DIMENSION >= TENSOR->nDimension) ) \ + THError("invalid dimension"); \ +\ + TENSOR##_data = (TENSOR)->storage->data+(TENSOR)->storageOffset; \ + TENSOR##_stride = (TENSOR)->stride[DIMENSION]; \ + TENSOR##_size = TENSOR->size[DIMENSION]; \ + TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR->nDimension)); \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->nDimension; TH_TENSOR_DIM_APPLY_i++) \ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ +\ + while(!TH_TENSOR_DIM_APPLY_hasFinished) \ + { \ + CODE \ +\ + if(TENSOR->nDimension == 1) \ + break; \ + \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->nDimension; TH_TENSOR_DIM_APPLY_i++) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == TENSOR->nDimension-1) \ + { \ + TH_TENSOR_DIM_APPLY_hasFinished = 1; \ + break; \ + } \ + continue; \ + } \ +\ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \ + TENSOR##_data += TENSOR->stride[TH_TENSOR_DIM_APPLY_i]; \ +\ + if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR->size[TH_TENSOR_DIM_APPLY_i]) \ + { \ + if(TH_TENSOR_DIM_APPLY_i == TENSOR->nDimension-1) \ + { \ + TH_TENSOR_DIM_APPLY_hasFinished = 1; \ + break; \ + } \ + else \ + { \ + TENSOR##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR->stride[TH_TENSOR_DIM_APPLY_i]; \ + TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \ + } \ + } \ + else \ + break; \ + } \ + } \ + THFree(TH_TENSOR_DIM_APPLY_counter); \ +} + +#endif diff --git a/lib/TH/THTensorMacros.h b/lib/TH/THTensorMacros.h new file mode 100644 index 00000000000..15b67665e7a --- /dev/null +++ b/lib/TH/THTensorMacros.h @@ -0,0 +1,30 @@ +#ifndef TH_TENSOR_MACROS_INC +#define TH_TENSOR_MACROS_INC + +/* fast method to access to tensor data */ + +#define THTensor_fastGet1d(self, x0) \ + (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]]) + +#define THTensor_fastGet2d(self, x0, x1) \ + (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]]) + +#define THTensor_fastGet3d(self, x0, x1, x2) \ + (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]]) + +#define THTensor_fastGet4d(self, x0, x1, x2, x3) \ + (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]+(x3)*(self)->stride[3]]) + +#define THTensor_fastSet1d(self, x0, value) \ + (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]] = value) + +#define THTensor_fastSet2d(self, x0, x1, value) \ + (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]] = value) + +#define THTensor_fastSet3d(self, x0, x1, x2, value) \ + (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]] = value) + +#define THTensor_fastSet4d(self, x0, x1, x2, x3, value) \ + (((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]+(x3)*(self)->stride[3]] = value) + +#endif diff --git a/lib/TH/THVector.h b/lib/TH/THVector.h new file mode 100644 index 00000000000..f82b9a7ca70 --- /dev/null +++ b/lib/TH/THVector.h @@ -0,0 +1,240 @@ +#ifndef TH_VECTOR_INC +#define TH_VECTOR_INC + +#include "THGeneral.h" + +#define THVector_(NAME) TH_CONCAT_4(TH,Real,Vector_,NAME) + +#if defined __SSE2__ || defined __SSE3__ || defined __SSSE3__ \ + || defined __SSE4_1__ || defined __SSE4_2__ + +#ifdef __SSE2__ +#include +#endif + +#ifdef __SSE3__ +#include +#endif + +#ifdef __SSSE3__ +#include +#endif + +#if defined (__SSE4_2__) || defined (__SSE4_1__) +#include +#endif + +#define THDoubleVector_fill(x, c, n) { \ + long i; \ + __m128d XMM0 = _mm_set1_pd(c); \ + for (i=0; i<=((n)-8); i+=8) { \ + _mm_storeu_pd((x)+i , XMM0); \ + _mm_storeu_pd((x)+i+2, XMM0); \ + _mm_storeu_pd((x)+i+4, XMM0); \ + _mm_storeu_pd((x)+i+6, XMM0); \ + } \ + long off = (n) - ((n)%8); \ + for (i=0; i<((n)%8); i++) { \ + x[off+i] = c; \ + } \ + } + + +#define THDoubleVector_add(y, x, c, n) { \ + long i = 0; \ + __m128d XMM7 = _mm_set1_pd(c); \ + __m128d XMM0,XMM2; \ + for (; i<=((n)-2); i+=2) { \ + XMM0 = _mm_loadu_pd((x)+i); \ + XMM2 = _mm_loadu_pd((y)+i); \ + XMM0 = _mm_mul_pd(XMM0, XMM7); \ + XMM2 = _mm_add_pd(XMM2, XMM0); \ + _mm_storeu_pd((y)+i , XMM2); \ + } \ + for (; i<(n); i++) { \ + y[i] += c * x[i]; \ + } \ + } + +#define THDoubleVector_diff(z, x, y, n) { \ + long i; \ + for (i=0; i<=((n)-8); i+=8) { \ + __m128d XMM0 = _mm_loadu_pd((x)+i ); \ + __m128d XMM1 = _mm_loadu_pd((x)+i+2); \ + __m128d XMM2 = _mm_loadu_pd((x)+i+4); \ + __m128d XMM3 = _mm_loadu_pd((x)+i+6); \ + __m128d XMM4 = _mm_loadu_pd((y)+i ); \ + __m128d XMM5 = _mm_loadu_pd((y)+i+2); \ + __m128d XMM6 = _mm_loadu_pd((y)+i+4); \ + __m128d XMM7 = _mm_loadu_pd((y)+i+6); \ + XMM0 = _mm_sub_pd(XMM0, XMM4); \ + XMM1 = _mm_sub_pd(XMM1, XMM5); \ + XMM2 = _mm_sub_pd(XMM2, XMM6); \ + XMM3 = _mm_sub_pd(XMM3, XMM7); \ + _mm_storeu_pd((z)+i , XMM0); \ + _mm_storeu_pd((z)+i+2, XMM1); \ + _mm_storeu_pd((z)+i+4, XMM2); \ + _mm_storeu_pd((z)+i+6, XMM3); \ + } \ + long off = (n) - ((n)%8); \ + for (i=0; i<((n)%8); i++) { \ + z[off+i] = x[off+i] - y[off+i]; \ + } \ + } + +#define THDoubleVector_scale(y, c, n) { \ + long i; \ + __m128d XMM7 = _mm_set1_pd(c); \ + for (i=0; i<=((n)-4); i+=4) { \ + __m128d XMM0 = _mm_loadu_pd((y)+i ); \ + __m128d XMM1 = _mm_loadu_pd((y)+i+2); \ + XMM0 = _mm_mul_pd(XMM0, XMM7); \ + XMM1 = _mm_mul_pd(XMM1, XMM7); \ + _mm_storeu_pd((y)+i , XMM0); \ + _mm_storeu_pd((y)+i+2, XMM1); \ + } \ + long off = (n) - ((n)%4); \ + for (i=0; i<((n)%4); i++) { \ + y[off+i] *= c; \ + } \ + } + +#define THDoubleVector_mul(y, x, n) { \ + long i; \ + for (i=0; i<=((n)-8); i+=8) { \ + __m128d XMM0 = _mm_loadu_pd((x)+i ); \ + __m128d XMM1 = _mm_loadu_pd((x)+i+2); \ + __m128d XMM2 = _mm_loadu_pd((x)+i+4); \ + __m128d XMM3 = _mm_loadu_pd((x)+i+6); \ + __m128d XMM4 = _mm_loadu_pd((y)+i ); \ + __m128d XMM5 = _mm_loadu_pd((y)+i+2); \ + __m128d XMM6 = _mm_loadu_pd((y)+i+4); \ + __m128d XMM7 = _mm_loadu_pd((y)+i+6); \ + XMM4 = _mm_mul_pd(XMM4, XMM0); \ + XMM5 = _mm_mul_pd(XMM5, XMM1); \ + XMM6 = _mm_mul_pd(XMM6, XMM2); \ + XMM7 = _mm_mul_pd(XMM7, XMM3); \ + _mm_storeu_pd((y)+i , XMM4); \ + _mm_storeu_pd((y)+i+2, XMM5); \ + _mm_storeu_pd((y)+i+4, XMM6); \ + _mm_storeu_pd((y)+i+6, XMM7); \ + } \ + long off = (n) - ((n)%8); \ + for (i=0; i<((n)%8); i++) { \ + y[off+i] *= x[off+i]; \ + } \ + } + +#define THFloatVector_fill(x, c, n) { \ + long i; \ + __m128 XMM0 = _mm_set_ps1(c); \ + for (i=0; i<=((n)-16); i+=16) { \ + _mm_storeu_ps((x)+i , XMM0); \ + _mm_storeu_ps((x)+i+4, XMM0); \ + _mm_storeu_ps((x)+i+8, XMM0); \ + _mm_storeu_ps((x)+i+12, XMM0); \ + } \ + long off = (n) - ((n)%16); \ + for (i=0; i<((n)%16); i++) { \ + x[off+i] = c; \ + } \ + } + +#define THFloatVector_add(y, x, c, n) { \ + long i = 0; \ + __m128 XMM7 = _mm_set_ps1(c); \ + __m128 XMM0,XMM2; \ + for (; i<=((n)-4); i+=4) { \ + XMM0 = _mm_loadu_ps((x)+i); \ + XMM2 = _mm_loadu_ps((y)+i); \ + XMM0 = _mm_mul_ps(XMM0, XMM7); \ + XMM2 = _mm_add_ps(XMM2, XMM0); \ + _mm_storeu_ps((y)+i , XMM2); \ + } \ + for (; i<(n); i++) { \ + y[i] += c * x[i]; \ + } \ + } + +#define THFloatVector_diff(z, x, y, n) { \ + long i; \ + for (i=0; i<=((n)-16); i+=16) { \ + __m128 XMM0 = _mm_loadu_ps((x)+i ); \ + __m128 XMM1 = _mm_loadu_ps((x)+i+ 4); \ + __m128 XMM2 = _mm_loadu_ps((x)+i+ 8); \ + __m128 XMM3 = _mm_loadu_ps((x)+i+12); \ + __m128 XMM4 = _mm_loadu_ps((y)+i ); \ + __m128 XMM5 = _mm_loadu_ps((y)+i+ 4); \ + __m128 XMM6 = _mm_loadu_ps((y)+i+ 8); \ + __m128 XMM7 = _mm_loadu_ps((y)+i+12); \ + XMM0 = _mm_sub_ps(XMM0, XMM4); \ + XMM1 = _mm_sub_ps(XMM1, XMM5); \ + XMM2 = _mm_sub_ps(XMM2, XMM6); \ + XMM3 = _mm_sub_ps(XMM3, XMM7); \ + _mm_storeu_ps((z)+i , XMM0); \ + _mm_storeu_ps((z)+i+ 4, XMM1); \ + _mm_storeu_ps((z)+i+ 8, XMM2); \ + _mm_storeu_ps((z)+i+12, XMM3); \ + } \ + long off = (n) - ((n)%16); \ + for (i=0; i<((n)%16); i++) { \ + z[off+i] = x[off+i] - y[off+i]; \ + } \ + } + +#define THFloatVector_scale(y, c, n) { \ + long i; \ + __m128 XMM7 = _mm_set_ps1(c); \ + for (i=0; i<=((n)-8); i+=8) { \ + __m128 XMM0 = _mm_loadu_ps((y)+i ); \ + __m128 XMM1 = _mm_loadu_ps((y)+i+4); \ + XMM0 = _mm_mul_ps(XMM0, XMM7); \ + XMM1 = _mm_mul_ps(XMM1, XMM7); \ + _mm_storeu_ps((y)+i , XMM0); \ + _mm_storeu_ps((y)+i+4, XMM1); \ + } \ + long off = (n) - ((n)%8); \ + for (i=0; i<((n)%8); i++) { \ + y[off+i] *= c; \ + } \ + } + +#define THFloatVector_mul(y, x, n) { \ + long i; \ + for (i=0; i<=((n)-16); i+=16) { \ + __m128 XMM0 = _mm_loadu_ps((x)+i ); \ + __m128 XMM1 = _mm_loadu_ps((x)+i+ 4); \ + __m128 XMM2 = _mm_loadu_ps((x)+i+ 8); \ + __m128 XMM3 = _mm_loadu_ps((x)+i+12); \ + __m128 XMM4 = _mm_loadu_ps((y)+i ); \ + __m128 XMM5 = _mm_loadu_ps((y)+i+ 4); \ + __m128 XMM6 = _mm_loadu_ps((y)+i+ 8); \ + __m128 XMM7 = _mm_loadu_ps((y)+i+12); \ + XMM4 = _mm_mul_ps(XMM4, XMM0); \ + XMM5 = _mm_mul_ps(XMM5, XMM1); \ + XMM6 = _mm_mul_ps(XMM6, XMM2); \ + XMM7 = _mm_mul_ps(XMM7, XMM3); \ + _mm_storeu_ps((y)+i , XMM4); \ + _mm_storeu_ps((y)+i+ 4, XMM5); \ + _mm_storeu_ps((y)+i+ 8, XMM6); \ + _mm_storeu_ps((y)+i+12, XMM7); \ + } \ + long off = (n) - ((n)%16); \ + for (i=0; i<((n)%16); i++) { \ + y[off+i] *= x[off+i]; \ + } \ + } + +#else + +/* If SSE2 not defined, then generate plain C operators */ +#include "generic/THVector.c" +#include "THGenerateFloatTypes.h" + +#endif + +/* For non-float types, generate plain C operators */ +#include "generic/THVector.c" +#include "THGenerateIntTypes.h" + +#endif diff --git a/lib/TH/cmake/FindBLAS.cmake b/lib/TH/cmake/FindBLAS.cmake new file mode 100644 index 00000000000..a1fd9969c29 --- /dev/null +++ b/lib/TH/cmake/FindBLAS.cmake @@ -0,0 +1,212 @@ +# - Find BLAS library +# This module finds an installed fortran library that implements the BLAS +# linear-algebra interface (see http://www.netlib.org/blas/). +# The list of libraries searched for is taken +# from the autoconf macro file, acx_blas.m4 (distributed at +# http://ac-archive.sourceforge.net/ac-archive/acx_blas.html). +# +# This module sets the following variables: +# BLAS_FOUND - set to true if a library implementing the BLAS interface is found. +# BLAS_INFO - name of the detected BLAS library. +# BLAS_F2C - set to true if following the f2c return convention +# BLAS_LIBRARIES - list of libraries to link against to use BLAS +# BLAS_INCLUDE_DIR - include directory + +SET(BLAS_LIBRARIES) +SET(BLAS_INCLUDE_DIR) +SET(BLAS_INFO) +SET(BLAS_F2C) + +# CBLAS in Intel mkl +FIND_PACKAGE(MKL) +IF (MKL_FOUND AND NOT BLAS_LIBRARIES) + SET(BLAS_INFO imkl) + SET(BLAS_LIBRARIES ${MKL_LIBRARIES}) + SET(BLAS_INCLUDE_DIR ${MKL_INCLUDE_DIR}) + SET(BLAS_VERSION ${MKL_VERSION}) +ENDIF (MKL_FOUND AND NOT BLAS_LIBRARIES) + +# Old FindBlas +INCLUDE(CheckCSourceRuns) +INCLUDE(CheckFortranFunctionExists) +SET(_verbose TRUE) + +MACRO(Check_Fortran_Libraries LIBRARIES _prefix _name _flags _list) + # This macro checks for the existence of the combination of fortran libraries + # given by _list. If the combination is found, this macro checks (using the + # Check_Fortran_Function_Exists macro) whether can link against that library + # combination using the name of a routine given by _name using the linker + # flags given by _flags. If the combination of libraries is found and passes + # the link test, LIBRARIES is set to the list of complete library paths that + # have been found. Otherwise, LIBRARIES is set to NOTFOUND. + # N.B. _prefix is the prefix applied to the names of all cached variables that + # are generated internally and marked advanced by this macro. + + set(__list) + foreach(_elem ${_list}) + if(__list) + set(__list "${__list} - ${_elem}") + else(__list) + set(__list "${_elem}") + endif(__list) + endforeach(_elem) + if(_verbose) + message(STATUS "Checking for [${__list}]") + endif(_verbose) + + set(_libraries_work TRUE) + set(${LIBRARIES}) + set(_combined_name) + foreach(_library ${_list}) + set(_combined_name ${_combined_name}_${_library}) + if(_libraries_work) + if ( WIN32 ) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS ENV LIB + PATHS ENV PATH ) + endif ( WIN32 ) + if ( APPLE ) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 + ENV DYLD_LIBRARY_PATH ) + else ( APPLE ) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 + ENV LD_LIBRARY_PATH ) + endif( APPLE ) + mark_as_advanced(${_prefix}_${_library}_LIBRARY) + set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + endif(_libraries_work) + endforeach(_library ${_list}) + if(_libraries_work) + # Test this combination of libraries. + set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}}) + if (CMAKE_Fortran_COMPILER_WORKS) + check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) + else (CMAKE_Fortran_COMPILER_WORKS) + check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) + endif (CMAKE_Fortran_COMPILER_WORKS) + set(CMAKE_REQUIRED_LIBRARIES) + mark_as_advanced(${_prefix}${_combined_name}_WORKS) + set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + endif(_libraries_work) + if(NOT _libraries_work) + set(${LIBRARIES} NOTFOUND) + endif(NOT _libraries_work) +endmacro(Check_Fortran_Libraries) + + +# Apple BLAS library? +if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "Accelerate") + if (BLAS_LIBRARIES) + set(BLAS_INFO "accelerate") + endif (BLAS_LIBRARIES) +endif(NOT BLAS_LIBRARIES) +if ( NOT BLAS_LIBRARIES ) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "vecLib") + if (BLAS_LIBRARIES) + set(BLAS_INFO "veclib") + endif (BLAS_LIBRARIES) +endif ( NOT BLAS_LIBRARIES ) + +# BLAS in ATLAS library? (http://math-atlas.sourceforge.net/) +if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "cblas;f77blas;atlas") + if (BLAS_LIBRARIES) + set(BLAS_INFO "atlas") + endif (BLAS_LIBRARIES) +endif(NOT BLAS_LIBRARIES) + +# Generic BLAS library? +if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "blas") + if (BLAS_LIBRARIES) + set(BLAS_INFO "generic") + endif (BLAS_LIBRARIES) +endif(NOT BLAS_LIBRARIES) + +# Determine if blas was compiled with the f2c conventions +IF (BLAS_LIBRARIES) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + CHECK_C_SOURCE_RUNS(" +#include +float x[4] = { 1, 2, 3, 4 }; +float y[4] = { .1, .01, .001, .0001 }; +int four = 4; +int one = 1; +extern double sdot_(); +int main() { + int i; + double r = sdot_(&four, x, &one, y, &one); + exit((float)r != (float).1234); +}" BLAS_F2C_DOUBLE_WORKS ) + CHECK_C_SOURCE_RUNS(" +#include +float x[4] = { 1, 2, 3, 4 }; +float y[4] = { .1, .01, .001, .0001 }; +int four = 4; +int one = 1; +extern float sdot_(); +int main() { + int i; + double r = sdot_(&four, x, &one, y, &one); + exit((float)r != (float).1234); +}" BLAS_F2C_FLOAT_WORKS ) + IF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + IF (_verbose) + MESSAGE(STATUS "This BLAS uses the F2C return conventions") + ENDIF(_verbose) + SET(BLAS_F2C TRUE) + ELSE (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + SET(BLAS_F2C FALSE) + ENDIF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) +ENDIF(BLAS_LIBRARIES) + +# epilogue + +if(BLAS_LIBRARIES) + set(BLAS_FOUND TRUE) +else(BLAS_LIBRARIES) + set(BLAS_FOUND FALSE) +endif(BLAS_LIBRARIES) + +IF (NOT BLAS_FOUND AND BLAS_FIND_REQUIRED) + message(FATAL_ERROR "Cannot find a library with BLAS API. Please specify library location.") +ENDIF (NOT BLAS_FOUND AND BLAS_FIND_REQUIRED) +IF(NOT BLAS_FIND_QUIETLY) + IF(BLAS_FOUND) + MESSAGE(STATUS "Found a library with BLAS API (${BLAS_INFO}).") + ELSE(BLAS_FOUND) + MESSAGE(STATUS "Cannot find a library with BLAS API. Not using BLAS.") + ENDIF(BLAS_FOUND) +ENDIF(NOT BLAS_FIND_QUIETLY) + + + + + diff --git a/lib/TH/cmake/FindLAPACK.cmake b/lib/TH/cmake/FindLAPACK.cmake new file mode 100644 index 00000000000..9a755b451fe --- /dev/null +++ b/lib/TH/cmake/FindLAPACK.cmake @@ -0,0 +1,166 @@ +# - Find LAPACK library +# This module finds an installed fortran library that implements the LAPACK +# linear-algebra interface (see http://www.netlib.org/lapack/). +# +# The approach follows that taken for the autoconf macro file, acx_lapack.m4 +# (distributed at http://ac-archive.sourceforge.net/ac-archive/acx_lapack.html). +# +# This module sets the following variables: +# LAPACK_FOUND - set to true if a library implementing the LAPACK interface is found +# LAPACK_LIBRARIES - list of libraries (using full path name) for LAPACK + +SET(LAPACK_LIBRARIES) +SET(LAPACK_INFO) + +IF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + FIND_PACKAGE(BLAS) +ELSE(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + FIND_PACKAGE(BLAS REQUIRED) +ENDIF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + +# LAPACK in Intel mkl +IF (MKL_FOUND AND NOT LAPACK_LIBRARIES) + SET(LAPACK_LIBRARIES ${MKL_LAPACK_LIBRARIES} ${MKL_LIBRARIES}) + SET(LAPACK_INCLUDE_DIR ${MKL_INCLUDE_DIR}) + SET(LAPACK_INFO "mkl") +ENDIF (MKL_FOUND AND NOT LAPACK_LIBRARIES) + +# Old search lapack script +include(CheckFortranFunctionExists) + +macro(Check_Lapack_Libraries LIBRARIES _prefix _name _flags _list _blas) + # This macro checks for the existence of the combination of fortran libraries + # given by _list. If the combination is found, this macro checks (using the + # Check_Fortran_Function_Exists macro) whether can link against that library + # combination using the name of a routine given by _name using the linker + # flags given by _flags. If the combination of libraries is found and passes + # the link test, LIBRARIES is set to the list of complete library paths that + # have been found. Otherwise, LIBRARIES is set to FALSE. + # N.B. _prefix is the prefix applied to the names of all cached variables that + # are generated internally and marked advanced by this macro. + set(_libraries_work TRUE) + set(${LIBRARIES}) + set(_combined_name) + foreach(_library ${_list}) + set(_combined_name ${_combined_name}_${_library}) + if(_libraries_work) + if (WIN32) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} PATHS ENV LIB PATHS ENV PATH) + else (WIN32) + if(APPLE) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 + ENV DYLD_LIBRARY_PATH) + else(APPLE) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 + ENV LD_LIBRARY_PATH) + endif(APPLE) + endif(WIN32) + mark_as_advanced(${_prefix}_${_library}_LIBRARY) + set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + endif(_libraries_work) + endforeach(_library ${_list}) + if(_libraries_work) + # Test this combination of libraries. + set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_blas}) + if (CMAKE_Fortran_COMPILER_WORKS) + check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) + else (CMAKE_Fortran_COMPILER_WORKS) + check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) + endif (CMAKE_Fortran_COMPILER_WORKS) + set(CMAKE_REQUIRED_LIBRARIES) + mark_as_advanced(${_prefix}${_combined_name}_WORKS) + set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + endif(_libraries_work) + if(NOT _libraries_work) + set(${LIBRARIES} FALSE) + endif(NOT _libraries_work) +endmacro(Check_Lapack_Libraries) + + +if(BLAS_FOUND) + + #acml lapack + if(NOT LAPACK_LIBRARIES) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "acml" + "${BLAS_LIBRARIES}" + ) + if(LAPACK_LIBRARIES) + SET(LAPACK_INFO "acml") + endif(LAPACK_LIBRARIES) + endif(NOT LAPACK_LIBRARIES) + + # Apple LAPACK library? + if(NOT LAPACK_LIBRARIES) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "Accelerate" + "${BLAS_LIBRARIES}" + ) + if(LAPACK_LIBRARIES) + SET(LAPACK_INFO "Accelerate") + endif(LAPACK_LIBRARIES) + endif(NOT LAPACK_LIBRARIES) + + if ( NOT LAPACK_LIBRARIES ) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "vecLib" + "${BLAS_LIBRARIES}" + ) + if(LAPACK_LIBRARIES) + SET(LAPACK_INFO "veclib") + endif(LAPACK_LIBRARIES) + endif ( NOT LAPACK_LIBRARIES ) + + # Generic LAPACK library? + if ( NOT LAPACK_LIBRARIES ) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "lapack" + "${BLAS_LIBRARIES}" + ) + if(LAPACK_LIBRARIES) + SET(LAPACK_INFO "generic") + endif(LAPACK_LIBRARIES) + endif ( NOT LAPACK_LIBRARIES ) + +else(BLAS_FOUND) + message(STATUS "LAPACK requires BLAS") +endif(BLAS_FOUND) + +if(LAPACK_LIBRARIES) + set(LAPACK_FOUND TRUE) +else(LAPACK_LIBRARIES) + set(LAPACK_FOUND FALSE) +endif(LAPACK_LIBRARIES) + +IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) + message(FATAL_ERROR "Cannot find a library with LAPACK API. Please specify library location.") +ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) +IF(NOT LAPACK_FIND_QUIETLY) + IF(LAPACK_FOUND) + MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})") + ELSE(LAPACK_FOUND) + MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.") + ENDIF(LAPACK_FOUND) +ENDIF(NOT LAPACK_FIND_QUIETLY) diff --git a/lib/TH/cmake/FindMKL.cmake b/lib/TH/cmake/FindMKL.cmake new file mode 100644 index 00000000000..f528647ce48 --- /dev/null +++ b/lib/TH/cmake/FindMKL.cmake @@ -0,0 +1,274 @@ +# - Find INTEL MKL library +# +# This module finds the Intel Mkl libraries. +# +# This module sets the following variables: +# MKL_FOUND - set to true if a library implementing the CBLAS interface is found +# MKL_VERSION - best guess +# MKL_INCLUDE_DIR - path to include dir. +# MKL_LIBRARIES - list of libraries for base mkl +# MKL_LAPACK_LIBRARIES - list of libraries to add for lapack +# MKL_SCALAPACK_LIBRARIES - list of libraries to add for scalapack +# MKL_SOLVER_LIBRARIES - list of libraries to add for the solvers +# MKL_CDFT_LIBRARIES - list of libraries to add for the solvers + + +# Do nothing if MKL_FOUND was set before! +IF (NOT MKL_FOUND) + +SET(MKL_VERSION) +SET(MKL_INCLUDE_DIR) +SET(MKL_LIBRARIES) +SET(MKL_LAPACK_LIBRARIES) +SET(MKL_SCALAPACK_LIBRARIES) +SET(MKL_SOLVER_LIBRARIES) +SET(MKL_CDFT_LIBRARIES) + +# Includes +INCLUDE(CheckTypeSize) +INCLUDE(CheckFunctionExists) + +# Prints diagnostic +# SET(_verbose TRUE) + +# Intel Compiler Suite +SET(INTEL_COMPILER_DIR CACHE STRING + "Root directory of the Intel Compiler Suite (contains ipp, mkl, etc.)") +SET(INTEL_MKL_DIR CACHE STRING + "Root directory of the Intel MKL (standalone)") +SET(INTEL_MKL_SEQUENTIAL OFF CACHE BOOL + "Force using the sequential (non threaded) libraries") + +# Checks +CHECK_TYPE_SIZE("void*" SIZE_OF_VOIDP) +IF ("${SIZE_OF_VOIDP}" EQUAL 8) + SET(mklvers "em64t") + SET(iccvers "intel64") + SET(mkl64s "_lp64") +ELSE ("${SIZE_OF_VOIDP}" EQUAL 8) + SET(mklvers "32") + SET(iccvers "ia32") + SET(mkl64s) +ENDIF ("${SIZE_OF_VOIDP}" EQUAL 8) +IF (CMAKE_COMPILER_IS_GNUCC) + SET(mklthreads "mkl_gnu_thread" "mkl_intel_thread") + SET(mklifaces "gf" "intel") +ELSE (CMAKE_COMPILER_IS_GNUCC) + SET(mklthreads "mkl_intel_thread") + SET(mklifaces "intel") +ENDIF (CMAKE_COMPILER_IS_GNUCC) +SET(mklrtls "iomp5" "guide") + +# Kernel libraries dynamically loaded +SET(mklkerlibs "mc" "mc3" "nc" "p4n" "p4m" "p4m3" "p4p" "def") +SET(mklseq) + + + +# Paths +SET(saved_CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}) +SET(saved_CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH}) +IF (INTEL_COMPILER_DIR) + # TODO: diagnostic if dir does not exist + SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${INTEL_COMPILER_DIR}/lib/${iccvers}") + IF (NOT INTEL_MKL_DIR) + SET(INTEL_MKL_DIR "${INTEL_COMPILER_DIR}/mkl") + ENDIF (NOT INTEL_MKL_DIR) +ENDIF (INTEL_COMPILER_DIR) +IF (INTEL_MKL_DIR) + # TODO: diagnostic if dir does not exist + SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} + "${INTEL_MKL_DIR}/include") + SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${INTEL_MKL_DIR}/lib/${mklvers}") +ENDIF (INTEL_MKL_DIR) + +# Try linking multiple libs +MACRO(CHECK_ALL_LIBRARIES LIBRARIES _name _list _flags) + # This macro checks for the existence of the combination of libraries given by _list. + # If the combination is found, this macro whether we can link against that library + # combination using the name of a routine given by _name using the linker + # flags given by _flags. If the combination of libraries is found and passes + # the link test, LIBRARIES is set to the list of complete library paths that + # have been found. Otherwise, LIBRARIES is set to FALSE. + # N.B. _prefix is the prefix applied to the names of all cached variables that + # are generated internally and marked advanced by this macro. + SET(_prefix "${LIBRARIES}") + IF (_verbose) + SET(__list) + FOREACH(_elem ${_list}) + IF(__list) + SET(__list "${__list} - ${_elem}") + ELSE(__list) + SET(__list "${_elem}") + ENDIF(__list) + ENDFOREACH(_elem) + ENDIF(_verbose) + # start checking + SET(_libraries_work TRUE) + SET(${LIBRARIES}) + SET(_combined_name) + SET(_paths) + FOREACH(_library ${_list}) + SET(_combined_name ${_combined_name}_${_library}) + IF(_libraries_work) + FIND_LIBRARY(${_prefix}_${_library}_LIBRARY NAMES ${_library}) + MARK_AS_ADVANCED(${_prefix}_${_library}_LIBRARY) + SET(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + SET(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + ENDIF(_libraries_work) + ENDFOREACH(_library ${_list}) + # Test this combination of libraries. + IF(_libraries_work) + SET(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}}) + CHECK_FUNCTION_EXISTS(${_name} ${_prefix}${_combined_name}_WORKS) + SET(CMAKE_REQUIRED_LIBRARIES) + MARK_AS_ADVANCED(${_prefix}${_combined_name}_WORKS) + SET(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + ENDIF(_libraries_work) + # Fin + IF(_libraries_work) + IF (_verbose) + MESSAGE(STATUS "FindMKL: ${__list} : ok") + ENDIF (_verbose) + ELSE (_libraries_work) + SET(${LIBRARIES}) + MARK_AS_ADVANCED(${LIBRARIES}) + IF (_verbose) + MESSAGE(STATUS "FindMKL: ${__list} : no") + ENDIF (_verbose) + ENDIF(_libraries_work) +ENDMACRO(CHECK_ALL_LIBRARIES) + + +# Check for version 10/11 +IF (NOT MKL_LIBRARIES) + SET(MKL_VERSION 1011) +ENDIF (NOT MKL_LIBRARIES) +FOREACH(mklrtl ${mklrtls}) + FOREACH(mkliface ${mklifaces}) + FOREACH(mkl64 ${mkl64s} "") + FOREACH(mklthread ${mklthreads}) + IF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL) + CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm + "mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;m" "") + ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL) + ENDFOREACH(mklthread) + ENDFOREACH(mkl64) + ENDFOREACH(mkliface) +ENDFOREACH(mklrtl) +FOREACH(mklrtl ${mklrtls}) + FOREACH(mkliface ${mklifaces}) + FOREACH(mkl64 ${mkl64s} "") + IF (NOT MKL_LIBRARIES) + CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm + "mkl_${mkliface}${mkl64};mkl_sequential;mkl_core;m" "") + IF (MKL_LIBRARIES) + SET(mklseq "_sequential") + ENDIF (MKL_LIBRARIES) + ENDIF (NOT MKL_LIBRARIES) + ENDFOREACH(mkl64) + ENDFOREACH(mkliface) +ENDFOREACH(mklrtl) +FOREACH(mklrtl ${mklrtls}) + FOREACH(mkliface ${mklifaces}) + FOREACH(mkl64 ${mkl64s} "") + FOREACH(mklthread ${mklthreads}) + IF (NOT MKL_LIBRARIES) + CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm + "mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;m" "") + ENDIF (NOT MKL_LIBRARIES) + ENDFOREACH(mklthread) + ENDFOREACH(mkl64) + ENDFOREACH(mkliface) +ENDFOREACH(mklrtl) + +# Check for older versions +IF (NOT MKL_LIBRARIES) + SET(MKL_VERSION 900) + CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm + "mkl;guide;pthread;m" "") +ENDIF (NOT MKL_LIBRARIES) + +# Include files +IF (MKL_LIBRARIES) + FIND_PATH(MKL_INCLUDE_DIR "mkl_cblas.h") + MARK_AS_ADVANCED(MKL_INCLUDE_DIR) +ENDIF (MKL_LIBRARIES) + +# Other libraries +IF (MKL_LIBRARIES) + FOREACH(mkl64 ${mkl64s} "_core" "") + FOREACH(mkls ${mklseq} "") + IF (NOT MKL_LAPACK_LIBRARIES) + FIND_LIBRARY(MKL_LAPACK_LIBRARIES NAMES "mkl_lapack${mkl64}${mkls}") + MARK_AS_ADVANCED(MKL_LAPACK_LIBRARIES) + ENDIF (NOT MKL_LAPACK_LIBRARIES) + IF (NOT MKL_SCALAPACK_LIBRARIES) + FIND_LIBRARY(MKL_SCALAPACK_LIBRARIES NAMES "mkl_scalapack${mkl64}${mkls}") + MARK_AS_ADVANCED(MKL_SCALAPACK_LIBRARIES) + ENDIF (NOT MKL_SCALAPACK_LIBRARIES) + IF (NOT MKL_SOLVER_LIBRARIES) + FIND_LIBRARY(MKL_SOLVER_LIBRARIES NAMES "mkl_solver${mkl64}${mkls}") + MARK_AS_ADVANCED(MKL_SOLVER_LIBRARIES) + ENDIF (NOT MKL_SOLVER_LIBRARIES) + IF (NOT MKL_CDFT_LIBRARIES) + FIND_LIBRARY(MKL_CDFT_LIBRARIES NAMES "mkl_cdft${mkl64}${mkls}") + MARK_AS_ADVANCED(MKL_CDFT_LIBRARIES) + ENDIF (NOT MKL_CDFT_LIBRARIES) + ENDFOREACH(mkls) + ENDFOREACH(mkl64) +ENDIF (MKL_LIBRARIES) + +# LibIRC: intel compiler always links this; +# gcc does not; but mkl kernels sometimes need it. +IF (MKL_LIBRARIES) + IF (CMAKE_COMPILER_IS_GNUCC) + FIND_LIBRARY(MKL_KERNEL_libirc "irc") + ELSEIF (CMAKE_C_COMPILER_ID AND NOT CMAKE_C_COMPILER_ID STREQUAL "Intel") + FIND_LIBRARY(MKL_KERNEL_libirc "irc") + ENDIF (CMAKE_COMPILER_IS_GNUCC) + MARK_AS_ADVANCED(MKL_KERNEL_libirc) + IF (MKL_KERNEL_libirc) + SET(MKL_LIBRARIES ${MKL_LIBRARIES} ${MKL_KERNEL_libirc}) + ENDIF (MKL_KERNEL_libirc) +ENDIF (MKL_LIBRARIES) + +# Final +SET(CMAKE_LIBRARY_PATH ${saved_CMAKE_LIBRARY_PATH}) +SET(CMAKE_INCLUDE_PATH ${saved_CMAKE_INCLUDE_PATH}) +IF (MKL_LIBRARIES) + SET(MKL_FOUND TRUE) +ELSE (MKL_LIBRARIES) + SET(MKL_FOUND FALSE) + SET(MKL_VERSION) +ENDIF (MKL_LIBRARIES) + +# Results +IF (_verbose) + MESSAGE(STATUS "*** MKL_FOUND = ${MKL_FOUND}") + MESSAGE(STATUS "*** MKL_INCLUDE_DIR = ${MKL_INCLUDE_DIR}") + MESSAGE(STATUS "*** MKL_LIBRARIES = ${MKL_LIBRARIES}") + MESSAGE(STATUS "*** MKL_LAPACK_LIBRARIES = ${MKL_LAPACK_LIBRARIES}") + MESSAGE(STATUS "*** MKL_SCALAPACK_LIBRARIES = ${MKL_SCALAPACK_LIBRARIES}") + MESSAGE(STATUS "*** MKL_SOLVER_LIBRARIES = ${MKL_SOLVER_LIBRARIES}") + MESSAGE(STATUS "*** MKL_CDFT_LIBRARIES = ${MKL_CDFT_LIBRARIES}") +ENDIF(_verbose) + +# Standard termination +IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED) + MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location") +ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED) +IF(NOT MKL_FIND_QUIETLY) + IF(MKL_FOUND) + MESSAGE(STATUS "MKL library found") + ELSE(MKL_FOUND) + MESSAGE(STATUS "MKL library not found") + ENDIF(MKL_FOUND) +ENDIF(NOT MKL_FIND_QUIETLY) + +# Do nothing if MKL_FOUND was set before! +ENDIF (NOT MKL_FOUND) + + diff --git a/lib/TH/cmake/FindSSE.cmake b/lib/TH/cmake/FindSSE.cmake new file mode 100644 index 00000000000..6ece8768968 --- /dev/null +++ b/lib/TH/cmake/FindSSE.cmake @@ -0,0 +1,104 @@ +# Check if SSE instructions are available on the machine where +# the project is compiled. + +IF(CMAKE_SYSTEM_NAME MATCHES "Linux") + EXEC_PROGRAM(cat ARGS "/proc/cpuinfo" OUTPUT_VARIABLE CPUINFO) + + STRING(REGEX REPLACE "^.*(sse2).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "sse2" "${SSE_THERE}" SSE2_TRUE) + IF (SSE2_TRUE) + set(SSE2_FOUND true CACHE BOOL "SSE2 available on host") + ELSE (SSE2_TRUE) + set(SSE2_FOUND false CACHE BOOL "SSE2 available on host") + ENDIF (SSE2_TRUE) + + # /proc/cpuinfo apparently omits sse3 :( + STRING(REGEX REPLACE "^.*[^s](sse3).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "sse3" "${SSE_THERE}" SSE3_TRUE) + IF (NOT SSE3_TRUE) + STRING(REGEX REPLACE "^.*(T2300).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "T2300" "${SSE_THERE}" SSE3_TRUE) + ENDIF (NOT SSE3_TRUE) + + STRING(REGEX REPLACE "^.*(ssse3).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "ssse3" "${SSE_THERE}" SSSE3_TRUE) + IF (SSE3_TRUE OR SSSE3_TRUE) + set(SSE3_FOUND true CACHE BOOL "SSE3 available on host") + ELSE (SSE3_TRUE OR SSSE3_TRUE) + set(SSE3_FOUND false CACHE BOOL "SSE3 available on host") + ENDIF (SSE3_TRUE OR SSSE3_TRUE) + IF (SSSE3_TRUE) + set(SSSE3_FOUND true CACHE BOOL "SSSE3 available on host") + ELSE (SSSE3_TRUE) + set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host") + ENDIF (SSSE3_TRUE) + + STRING(REGEX REPLACE "^.*(sse4_1).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "sse4_1" "${SSE_THERE}" SSE41_TRUE) + IF (SSE41_TRUE) + set(SSE4_1_FOUND true CACHE BOOL "SSE4.1 available on host") + ELSE (SSE41_TRUE) + set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host") + ENDIF (SSE41_TRUE) +ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin") + EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE + CPUINFO) + + STRING(REGEX REPLACE "^.*[^S](SSE2).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "SSE2" "${SSE_THERE}" SSE2_TRUE) + IF (SSE2_TRUE) + set(SSE2_FOUND true CACHE BOOL "SSE2 available on host") + ELSE (SSE2_TRUE) + set(SSE2_FOUND false CACHE BOOL "SSE2 available on host") + ENDIF (SSE2_TRUE) + + STRING(REGEX REPLACE "^.*[^S](SSE3).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "SSE3" "${SSE_THERE}" SSE3_TRUE) + IF (SSE3_TRUE) + set(SSE3_FOUND true CACHE BOOL "SSE3 available on host") + ELSE (SSE3_TRUE) + set(SSE3_FOUND false CACHE BOOL "SSE3 available on host") + ENDIF (SSE3_TRUE) + + STRING(REGEX REPLACE "^.*(SSSE3).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "SSSE3" "${SSE_THERE}" SSSE3_TRUE) + IF (SSSE3_TRUE) + set(SSSE3_FOUND true CACHE BOOL "SSSE3 available on host") + ELSE (SSSE3_TRUE) + set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host") + ENDIF (SSSE3_TRUE) + + STRING(REGEX REPLACE "^.*(SSE4.1).*$" "\\1" SSE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "SSE4.1" "${SSE_THERE}" SSE41_TRUE) + IF (SSE41_TRUE) + set(SSE4_1_FOUND true CACHE BOOL "SSE4.1 available on host") + ELSE (SSE41_TRUE) + set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host") + ENDIF (SSE41_TRUE) +ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows") + # TODO + set(SSE2_FOUND true CACHE BOOL "SSE2 available on host") + set(SSE3_FOUND false CACHE BOOL "SSE3 available on host") + set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host") + set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host") +ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux") + set(SSE2_FOUND true CACHE BOOL "SSE2 available on host") + set(SSE3_FOUND false CACHE BOOL "SSE3 available on host") + set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host") + set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host") +ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux") + +if(NOT SSE2_FOUND) + MESSAGE(STATUS "Could not find hardware support for SSE2 on this machine.") +endif(NOT SSE2_FOUND) +if(NOT SSE3_FOUND) + MESSAGE(STATUS "Could not find hardware support for SSE3 on this machine.") +endif(NOT SSE3_FOUND) +if(NOT SSSE3_FOUND) + MESSAGE(STATUS "Could not find hardware support for SSSE3 on this machine.") +endif(NOT SSSE3_FOUND) +if(NOT SSE4_1_FOUND) + MESSAGE(STATUS "Could not find hardware support for SSE4.1 on this machine.") +endif(NOT SSE4_1_FOUND) + +mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND) diff --git a/lib/TH/generic/THBlas.c b/lib/TH/generic/THBlas.c new file mode 100644 index 00000000000..89e186b6e4f --- /dev/null +++ b/lib/TH/generic/THBlas.c @@ -0,0 +1,382 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THBlas.c" +#else + +void THBlas_(swap)(long n, real *x, long incx, real *y, long incy) +{ + if(n == 1) + { + incx = 1; + incy = 1; + } + +#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) + if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) + { + int i_n = (int)n; + int i_incx = (int)incx; + int i_incy = (int)incy; + +#if defined(TH_REAL_IS_DOUBLE) + extern void dswap_(int *n, double *x, int *incx, double *y, int *incy); + dswap_(&i_n, x, &i_incx, y, &i_incy); +#else + extern void sswap_(int *n, float *x, int *incx, float *y, int *incy); + sswap_(&i_n, x, &i_incx, y, &i_incy); +#endif + return; + } +#endif + { + long i; + for(i = 0; i < n; i++) + { + real z = x[i*incx]; + x[i*incx] = y[i*incy]; + y[i*incy] = z; + } + } +} + +void THBlas_(scal)(long n, real a, real *x, long incx) +{ + if(n == 1) + incx = 1; + +#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) + if( (n <= INT_MAX) && (incx <= INT_MAX) ) + { + int i_n = (int)n; + int i_incx = (int)incx; + +#if defined(TH_REAL_IS_DOUBLE) + extern void dscal_(int *n, double *a, double *x, int *incx); + dscal_(&i_n, &a, x, &i_incx); +#else + extern void sscal_(int *n, float *a, float *x, int *incx); + sscal_(&i_n, &a, x, &i_incx); +#endif + return; + } +#endif + { + long i; + for(i = 0; i < n; i++) + x[i*incx] *= a; + } +} + +void THBlas_(copy)(long n, real *x, long incx, real *y, long incy) +{ + if(n == 1) + { + incx = 1; + incy = 1; + } + +#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) + if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) + { + int i_n = (int)n; + int i_incx = (int)incx; + int i_incy = (int)incy; + +#if defined(TH_REAL_IS_DOUBLE) + extern void dcopy_(int *n, double *x, int *incx, double *y, int *incy); + dcopy_(&i_n, x, &i_incx, y, &i_incy); +#else + extern void scopy_(int *n, float *x, int *incx, float *y, int *incy); + scopy_(&i_n, x, &i_incx, y, &i_incy); +#endif + return; + } +#endif + { + long i; + for(i = 0; i < n; i++) + y[i*incy] = x[i*incx]; + } +} + +void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy) +{ + if(n == 1) + { + incx = 1; + incy = 1; + } + +#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) + if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) + { + int i_n = (int)n; + int i_incx = (int)incx; + int i_incy = (int)incy; + +#if defined(TH_REAL_IS_DOUBLE) + extern void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy); + daxpy_(&i_n, &a, x, &i_incx, y, &i_incy); +#else + extern void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy); + saxpy_(&i_n, &a, x, &i_incx, y, &i_incy); +#endif + return; + } +#endif + { + long i; + for(i = 0; i < n; i++) + y[i*incy] += a*x[i*incx]; + } +} + +real THBlas_(dot)(long n, real *x, long incx, real *y, long incy) +{ + if(n == 1) + { + incx = 1; + incy = 1; + } + +#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) + if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) + { + int i_n = (int)n; + int i_incx = (int)incx; + int i_incy = (int)incy; + +#if defined(TH_REAL_IS_DOUBLE) + extern double ddot_(int *n, double *x, int *incx, double *y, int *incy); + return ddot_(&i_n, x, &i_incx, y, &i_incy); +#else + extern float sdot_(int *n, float *x, int *incx, float *y, int *incy); + return sdot_(&i_n, x, &i_incx, y, &i_incy); +#endif + } +#endif + { + long i; + real sum = 0; + for(i = 0; i < n; i++) + sum += x[i*incx]*y[i*incy]; + return sum; + } +} + +void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, real *x, long incx, real beta, real *y, long incy) +{ + if(n == 1) + lda = m; + +#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) + if( (m <= INT_MAX) && (n <= INT_MAX) && + (lda > 0) && (lda <= INT_MAX) && + (incx > 0) && (incx <= INT_MAX) && + (incy > 0) && (incy <= INT_MAX) ) + { + int i_m = (int)m; + int i_n = (int)n; + int i_lda = (int)lda; + int i_incx = (int)incx; + int i_incy = (int)incy; + +#if defined(TH_REAL_IS_DOUBLE) + extern void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy); + dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); +#else + extern void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy); + sgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); +#endif + return; + } +#endif + { + long i, j; + + if( (trans == 'T') || (trans == 't') ) + { + for(i = 0; i < n; i++) + { + real sum = 0; + real *row_ = a+lda*i; + for(j = 0; j < m; j++) + sum += x[j*incx]*row_[j]; + y[i*incy] = beta*y[i*incy] + alpha*sum; + } + } + else + { + if(beta != 1) + THBlas_(scal)(m, beta, y, incy); + + for(j = 0; j < n; j++) + { + real *column_ = a+lda*j; + real z = alpha*x[j*incx]; + for(i = 0; i < m; i++) + y[i*incy] += z*column_[i]; + } + } + } +} + +void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long incy, real *a, long lda) +{ + if(n == 1) + lda = m; + +#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) + if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) + { + int i_m = (int)m; + int i_n = (int)n; + int i_lda = (int)lda; + int i_incx = (int)incx; + int i_incy = (int)incy; + +#if defined(TH_REAL_IS_DOUBLE) + extern void dger_(int *m, int *n, double *alpha, double *x, int *incx, real *y, int *incy, double *a, int *lda); + dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda); +#else + extern void sger_(int *m, int *n, float *alpha, float *x, int *incx, real *y, int *incy, float *a, int *lda); + sger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda); +#endif + return; + } +#endif + { + long i, j; + for(j = 0; j < n; j++) + { + real *column_ = a+j*lda; + real z = alpha*y[j*incy]; + for(i = 0; i < m; i++) + column_[i] += z*x[i*incx] ; + } + } +} + +void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc) +{ + int transa_ = ((transa == 't') || (transa == 'T')); + int transb_ = ((transb == 't') || (transb == 'T')); + + if(n == 1) + ldc = m; + + if(transa_) + { + if(m == 1) + lda = k; + } + else + { + if(k == 1) + lda = m; + } + + if(transb_) + { + if(k == 1) + ldb = n; + } + else + { + if(n == 1) + ldb = k; + } + +#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) + if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) + { + int i_m = (int)m; + int i_n = (int)n; + int i_k = (int)k; + int i_lda = (int)lda; + int i_ldb = (int)ldb; + int i_ldc = (int)ldc; + +#if defined(TH_REAL_IS_DOUBLE) + extern void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc); + dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc); +#else + extern void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc); + sgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc); +#endif + return; + } +#endif + { + long i, j, l; + if(!transa_ && !transb_) + { + real *a_ = a; + for(i = 0; i < m; i++) + { + real *b_ = b; + for(j = 0; j < n; j++) + { + real sum = 0; + for(l = 0; l < k; l++) + sum += a_[l*lda]*b_[l]; + b_ += ldb; + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum; + } + a_++; + } + } + else if(transa_ && !transb_) + { + real *a_ = a; + for(i = 0; i < m; i++) + { + real *b_ = b; + for(j = 0; j < n; j++) + { + real sum = 0; + for(l = 0; l < k; l++) + sum += a_[l]*b_[l]; + b_ += ldb; + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum; + } + a_ += lda; + } + } + else if(!transa_ && transb_) + { + real *a_ = a; + for(i = 0; i < m; i++) + { + real *b_ = b; + for(j = 0; j < n; j++) + { + real sum = 0; + for(l = 0; l < k; l++) + sum += a_[l*lda]*b_[l*ldb]; + b_++; + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum; + } + a_++; + } + } + else + { + real *a_ = a; + for(i = 0; i < m; i++) + { + real *b_ = b; + for(j = 0; j < n; j++) + { + real sum = 0; + for(l = 0; l < k; l++) + sum += a_[l]*b_[l*ldb]; + b_++; + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum; + } + a_ += lda; + } + } + } +} + +#endif diff --git a/lib/TH/generic/THBlas.h b/lib/TH/generic/THBlas.h new file mode 100644 index 00000000000..244bb3cba7b --- /dev/null +++ b/lib/TH/generic/THBlas.h @@ -0,0 +1,19 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THBlas.h" +#else + +/* Level 1 */ +void THBlas_(swap)(long n, real *x, long incx, real *y, long incy); +void THBlas_(scal)(long n, real a, real *x, long incx); +void THBlas_(copy)(long n, real *x, long incx, real *y, long incy); +void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy); +real THBlas_(dot)(long n, real *x, long incx, real *y, long incy); + +/* Level 2 */ +void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, real *x, long incx, real beta, real *y, long incy); +void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long incy, real *a, long lda); + +/* Level 3 */ +void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc); + +#endif diff --git a/lib/TH/generic/THLapack.c b/lib/TH/generic/THLapack.c new file mode 100644 index 00000000000..e0691f524a8 --- /dev/null +++ b/lib/TH/generic/THLapack.c @@ -0,0 +1,66 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THLapack.c" +#else + +void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info) +{ +#ifdef __LAPACK__ +#if defined(TH_REAL_IS_DOUBLE) + extern void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info); + dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); +#else + extern void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info); + sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); +#endif +#else + THError("gesv : Lapack library not found in compile time\n"); +#endif + return; +} + +void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info) +{ +#ifdef __LAPACK__ +#if defined(TH_REAL_IS_DOUBLE) + extern void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); + dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); +#else + extern void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); + sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); +#endif +#else + THError("gels : Lapack library not found in compile time\n"); +#endif +} + +void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info) +{ +#ifdef __LAPACK__ +#if defined(TH_REAL_IS_DOUBLE) + extern void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info); + dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); +#else + extern void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info); + ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); +#endif +#else + THError("syev : Lapack library not found in compile time\n"); +#endif +} + +void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info) +{ +#ifdef __LAPACK__ +#if defined(TH_REAL_IS_DOUBLE) + extern void dgesvd_(char *jobu, char *jobvt, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *info); + dgesvd_( &jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info); +#else + extern void sgesvd_(char *jobu, char *jobvt, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *info); + sgesvd_( &jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info); +#endif +#else + THError("gesvd : Lapack library not found in compile time\n"); +#endif +} + +#endif diff --git a/lib/TH/generic/THLapack.h b/lib/TH/generic/THLapack.h new file mode 100644 index 00000000000..ed3564f904a --- /dev/null +++ b/lib/TH/generic/THLapack.h @@ -0,0 +1,15 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THLapack.h" +#else + + + +/* AX=B */ +void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info); +/* ||AX-B|| */ +void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info); +/* Eigenvals */ +void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info); +/* svd */ +void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info); +#endif diff --git a/lib/TH/generic/THStorage.c b/lib/TH/generic/THStorage.c new file mode 100644 index 00000000000..863ec6881e2 --- /dev/null +++ b/lib/TH/generic/THStorage.c @@ -0,0 +1,259 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THStorage.c" +#else + +THStorage* THStorage_(new)(void) +{ + return THStorage_(newWithSize)(0); +} + +THStorage* THStorage_(newWithSize)(long size) +{ + THStorage *storage = THAlloc(sizeof(THStorage)); + storage->data = THAlloc(sizeof(real)*size); + storage->size = size; + storage->refcount = 1; + storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; + return storage; +} + +THStorage* THStorage_(newWithSize1)(real data0) +{ + THStorage *self = THStorage_(newWithSize)(1); + self->data[0] = data0; + return self; +} + +THStorage* THStorage_(newWithSize2)(real data0, real data1) +{ + THStorage *self = THStorage_(newWithSize)(2); + self->data[0] = data0; + self->data[1] = data1; + return self; +} + +THStorage* THStorage_(newWithSize3)(real data0, real data1, real data2) +{ + THStorage *self = THStorage_(newWithSize)(3); + self->data[0] = data0; + self->data[1] = data1; + self->data[2] = data2; + return self; +} + +THStorage* THStorage_(newWithSize4)(real data0, real data1, real data2, real data3) +{ + THStorage *self = THStorage_(newWithSize)(4); + self->data[0] = data0; + self->data[1] = data1; + self->data[2] = data2; + self->data[3] = data3; + return self; +} + +#if defined(_WIN32) || defined(HAVE_MMAP) + +THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared) +{ + THStorage *storage = THAlloc(sizeof(THStorage)); + long size; + + /* check size */ + FILE *f = fopen(fileName, "rb"); + if(f == NULL) + THError("unable to open file <%s> for mapping (read-only mode)", fileName); + fseek(f, 0, SEEK_END); + size = ftell(f); + fclose(f); + size /= sizeof(real); + +#ifdef _WIN32 + { + HANDLE hfile; + HANDLE hmfile; + DWORD size_hi, size_lo; + + /* open file */ + if(isShared) + { + hfile = CreateFileA(fileName, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + if (hfile == INVALID_HANDLE_VALUE) + THError("could not open file <%s> in read-write mode", fileName); + } + else + { + hfile = CreateFileA(fileName, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + if (hfile == INVALID_HANDLE_VALUE) + THError("could not open file <%s> in read-only mode", fileName); + } + +#if SIZEOF_SIZE_T > 4 + size_hi = (DWORD)((size*sizeof(real)) >> 32); + size_lo = (DWORD)((size*sizeof(real)) & 0xFFFFFFFF); +#else + size_hi = 0; + size_lo = (DWORD)(size*sizeof(real)); +#endif + + /* get map handle */ + if(isShared) + { + if( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, size_hi, size_lo, NULL)) == NULL ) + THError("could not create a map on file <%s>", fileName); + } + else + { + if( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, size_hi, size_lo, NULL)) == NULL ) + THError("could not create a map on file <%s>", fileName); + } + + /* map the stuff */ + storage = THStorage_(new)(); + if(isShared) + storage->data = MapViewOfFile(hmfile, FILE_MAP_ALL_ACCESS, 0, 0, 0); + else + storage->data = MapViewOfFile(hmfile, FILE_MAP_COPY, 0, 0, 0); + + storage->size = size; + if(storage->data == NULL) + { + THStorage_(free)(storage); + THError("memory map failed on file <%s>", fileName); + } + CloseHandle(hfile); + CloseHandle(hmfile); + } +#else + { + /* open file */ + int fd; + if(isShared) + { + fd = open(fileName, O_RDWR); + if(fd == -1) + THError("unable to open file <%s> in read-write mode", fileName); + } + else + { + fd = open(fileName, O_RDONLY); + if(fd == -1) + THError("unable to open file <%s> in read-only mode", fileName); + } + + /* map it */ + storage = THStorage_(new)(); + if(isShared) + storage->data = mmap(NULL, size*sizeof(real), PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0); + else + storage->data = mmap(NULL, size*sizeof(real), PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); + + storage->size = size; + if(storage->data == MAP_FAILED) + { + storage->data = NULL; /* let's be sure it is NULL before calling free() */ + THStorage_(free)(storage); + THError("memory map failed on file <%s>", fileName); + } + close (fd); + } +#endif + + storage->refcount = 1; + storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_MAPPED | TH_STORAGE_FREEMEM;; + return storage; +} + +#else + +THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared) +{ + THError("Mapped file Storages are not supported on your system"); +} + +#endif + +void THStorage_(setFlag)(THStorage *storage, const char flag) +{ + storage->flag |= flag; +} + +void THStorage_(clearFlag)(THStorage *storage, const char flag) +{ + storage->flag &= ~flag; +} + +void THStorage_(retain)(THStorage *storage) +{ + if(storage && (storage->flag & TH_STORAGE_REFCOUNTED)) + ++storage->refcount; +} + +void THStorage_(free)(THStorage *storage) +{ + if(!storage) + return; + + if((storage->flag & TH_STORAGE_REFCOUNTED) && (storage->refcount > 0)) + { + if(--storage->refcount == 0) + { + if(storage->flag & TH_STORAGE_FREEMEM) + { +#if defined(_WIN32) || defined(HAVE_MMAP) + if(storage->flag & TH_STORAGE_MAPPED) + { +#ifdef _WIN32 + if(!UnmapViewOfFile((LPINT)storage->data)) +#else + if (munmap(storage->data, storage->size*sizeof(real))) +#endif + THError("could not unmap the shared memory file"); + } + else +#endif + THFree(storage->data); + } + THFree(storage); + } + } +} + +THStorage* THStorage_(newWithData)(real *data, long size) +{ + THStorage *storage = THAlloc(sizeof(THStorage)); + storage->data = data; + storage->size = size; + storage->refcount = 1; + storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; + return storage; +} + +void THStorage_(resize)(THStorage *storage, long size) +{ + if(storage->flag & TH_STORAGE_RESIZABLE) + { + storage->data = THRealloc(storage->data, sizeof(real)*size); + storage->size = size; + } +} + +void THStorage_(fill)(THStorage *storage, real value) +{ + long i; + for(i = 0; i < storage->size; i++) + storage->data[i] = value; +} + +void THStorage_(set)(THStorage *self, long idx, real value) +{ + THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds"); + self->data[idx] = value; +} + +real THStorage_(get)(THStorage *self, long idx) +{ + THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds"); + return self->data[idx]; +} + +#endif diff --git a/lib/TH/generic/THStorage.h b/lib/TH/generic/THStorage.h new file mode 100644 index 00000000000..01e253a4e77 --- /dev/null +++ b/lib/TH/generic/THStorage.h @@ -0,0 +1,59 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THStorage.h" +#else + +/* on pourrait avoir un liste chainee + qui initialise math, lab structures (or more). + mouais -- complique. + + Pb: THMapStorage is kind of a class + THLab_()... comment je m'en sors? + + en template, faudrait que je les instancie toutes!!! oh boy! + Et comment je sais que c'est pour Cuda? Le type float est le meme dans les <> + + au bout du compte, ca serait sur des pointeurs float/double... etc... = facile. + primitives?? + */ + +#define TH_STORAGE_REFCOUNTED 1 +#define TH_STORAGE_RESIZABLE 2 +#define TH_STORAGE_MAPPED 4 +#define TH_STORAGE_FREEMEM 8 + +typedef struct THStorage +{ + real *data; + long size; + int refcount; + char flag; + +} THStorage; + +TH_API real* THStorage_(data)(THStorage*); +TH_API long THStorage_(size)(THStorage*); + +/* slow access -- checks everything */ +TH_API void THStorage_(set)(THStorage*, long, real); +TH_API real THStorage_(get)(THStorage*, long); + +TH_API THStorage* THStorage_(new)(void); +TH_API THStorage* THStorage_(newWithSize)(long size); +TH_API THStorage* THStorage_(newWithSize1)(real); +TH_API THStorage* THStorage_(newWithSize2)(real, real); +TH_API THStorage* THStorage_(newWithSize3)(real, real, real); +TH_API THStorage* THStorage_(newWithSize4)(real, real, real, real); +TH_API THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared); +TH_API THStorage* THStorage_(newWithData)(real *data, long size); + +/* should not differ with API */ +TH_API void THStorage_(setFlag)(THStorage *storage, const char flag); +TH_API void THStorage_(clearFlag)(THStorage *storage, const char flag); +TH_API void THStorage_(retain)(THStorage *storage); + +/* might differ with other API (like CUDA) */ +TH_API void THStorage_(free)(THStorage *storage); +TH_API void THStorage_(resize)(THStorage *storage, long size); +TH_API void THStorage_(fill)(THStorage *storage, real value); + +#endif diff --git a/lib/TH/generic/THStorageCopy.c b/lib/TH/generic/THStorageCopy.c new file mode 100644 index 00000000000..63a26dc1be6 --- /dev/null +++ b/lib/TH/generic/THStorageCopy.c @@ -0,0 +1,36 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THStorageCopy.c" +#else + +void THStorage_(rawCopy)(THStorage *storage, real *src) +{ + long i; + for(i = 0; i < storage->size; i++) + storage->data[i] = src[i]; +} + +void THStorage_(copy)(THStorage *storage, THStorage *src) +{ + THArgCheck(storage->size == src->size, 2, "size mismatch"); + THStorage_(rawCopy)(storage, src->data); +} + + +#define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \ +void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ +{ \ + long i; \ + THArgCheck(storage->size == src->size, 2, "size mismatch"); \ + for(i = 0; i < storage->size; i++) \ + storage->data[i] = (real)src->data[i]; \ +} + +IMPLEMENT_THStorage_COPY(Byte) +IMPLEMENT_THStorage_COPY(Char) +IMPLEMENT_THStorage_COPY(Short) +IMPLEMENT_THStorage_COPY(Int) +IMPLEMENT_THStorage_COPY(Long) +IMPLEMENT_THStorage_COPY(Float) +IMPLEMENT_THStorage_COPY(Double) + +#endif diff --git a/lib/TH/generic/THStorageCopy.h b/lib/TH/generic/THStorageCopy.h new file mode 100644 index 00000000000..f853a82b569 --- /dev/null +++ b/lib/TH/generic/THStorageCopy.h @@ -0,0 +1,17 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THStorageCopy.h" +#else + +/* Support for copy between different Storage types */ + +TH_API void THStorage_(rawCopy)(THStorage *storage, real *src); +TH_API void THStorage_(copy)(THStorage *storage, THStorage *src); +TH_API void THStorage_(copyByte)(THStorage *storage, struct THByteStorage *src); +TH_API void THStorage_(copyChar)(THStorage *storage, struct THCharStorage *src); +TH_API void THStorage_(copyShort)(THStorage *storage, struct THShortStorage *src); +TH_API void THStorage_(copyInt)(THStorage *storage, struct THIntStorage *src); +TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src); +TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src); +TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src); + +#endif diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c new file mode 100644 index 00000000000..98737c29eec --- /dev/null +++ b/lib/TH/generic/THTensor.c @@ -0,0 +1,728 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensor.c" +#else + +/**** access methods ****/ +THStorage *THTensor_(storage)(THTensor *self) +{ + return self->storage; +} + +long THTensor_(storageOffset)(THTensor *self) +{ + return self->storageOffset; +} + +int THTensor_(nDimension)(THTensor *self) +{ + return self->nDimension; +} + +long THTensor_(size)(THTensor *self, int dim) +{ + THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "out of range"); + return self->size[dim]; +} + +long THTensor_(stride)(THTensor *self, int dim) +{ + THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "out of range"); + return self->stride[dim]; +} + +THLongStorage *THTensor_(newSizeOf)(THTensor *self) +{ + THLongStorage *size = THLongStorage_newWithSize(self->nDimension); + THLongStorage_rawCopy(size, self->size); + return size; +} + +THLongStorage *THTensor_(newStrideOf)(THTensor *self) +{ + THLongStorage *stride = THLongStorage_newWithSize(self->nDimension); + THLongStorage_rawCopy(stride, self->stride); + return stride; +} + +real *THTensor_(data)(THTensor *self) +{ + if(self->storage) + return (self->storage->data+self->storageOffset); + else + return NULL; +} + +void THTensor_(setFlag)(THTensor *self, const char flag) +{ + self->flag |= flag; +} + +void THTensor_(clearFlag)(THTensor *self, const char flag) +{ + self->flag &= ~flag; +} + +/**** creation methods ****/ + +static void THTensor_(rawInit)(THTensor *self); +static void THTensor_(rawSet)(THTensor *self, THStorage *storage, long storageOffset, int nDimension, long *size, long *stride); +static void THTensor_(rawResize)(THTensor *self, int nDimension, long *size, long *stride); + + +/* Empty init */ +THTensor *THTensor_(new)(void) +{ + THTensor *self = THAlloc(sizeof(THTensor)); + THTensor_(rawInit)(self); + return self; +} + +/* Pointer-copy init */ +THTensor *THTensor_(newWithTensor)(THTensor *tensor) +{ + THTensor *self = THAlloc(sizeof(THTensor)); + THTensor_(rawInit)(self); + THTensor_(rawSet)(self, + tensor->storage, + tensor->storageOffset, + tensor->nDimension, + tensor->size, + tensor->stride); + return self; +} + +/* Storage init */ +THTensor *THTensor_(newWithStorage)(THStorage *storage, long storageOffset, THLongStorage *size, THLongStorage *stride) +{ + THTensor *self = THAlloc(sizeof(THTensor)); + if(size && stride) + THArgCheck(size->size == stride->size, 4, "inconsistent size"); + + THTensor_(rawInit)(self); + THTensor_(rawSet)(self, + storage, + storageOffset, + (size ? size->size : (stride ? stride->size : 0)), + (size ? size->data : NULL), + (stride ? stride->data : NULL)); + + return self; +} +THTensor *THTensor_(newWithStorage1d)(THStorage *storage, long storageOffset, + long size0, long stride0) +{ + return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, -1, -1, -1, -1, -1, -1); +} + +THTensor *THTensor_(newWithStorage2d)(THStorage *storage, long storageOffset, + long size0, long stride0, + long size1, long stride1) +{ + return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, size1, stride1, -1, -1, -1, -1); +} + +THTensor *THTensor_(newWithStorage3d)(THStorage *storage, long storageOffset, + long size0, long stride0, + long size1, long stride1, + long size2, long stride2) +{ + return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, size1, stride1, size2, stride2, -1, -1); +} + +THTensor *THTensor_(newWithStorage4d)(THStorage *storage, long storageOffset, + long size0, long stride0, + long size1, long stride1, + long size2, long stride2, + long size3, long stride3) +{ + long size[4] = {size0, size1, size2, size3}; + long stride[4] = {stride0, stride1, stride2, stride3}; + + THTensor *self = THAlloc(sizeof(THTensor)); + THTensor_(rawInit)(self); + THTensor_(rawSet)(self, storage, storageOffset, 4, size, stride); + + return self; +} + +THTensor *THTensor_(newWithSize)(THLongStorage *size, THLongStorage *stride) +{ + return THTensor_(newWithStorage)(NULL, 0, size, stride); +} + +THTensor *THTensor_(newWithSize1d)(long size0) +{ + return THTensor_(newWithSize4d)(size0, -1, -1, -1); +} + +THTensor *THTensor_(newWithSize2d)(long size0, long size1) +{ + return THTensor_(newWithSize4d)(size0, size1, -1, -1); +} + +THTensor *THTensor_(newWithSize3d)(long size0, long size1, long size2) +{ + return THTensor_(newWithSize4d)(size0, size1, size2, -1); +} + +THTensor *THTensor_(newWithSize4d)(long size0, long size1, long size2, long size3) +{ + long size[4] = {size0, size1, size2, size3}; + + THTensor *self = THAlloc(sizeof(THTensor)); + THTensor_(rawInit)(self); + THTensor_(rawResize)(self, 4, size, NULL); + + return self; +} + +THTensor *THTensor_(newClone)(THTensor *self) +{ + THTensor *tensor = THTensor_(new)(); + THTensor_(resizeAs)(tensor, self); + THTensor_(copy)(tensor, self); + return tensor; +} + +THTensor *THTensor_(newContiguous)(THTensor *self) +{ + if(!THTensor_(isContiguous)(self)) + return THTensor_(newClone)(self); + else + { + THTensor_(retain)(self); + return self; + } +} + +THTensor *THTensor_(newSelect)(THTensor *tensor, int dimension_, long sliceIndex_) +{ + THTensor *self = THTensor_(newWithTensor)(tensor); + THTensor_(select)(self, NULL, dimension_, sliceIndex_); + return self; +} + +THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long firstIndex_, long size_) +{ + THTensor *self = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(self, NULL, dimension_, firstIndex_, size_); + return self; +} + +THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_) +{ + THTensor *self = THTensor_(newWithTensor)(tensor); + THTensor_(transpose)(self, NULL, dimension1_, dimension2_); + return self; +} + +THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_) +{ + THTensor *self = THTensor_(newWithTensor)(tensor); + THTensor_(unfold)(self, NULL, dimension_, size_, step_); + return self; +} + +/* Resize */ +void THTensor_(resize)(THTensor *self, THLongStorage *size, THLongStorage *stride) +{ + THArgCheck(size != NULL, 2, "invalid size"); + if(stride) + THArgCheck(stride->size == size->size, 3, "invalid stride"); + + THTensor_(rawResize)(self, size->size, size->data, (stride ? stride->data : NULL)); +} + +void THTensor_(resizeAs)(THTensor *self, THTensor *src) +{ + int isSame = 0; + int d; + if(self->nDimension == src->nDimension) + { + isSame = 1; + for(d = 0; d < self->nDimension; d++) + { + if(self->size[d] != src->size[d]) + { + isSame = 0; + break; + } + } + } + + if(!isSame) + THTensor_(rawResize)(self, src->nDimension, src->size, NULL); +} + +void THTensor_(resize1d)(THTensor *tensor, long size0) +{ + THTensor_(resize4d)(tensor, size0, -1, -1, -1); +} + +void THTensor_(resize2d)(THTensor *tensor, long size0, long size1) +{ + THTensor_(resize4d)(tensor, size0, size1, -1, -1); +} + +void THTensor_(resize3d)(THTensor *tensor, long size0, long size1, long size2) +{ + THTensor_(resize4d)(tensor, size0, size1, size2, -1); +} + +void THTensor_(resize4d)(THTensor *self, long size0, long size1, long size2, long size3) +{ + long size[4] = {size0, size1, size2, size3}; + + THTensor_(rawResize)(self, 4, size, NULL); +} + +void THTensor_(resize5d)(THTensor *self, long size0, long size1, long size2, long size3, long size4) +{ + long size[5] = {size0, size1, size2, size3, size4}; + + THTensor_(rawResize)(self, 5, size, NULL); +} + +void THTensor_(set)(THTensor *self, THTensor *src) +{ + if(self != src) + THTensor_(rawSet)(self, + src->storage, + src->storageOffset, + src->nDimension, + src->size, + src->stride); +} + +void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_) +{ + if(size_ && stride_) + THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes"); + + THTensor_(rawSet)(self, + storage_, + storageOffset_, + (size_ ? size_->size : (stride_ ? stride_->size : 0)), + (size_ ? size_->data : NULL), + (stride_ ? stride_->data : NULL)); +} + +void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_, + long size0_, long stride0_) +{ + THTensor_(setStorage4d)(self, storage_, storageOffset_, + size0_, stride0_, + -1, -1, + -1, -1, + -1, -1); +} + +void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_) +{ + THTensor_(setStorage4d)(self, storage_, storageOffset_, + size0_, stride0_, + size1_, stride1_, + -1, -1, + -1, -1); +} + +void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_, + long size2_, long stride2_) +{ + THTensor_(setStorage4d)(self, storage_, storageOffset_, + size0_, stride0_, + size1_, stride1_, + size2_, stride2_, + -1, -1); +} + +void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_, + long size2_, long stride2_, + long size3_, long stride3_) +{ + + long size[4] = {size0_, size1_, size2_, size3_}; + long stride[4] = {stride0_, stride1_, stride2_, stride3_}; + + THTensor_(rawSet)(self, storage_, storageOffset_, 4, size, stride); +} + + +void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension, long firstIndex, long size) +{ + if(!src) + src = self; + + THArgCheck( (dimension >= 0) && (dimension < src->nDimension), 3, "out of range"); + THArgCheck( (firstIndex >= 0) && (firstIndex < src->size[dimension]), 4, "out of range"); + THArgCheck( (size > 0) && (firstIndex+size <= src->size[dimension]), 5, "out of range"); + + THTensor_(set)(self, src); + + if(firstIndex > 0) + self->storageOffset += firstIndex*self->stride[dimension]; + + self->size[dimension] = size; +} + +void THTensor_(select)(THTensor *self, THTensor *src, int dimension, long sliceIndex) +{ + int d; + + if(!src) + src = self; + + THArgCheck(src->nDimension > 1, 1, "cannot select on a vector"); + THArgCheck((dimension >= 0) && (dimension < src->nDimension), 3, "out of range"); + THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 4, "out of range"); + + THTensor_(set)(self, src); + THTensor_(narrow)(self, NULL, dimension, sliceIndex, 1); + for(d = dimension; d < self->nDimension-1; d++) + { + self->size[d] = self->size[d+1]; + self->stride[d] = self->stride[d+1]; + } + self->nDimension--; +} + +void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1, int dimension2) +{ + long z; + + if(!src) + src = self; + + THArgCheck( (dimension1 >= 0) && (dimension1 < src->nDimension), 1, "out of range"); + THArgCheck( (dimension2 >= 0) && (dimension2 < src->nDimension), 2, "out of range"); + + THTensor_(set)(self, src); + + if(dimension1 == dimension2) + return; + + z = self->stride[dimension1]; + self->stride[dimension1] = self->stride[dimension2]; + self->stride[dimension2] = z; + z = self->size[dimension1]; + self->size[dimension1] = self->size[dimension2]; + self->size[dimension2] = z; +} + +void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension, long size, long step) +{ + long *newSize; + long *newStride; + int d; + + if(!src) + src = self; + + THArgCheck( (src->nDimension > 0), 1, "cannot unfold an empty tensor"); + THArgCheck(dimension < src->nDimension, 2, "out of range"); + THArgCheck(size <= src->size[dimension], 3, "out of range"); + THArgCheck(step > 0, 4, "invalid step"); + + THTensor_(set)(self, src); + + newSize = THAlloc(sizeof(long)*(self->nDimension+1)); + newStride = THAlloc(sizeof(long)*(self->nDimension+1)); + + newSize[self->nDimension] = size; + newStride[self->nDimension] = self->stride[dimension]; + for(d = 0; d < self->nDimension; d++) + { + if(d == dimension) + { + newSize[d] = (self->size[d] - size) / step + 1; + newStride[d] = step*self->stride[d]; + } + else + { + newSize[d] = self->size[d]; + newStride[d] = self->stride[d]; + } + } + + THFree(self->size); + THFree(self->stride); + + self->size = newSize; + self->stride = newStride; + self->nDimension++; +} + +/* we have to handle the case where the result is a number */ +void THTensor_(squeeze)(THTensor *self, THTensor *src) +{ + int ndim = 0; + int d; + + if(!src) + src = self; + + THTensor_(set)(self, src); + + for(d = 0; d < src->nDimension; d++) + { + if(src->size[d] != 1) + { + if(d != ndim) + { + self->size[ndim] = src->size[d]; + self->stride[ndim] = src->stride[d]; + } + ndim++; + } + } + + /* right now, we do not handle 0-dimension tensors */ + if(ndim == 0 && src->nDimension > 0) + { + self->size[0] = 1; + self->stride[0] = 1; + ndim = 1; + } + self->nDimension = ndim; +} + +void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension) +{ + int d; + + if(!src) + src = self; + + THArgCheck(dimension < src->nDimension, 3, "dimension out of range"); + + THTensor_(set)(self, src); + + if(src->size[dimension] == 1 && src->nDimension > 1) + { + for(d = dimension; d < self->nDimension-1; d++) + { + self->size[d] = self->size[d+1]; + self->stride[d] = self->stride[d+1]; + } + self->nDimension--; + } +} + +int THTensor_(isContiguous)(THTensor *self) +{ + long z = 1; + int d; + for(d = self->nDimension-1; d >= 0; d--) + { + if(self->size[d] != 1) + { + if(self->stride[d] == z) + z *= self->size[d]; + else + return 0; + } + } + return 1; +} + +long THTensor_(nElement)(THTensor *self) +{ + if(self->nDimension == 0) + return 0; + else + { + long nElement = 1; + int d; + for(d = 0; d < self->nDimension; d++) + nElement *= self->size[d]; + return nElement; + } +} + +void THTensor_(retain)(THTensor *self) +{ + if(self->flag & TH_TENSOR_REFCOUNTED) + ++self->refcount; +} + +void THTensor_(free)(THTensor *self) +{ + if(!self) + return; + + if(self->flag & TH_TENSOR_REFCOUNTED) + { + if(--self->refcount == 0) + { + THFree(self->size); + THFree(self->stride); + if(self->storage) + THStorage_(free)(self->storage); + THFree(self); + } + } +} + +void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst) +{ + if(self != dst) + THTensor_(copy)(dst, self); + + THTensor_(free)(self); +} + +/*******************************************************************************/ + +static void THTensor_(rawInit)(THTensor *self) +{ + self->refcount = 1; + self->storage = NULL; + self->storageOffset = 0; + self->size = NULL; + self->stride = NULL; + self->nDimension = 0; + self->flag = TH_TENSOR_REFCOUNTED; +} + +static void THTensor_(rawSet)(THTensor *self, THStorage *storage, long storageOffset, int nDimension, long *size, long *stride) +{ + /* storage */ + if(self->storage != storage) + { + if(self->storage) + THStorage_(free)(self->storage); + + if(storage) + { + self->storage = storage; + THStorage_(retain)(self->storage); + } + else + self->storage = NULL; + } + + /* storageOffset */ + if(storageOffset < 0) + THError("Tensor: invalid storage offset"); + self->storageOffset = storageOffset; + + /* size and stride */ + THTensor_(rawResize)(self, nDimension, size, stride); +} + +static void THTensor_(rawResize)(THTensor *self, int nDimension, long *size, long *stride) +{ + int d; + int nDimension_; + long totalSize; + + nDimension_ = 0; + for(d = 0; d < nDimension; d++) + { + if(size[d] > 0) + nDimension_++; + else + break; + } + nDimension = nDimension_; + + if(nDimension > 0) + { + if(nDimension != self->nDimension) + { + self->size = THRealloc(self->size, sizeof(long)*nDimension); + self->stride = THRealloc(self->stride, sizeof(long)*nDimension); + self->nDimension = nDimension; + } + + totalSize = 1; + for(d = self->nDimension-1; d >= 0; d--) + { + self->size[d] = size[d]; + if(stride && (stride[d] >= 0) ) + self->stride[d] = stride[d]; + else + { + if(d == self->nDimension-1) + self->stride[d] = 1; + else + self->stride[d] = self->size[d+1]*self->stride[d+1]; + } + totalSize += (self->size[d]-1)*self->stride[d]; + } + + if(totalSize+self->storageOffset > 0) + { + if(!self->storage) + self->storage = THStorage_(new)(); + if(totalSize+self->storageOffset > self->storage->size) + THStorage_(resize)(self->storage, totalSize+self->storageOffset); + } + } + else + self->nDimension = 0; +} + +void THTensor_(set1d)(THTensor *tensor, long x0, real value) +{ + THArgCheck(tensor->nDimension == 1, 1, "tensor must have one dimension"); + THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]), 2, "out of range"); + THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0], value); +} + +real THTensor_(get1d)(THTensor *tensor, long x0) +{ + THArgCheck(tensor->nDimension == 1, 1, "tensor must have one dimension"); + THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]), 2, "out of range"); + return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]); +} + +void THTensor_(set2d)(THTensor *tensor, long x0, long x1, real value) +{ + THArgCheck(tensor->nDimension == 2, 1, "tensor must have two dimensions"); + THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]), 2, "out of range"); + THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1], value); +} + +real THTensor_(get2d)(THTensor *tensor, long x0, long x1) +{ + THArgCheck(tensor->nDimension == 2, 1, "tensor must have two dimensions"); + THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]), 2, "out of range"); + return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]); +} + +void THTensor_(set3d)(THTensor *tensor, long x0, long x1, long x2, real value) +{ + THArgCheck(tensor->nDimension == 3, 1, "tensor must have three dimensions"); + THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]), 2, "out of range"); + THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2], value); +} + +real THTensor_(get3d)(THTensor *tensor, long x0, long x1, long x2) +{ + THArgCheck(tensor->nDimension == 3, 1, "tensor must have three dimensions"); + THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]), 2, "out of range"); + return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]); +} + +void THTensor_(set4d)(THTensor *tensor, long x0, long x1, long x2, long x3, real value) +{ + THArgCheck(tensor->nDimension == 4, 1, "tensor must have four dimensions"); + THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]) && (x3 >= 0) && (x3 < tensor->size[3]), 2, "out of range"); + THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]+x3*tensor->stride[3], value); +} + +real THTensor_(get4d)(THTensor *tensor, long x0, long x1, long x2, long x3) +{ + THArgCheck(tensor->nDimension == 4, 1, "tensor must have four dimensions"); + THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]) && (x3 >= 0) && (x3 < tensor->size[3]), 2, "out of range"); + return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]+x3*tensor->stride[3]); +} + +#endif diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h new file mode 100644 index 00000000000..ca0f00c8958 --- /dev/null +++ b/lib/TH/generic/THTensor.h @@ -0,0 +1,123 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensor.h" +#else + +/* a la lua? dim, storageoffset, ... et les methodes ? */ + +#define TH_TENSOR_REFCOUNTED 1 + +typedef struct THTensor +{ + long *size; + long *stride; + int nDimension; + + THStorage *storage; + long storageOffset; + int refcount; + + char flag; + +} THTensor; + + +/**** access methods ****/ +TH_API THStorage* THTensor_(storage)(THTensor *self); +TH_API long THTensor_(storageOffset)(THTensor *self); +TH_API int THTensor_(nDimension)(THTensor *self); +TH_API long THTensor_(size)(THTensor *self, int dim); +TH_API long THTensor_(stride)(THTensor *self, int dim); +TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self); +TH_API THLongStorage *THTensor_(newStrideOf)(THTensor *self); +TH_API real *THTensor_(data)(THTensor *self); + +TH_API void THTensor_(setFlag)(THTensor *self, const char flag); +TH_API void THTensor_(clearFlag)(THTensor *self, const char flag); + + +/**** creation methods ****/ +TH_API THTensor *THTensor_(new)(void); +TH_API THTensor *THTensor_(newWithTensor)(THTensor *tensor); +/* stride might be NULL */ +TH_API THTensor *THTensor_(newWithStorage)(THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_); +TH_API THTensor *THTensor_(newWithStorage1d)(THStorage *storage_, long storageOffset_, + long size0_, long stride0_); +TH_API THTensor *THTensor_(newWithStorage2d)(THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_); +TH_API THTensor *THTensor_(newWithStorage3d)(THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_, + long size2_, long stride2_); +TH_API THTensor *THTensor_(newWithStorage4d)(THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_, + long size2_, long stride2_, + long size3_, long stride3_); + +/* stride might be NULL */ +TH_API THTensor *THTensor_(newWithSize)(THLongStorage *size_, THLongStorage *stride_); +TH_API THTensor *THTensor_(newWithSize1d)(long size0_); +TH_API THTensor *THTensor_(newWithSize2d)(long size0_, long size1_); +TH_API THTensor *THTensor_(newWithSize3d)(long size0_, long size1_, long size2_); +TH_API THTensor *THTensor_(newWithSize4d)(long size0_, long size1_, long size2_, long size3_); + +TH_API THTensor *THTensor_(newClone)(THTensor *self); +TH_API THTensor *THTensor_(newContiguous)(THTensor *tensor); +TH_API THTensor *THTensor_(newSelect)(THTensor *tensor, int dimension_, long sliceIndex_); +TH_API THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long firstIndex_, long size_); +TH_API THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_); +TH_API THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_); + +TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride); +TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src); +TH_API void THTensor_(resize1d)(THTensor *tensor, long size0_); +TH_API void THTensor_(resize2d)(THTensor *tensor, long size0_, long size1_); +TH_API void THTensor_(resize3d)(THTensor *tensor, long size0_, long size1_, long size2_); +TH_API void THTensor_(resize4d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_); +TH_API void THTensor_(resize5d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_, long size4_); + +TH_API void THTensor_(set)(THTensor *self, THTensor *src); +TH_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_); +TH_API void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_, + long size0_, long stride0_); +TH_API void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_); +TH_API void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_, + long size2_, long stride2_); +TH_API void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_, + long size0_, long stride0_, + long size1_, long stride1_, + long size2_, long stride2_, + long size3_, long stride3_); + +TH_API void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension_, long firstIndex_, long size_); +TH_API void THTensor_(select)(THTensor *self, THTensor *src, int dimension_, long sliceIndex_); +TH_API void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1_, int dimension2_); +TH_API void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension_, long size_, long step_); + +TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src); +TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_); + +TH_API int THTensor_(isContiguous)(THTensor *self); +TH_API long THTensor_(nElement)(THTensor *self); + +TH_API void THTensor_(retain)(THTensor *self); +TH_API void THTensor_(free)(THTensor *self); +TH_API void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst); + +/* Slow access methods [check everything] */ +TH_API void THTensor_(set1d)(THTensor *tensor, long x0, real value); +TH_API void THTensor_(set2d)(THTensor *tensor, long x0, long x1, real value); +TH_API void THTensor_(set3d)(THTensor *tensor, long x0, long x1, long x2, real value); +TH_API void THTensor_(set4d)(THTensor *tensor, long x0, long x1, long x2, long x3, real value); + +TH_API real THTensor_(get1d)(THTensor *tensor, long x0); +TH_API real THTensor_(get2d)(THTensor *tensor, long x0, long x1); +TH_API real THTensor_(get3d)(THTensor *tensor, long x0, long x1, long x2); +TH_API real THTensor_(get4d)(THTensor *tensor, long x0, long x1, long x2, long x3); + +#endif diff --git a/lib/TH/generic/THTensorConv.c b/lib/TH/generic/THTensorConv.c new file mode 100644 index 00000000000..b351166d3b8 --- /dev/null +++ b/lib/TH/generic/THTensorConv.c @@ -0,0 +1,1489 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorConv.c" +#else + +/* + 2D Input, 2D kernel : convolve given image with the given kernel. +*/ +TH_API void THTensor_(validXCorr2Dptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc) +{ + long or = (ir - kr) / sr + 1; + long oc = (ic - kc) / sc + 1; + + long xx, yy, kx, ky; + + if ((sc != 1) || (oc < 4)) { + // regular convolution + for(yy = 0; yy < or; yy++) { + for(xx = 0; xx < oc; xx++) { + /* Dot product in two dimensions... (between input image and the mask) */ + real *pi_ = t_ + yy*sr*ic + xx*sc; + real *pw_ = k_; + real sum = 0; + for(ky = 0; ky < kr; ky++) { + for(kx = 0; kx < kc; kx++) { + sum += pi_[kx]*pw_[kx]; + } + pi_ += ic; /* next input line */ + pw_ += kc; /* next mask line */ + } + /* Update output */ + *r_ += alpha*sum; + *r_++; + } + } + + } else { + // SSE-based convolution + for(yy = 0; yy < or; yy++) { + real *pi_ = t_ + yy*sr*ic; + real *pw_ = k_; + for (ky = 0; ky < kr; ky++) { + real *pis_ = pi_; + for (kx = 0; kx < kc; kx++) { + THVector_(add)(r_, pis_, alpha*pw_[kx], oc); + pis_++; + } + pi_ += ic; /* next input line */ + pw_ += kc; /* next mask line */ + } + r_ += oc; + } + } +} + +/* + 2D Input, 2D kernel : convolve given image with the given kernel. +*/ +TH_API void THTensor_(validConv2Dptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc) +{ + long or = (ir - kr) / sr + 1; + long oc = (ic - kc) / sc + 1; + + long xx, yy, kx, ky; + + if ((sc != 1) || (oc < 4)) { + // regular convolution + for(yy = 0; yy < or; yy++) { + for(xx = 0; xx < oc; xx++) { + /* Dot product in two dimensions... (between input image and the mask) */ + real *pi_ = t_ + yy*sr*ic + xx*sc; + real *pw_ = k_ + kr*kc - 1; + real sum = 0; + for(ky = 0; ky < kr; ky++) { + for(kx = 0; kx < kc; kx++) { + sum += pi_[kx]*pw_[-kx]; + } + pi_ += ic; /* next input line */ + pw_ -= kc; /* next mask line */ + } + /* Update output */ + *r_ += alpha*sum; + *r_++; + } + } + + } else { + // SSE-based convolution + for(yy = 0; yy < or; yy++) { + real *pw_ = k_ + kr*kc - 1; + real *pi_ = t_ + yy*sr*ic; + for (ky = 0; ky < kr; ky++) { + real *pis_ = pi_; + for (kx = 0; kx < kc; kx++) { + THVector_(add)(r_, pis_, alpha*pw_[-kx], oc); + pis_++; + } + pi_ += ic; /* next input line */ + pw_ -= kc; /* next mask line */ + } + r_ += oc; + } + } +} + +/* + 2D Input, 2D kernel : convolve given image with the given kernel, full convolution. +*/ +TH_API void THTensor_(fullConv2Dptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc) +{ + long or = (ir - 1) * sr + kr; + long oc = (ic - 1) * sc + kc; + + long xx, yy, kx, ky; + + if ((sc != 1) || (ic < 4)) { + // regular convolution + for(yy = 0; yy < ir; yy++) { + for(xx = 0; xx < ic; xx++) { + /* Outer product in two dimensions... (between input image and the mask) */ + real *po_ = r_ + yy*sr*oc + xx*sc; + real *pw_ = k_; + for(ky = 0; ky < kr; ky++) + { + real z = *t_ * alpha; + for(kx = 0; kx < kc; kx++) { + po_[kx] += z * pw_[kx]; + } + po_ += oc; /* next input line */ + pw_ += kc; /* next mask line */ + } + t_++; + } + } + + } else { + // SSE-based convolution + for(yy = 0; yy < ir; yy++) { + real *po_ = r_ + yy*sr*oc; + real *pw_ = k_; + for (ky = 0; ky < kr; ky++) { + real *pos_ = po_; + for (kx = 0; kx < kc; kx++) { + THVector_(add)(pos_, t_, alpha*pw_[kx], ic); + pos_++; + } + po_ += oc; /* next input line */ + pw_ += kc; /* next mask line */ + } + t_ += ic; + } + } +} + +/* + 2D Input, 2D kernel : convolve given image with the given kernel, full convolution. +*/ +TH_API void THTensor_(fullXCorr2Dptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc) +{ + long or = (ir - 1) * sr + kr; + long oc = (ic - 1) * sc + kc; + + long xx, yy, kx, ky; + + if ((sc != 1) || (ic < 4)) { + // regular convolution + for(yy = 0; yy < ir; yy++) { + for(xx = 0; xx < ic; xx++) { + /* Outer product in two dimensions... (between input image and the mask) */ + real *po_ = r_ + yy*sr*oc + xx*sc; + real *pw_ = k_ + kr*kc -1; + long kx, ky; + for(ky = 0; ky < kr; ky++) + { + real z = *t_ * alpha; + for(kx = 0; kx < kc; kx++) { + po_[kx] += z * pw_[-kx]; + } + po_ += oc; /* next input line */ + pw_ -= kc; /* next mask line */ + } + t_++; + } + } + + } else { + // SSE-based convolution + for(yy = 0; yy < ir; yy++) { + real *po_ = r_ + yy*sr*oc; + real *pw_ = k_ + kr*kc -1; + for (ky = 0; ky < kr; ky++) { + real *pos_ = po_; + for (kx = 0; kx < kc; kx++) { + THVector_(add)(pos_, t_, pw_[-kx]*alpha, ic); + pos_++; + } + po_ += oc; /* next input line */ + pw_ -= kc; /* next mask line */ + } + t_ += ic; + } + } +} + +/* + 2D Input, 2D kernel : convolve given image with the given kernel, valid convolution. + for sr,sc=1 this is equivalent to validXCorr2Dptr, but otherwise it is useful for + calculating derivatives wrt a kernel that is applied with stride sr,sc != 1 +*/ +TH_API void THTensor_(validXCorr2DRevptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc) +{ + long or = ir - (kr - 1) * sr; + long oc = ic - (kc - 1) * sc; + + long xx, yy, kx, ky; + + if ((sc != 1) || (kc < 4)) { + // regular convolution + for(yy = 0; yy < kr; yy++) { + for(xx = 0; xx < kc; xx++) { + real *po_ = r_; + real *pi_ = t_ + yy*sr*ic + xx*sc; + real z = *k_++ * alpha; + + for(ky = 0; ky < or; ky++) { + for(kx = 0; kx < oc; kx++) + po_[kx] += z * pi_[kx]; + pi_ += ic; + po_ += oc; + } + } + } + + } else { + // SSE-based convolution + for(yy = 0; yy < kr; yy++) { + for(xx = 0; xx < kc; xx++) { + real *po_ = r_; + real *pi_ = t_ + yy*sr*ic + xx*sc; + real z = *k_++ * alpha; + + for(ky = 0; ky < or; ky++) { + THVector_(add)(po_, pi_, z, oc); + pi_ += ic; + po_ += oc; + } + } + } + } +} +/* + 3D Input, 3D kernel : convolve given volume with the given kernel. +*/ +TH_API void THTensor_(validXCorr3Dptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc) +{ + long ot = (it - kt) / st + 1; + long or = (ir - kr) / sr + 1; + long oc = (ic - kc) / sc + 1; + + long zz, xx, yy; + + for (zz = 0; zz < ot; zz++) + { + for(yy = 0; yy < or; yy++) + { + for(xx = 0; xx < oc; xx++) + { + /* Dot product in two dimensions... (between input image and the mask) */ + real *pi_ = t_ + zz*st*ir*ic + yy*sr*ic + xx*sc; + real *pw_ = k_; + real sum = 0; + long kz, kx, ky; + for(kz = 0; kz < kt; kz++) + { + for(ky = 0; ky < kr; ky++) + { + for(kx = 0; kx < kc; kx++) { + sum += pi_[kx]*pw_[kx]; + } + pi_ += ic; /* next input line */ + pw_ += kc; /* next mask line */ + } + } + /* Update output */ + *r_ += sum*alpha; + *r_++; + } + } + } +} + +/* + 3D Input, 3D kernel : convolve given volume with the given kernel. +*/ +TH_API void THTensor_(validConv3Dptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc) +{ + long ot = (it - kt) / st + 1; + long or = (ir - kr) / sr + 1; + long oc = (ic - kc) / sc + 1; + + long zz, xx, yy; + + for(zz = 0; zz < ot; zz++) + { + for(yy = 0; yy < or; yy++) + { + for(xx = 0; xx < oc; xx++) + { + /* Dot product in two dimensions... (between input image and the mask) */ + real *pi_ = t_ + zz*st*ir*ic + yy*sr*ic + xx*sc; + real *pw_ = k_ + kt*kr*kc - 1; + real sum = 0; + long kz, kx, ky; + for(kz = 0; kz < kt; kz++) + { + for(ky = 0; ky < kr; ky++) + { + for(kx = 0; kx < kc; kx++) { + sum += pi_[kx]*pw_[-kx]; + } + pi_ += ic; /* next input line */ + pw_ -= kc; /* next mask line */ + } + } + /* Update output */ + *r_ += alpha*sum; + *r_++; + } + } + } +} + + +/* + 3D Input, 3D kernel : convolve given volume with the given kernel, full convolution. +*/ +TH_API void THTensor_(fullConv3Dptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc) +{ + long ot = (it - 1) * st + kt; + long or = (ir - 1) * sr + kr; + long oc = (ic - 1) * sc + kc; + + long zz, xx, yy; + + for(zz = 0; zz < it; zz++) + { + for(yy = 0; yy < ir; yy++) + { + for(xx = 0; xx < ic; xx++) + { + /* Outer product in two dimensions... (between input image and the mask) */ + real *po_ = r_ + zz*st*or*oc + yy*sr*oc + xx*sc; + real *pw_ = k_; + long kz, kx, ky; + //printf("Output Plane : %ld,%ld,%ld, input val=%g\n",zz,yy,xx,*t_); + for(kz = 0; kz < kt; kz++) + { + for(ky = 0; ky < kr; ky++) + { + real z = *t_ * alpha; + for(kx = 0; kx < kc; kx++) { + //printf("o=%g,k=%g," , po_[kx],pw_[kx]); + po_[kx] += z * pw_[kx]; + //printf("o=%g " , po_[kx]); + } + //printf("\n"); + po_ += oc; /* next input line */ + pw_ += kc; /* next mask line */ + } + //printf("\n"); + } + t_++; + } + } + } +} + +/* + 3D Input, 3D kernel : convolve given volume with the given kernel, full convolution. +*/ +TH_API void THTensor_(fullXCorr3Dptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc) +{ + long ot = (it - 1) * st + kt; + long or = (ir - 1) * sr + kr; + long oc = (ic - 1) * sc + kc; + + long zz, xx, yy; + + for(zz = 0; zz < it; zz++) + { + for(yy = 0; yy < ir; yy++) + { + for(xx = 0; xx < ic; xx++) + { + /* Outer product in two dimensions... (between input image and the mask) */ + real *po_ = r_ + zz*st*or*oc + yy*sr*oc + xx*sc; + real *pw_ = k_ + kt*kr*kc -1; + long kz, kx, ky; + for(kz = 0; kz < kt; kz++) + { + for(ky = 0; ky < kr; ky++) + { + real z = *t_ * alpha; + for(kx = 0; kx < kc; kx++) { + po_[kx] += z * pw_[-kx]; + } + po_ += oc; /* next input line */ + pw_ -= kc; /* next mask line */ + } + } + t_++; + } + } + } +} + +/* + 3D Input, 3D kernel : convolve given image with the given kernel, valid convolution. + for sr,sc=1 this is equivalent to validXCorr3Dptr, but otherwise it is useful for + calculating derivatives wrt a kernel that is applied with stride sr,sc != 1 +*/ +TH_API void THTensor_(validXCorr3DRevptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc) +{ + long ot = it - (kt - 1) * st; + long or = ir - (kr - 1) * sr; + long oc = ic - (kc - 1) * sc; + + long zz, xx, yy; + for(zz = 0; zz < kt; zz++) + { + for(yy = 0; yy < kr; yy++) + { + for(xx = 0; xx < kc; xx++) + { + real *po_ = r_; + real *pi_ = t_ + zz*st*ir*ic + yy*sr*ic + xx*sc; + real z = *k_++ * alpha; + long kz, kx, ky; + for(kz = 0; kz < ot; kz++) + { + for(ky = 0; ky < or; ky++) + { + for(kx = 0; kx < oc; kx++) + po_[kx] += z * pi_[kx]; + pi_ += ic; + po_ += oc; + } + } + } + } + } +} + +void THTensor_(conv2d)(real* output_data, + real alpha, + real* ptr_input, long nInputRows, long nInputCols, + real* ptr_weight, long nKernelRows, long nKernelCols, + long srow, long scol, + const char *vf, const char *xc) +{ + THArgCheck(*vf == 'V' || *vf == 'F', 7, "type of convolution can be 'V' or 'F'"); + THArgCheck(*xc == 'C' || *xc == 'X', 7, "type of convolution can be 'X' or 'C'"); + if (*vf == 'F') + if (*xc == 'X') + THTensor_(fullXCorr2Dptr)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol); + else + THTensor_(fullConv2Dptr)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol); + else + if (*xc == 'X') + THTensor_(validXCorr2Dptr)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol); + else + THTensor_(validConv2Dptr)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol); +} + +void THTensor_(conv3d)(real* output_data, + real alpha, + real* ptr_input, long nInputDepth, long nInputRows, long nInputCols, + real* ptr_weight, long nKernelDepth, long nKernelRows, long nKernelCols, + long sdepth, long srow, long scol, + const char *vf, const char *xc) +{ + THArgCheck(*vf == 'V' || *vf == 'F', 7, "type of convolution can be 'V' or 'F'"); + THArgCheck(*xc == 'C' || *xc == 'X', 7, "type of convolution can be 'X' or 'C'"); + if (*vf == 'F') + if (*xc == 'X') + THTensor_(fullXCorr3Dptr)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol); + else + THTensor_(fullConv3Dptr)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol); + else + if (*xc == 'X') + THTensor_(validXCorr3Dptr)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol); + else + THTensor_(validConv3Dptr)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol); +} + +long THTensor_(convsize)(long x, long k, long s, const char* vf) +{ + THArgCheck(*vf == 'V' || *vf == 'F', 1, "type of convolution can be 'V' or 'F'"); + if (*vf == 'V') + return (x-k)/s + 1; + else + return (x-1)*s + k; +} + + +/* + 3D input, 3D kernel, 4D output + like rank1 update + A <- xx' + beta*A + for sr,sc=1 this is equivalent to xcorr2Dger, but otherwise it is useful for + calculating derivatives wrt a kernel that is applied with stride sr,sc != 1 +*/ +void THTensor_(conv2DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol) +{ + long nInputPlane, nInputRows, nInputCols; + long nKernelPlane, nKernelRows, nKernelCols; + long nOutputPlane, nOutputRows, nOutputCols; + long istride0, kstride0; + + THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected"); + THArgCheck(k_->nDimension == 3 , 4, "kernel: 3D Tensor expected"); + THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor *kernel = THTensor_(newContiguous)(k_); + + nInputPlane = input->size[0]; + istride0 = input->stride[0]; + nInputRows = input->size[1]; + nInputCols = input->size[2]; + + kstride0 = kernel->stride[0]; + nKernelPlane = kernel->size[0]; + nKernelRows = kernel->size[1]; + nKernelCols = kernel->size[2]; + nOutputPlane = nInputPlane * kernel->size[0]; + + THArgCheck(nInputRows >= nKernelRows && nInputCols >= nKernelCols , 2, "conv2DRevger : Input image is smaller than kernel"); + + nOutputRows = nInputRows - (nKernelRows - 1) * srow; + nOutputCols = nInputCols - (nKernelCols - 1) * scol; + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize4d)(r_,nKernelPlane, nInputPlane, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long k,i; + for(k = 0; k < nKernelPlane; k++) + { + /* get kernel */ + real *ptr_weight = weight_data+k*kstride0; + + for(i = 0; i < nInputPlane; i++) + { + /* get input */ + real *ptr_input = input_data+i*istride0; + + /* do image, kernel convolution */ + THTensor_(validXCorr2DRevptr)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol); + /* Next output plane */ + output_data += nOutputCols*nOutputRows; + } + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + + +/* + 3D input, 3D kernel, 4D output + like rank1 update + A <- xx' + beta*A +*/ +void THTensor_(conv2Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc) +{ + long nInputPlane, nInputRows, nInputCols; + long nKernelPlane, nKernelRows, nKernelCols; + long nOutputPlane, nOutputRows, nOutputCols; + long istride0, kstride0; + + THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected"); + THArgCheck(k_->nDimension == 3 , 4, "kernel: 3D Tensor expected"); + THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor *kernel = THTensor_(newContiguous)(k_); + + nInputPlane = input->size[0]; + istride0 = input->stride[0]; + nInputRows = input->size[1]; + nInputCols = input->size[2]; + + kstride0 = kernel->stride[0]; + nKernelPlane = kernel->size[0]; + nKernelRows = kernel->size[1]; + nKernelCols = kernel->size[2]; + nOutputPlane = nInputPlane * kernel->size[0]; + + THArgCheck((nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *vf == 'F', 2, "conv2Dger : Input image is smaller than kernel"); + + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize4d)(r_,nKernelPlane, nInputPlane, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long k,i; + for(k = 0; k < nKernelPlane; k++) + { + /* get kernel */ + real *ptr_weight = weight_data+k*kstride0; + + for(i = 0; i < nInputPlane; i++) + { + /* get input */ + real *ptr_input = input_data+i*istride0; + + /* do image, kernel convolution */ + THTensor_(conv2d)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol, vf, xc); + + /* Next output plane */ + output_data += nOutputCols*nOutputRows; + } + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 3D input, 4D kernel, 3D output + matrix vector product like + y <- Ax + beta*y +*/ +void THTensor_(conv2Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc) +{ + long nInputPlane, nInputRows, nInputCols; + long nKernelRows, nKernelCols; + long nOutputPlane, nOutputRows, nOutputCols; + long istride0, kstride0, kstride1; + + THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected"); + THArgCheck(k_->nDimension == 4 , 4, "kernel: 4D Tensor expected"); + THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor* kernel; + if (!(k_->stride[3] == 1) || !(k_->stride[2] == k_->size[3])) { + kernel = THTensor_(newContiguous)(k_); + } else { + THTensor_(retain)(k_); + kernel = k_; + } + + nInputPlane = input->size[0]; + istride0 = input->stride[0]; + nInputRows = input->size[1]; + nInputCols = input->size[2]; + + kstride0 = kernel->stride[0]; + kstride1 = kernel->stride[1]; + nKernelRows = kernel->size[2]; + nKernelCols = kernel->size[3]; + nOutputPlane = kernel->size[0]; + THArgCheck(kernel->size[1] == nInputPlane, 2, "invalid number of input planes"); + + THArgCheck( (nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *vf == 'F', 2, "conv2Dmv : Input image is smaller than kernel"); + + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize3d)(r_, nOutputPlane, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long k,i; + for(k = 0; k < nOutputPlane; k++) + { + for(i = 0; i < nInputPlane; i++) + { + /* get kernel */ + real *ptr_weight = weight_data + k*kstride0 + i*kstride1; + /* get input */ + real *ptr_input = input_data + i*istride0; + + /* do image, kernel convolution */ + THTensor_(conv2d)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol, vf, xc); + } + /* Next output plane */ + output_data += nOutputCols*nOutputRows; + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 2D input, 2D kernel, 2D output + scalar multiplication like + y <- x*y + beta*y +*/ +void THTensor_(conv2Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc) +{ + + THArgCheck(t_->nDimension == 2 , 3, "input: 2D Tensor expected"); + THArgCheck(k_->nDimension == 2 , 4, "kernel: 2D Tensor expected"); + THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor* kernel = THTensor_(newContiguous)(k_); + + long nInputRows = input->size[0]; + long nInputCols = input->size[1]; + long nKernelRows = kernel->size[0]; + long nKernelCols = kernel->size[1]; + long nOutputRows, nOutputCols; + + THArgCheck((nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *vf == 'F', 2, "conv2Dmul : Input image is smaller than kernel"); + + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize2d)(r_, nOutputRows, nOutputCols); + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + THTensor_(zero)(r_); + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *ptr_input = THTensor_(data)(input); + real *ptr_weight = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + + /* do image, kernel convolution */ + THTensor_(conv2d)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol, vf, xc); + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 3D input, 3D kernel, 3D output + component wise multiplication like + y <- y.*x + beta*y +*/ +void THTensor_(conv2Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc) +{ + long nInputPlane, nInputRows, nInputCols; + long nKernelRows, nKernelCols; + long nOutputPlane, nOutputRows, nOutputCols; + long istride0, kstride0; + + THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected"); + THArgCheck(k_->nDimension == 3 , 4, "kernel: 3D Tensor expected"); + THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor* kernel = THTensor_(newContiguous)(k_); + + istride0 = input->stride[0]; + nInputPlane = input->size[0]; + nInputRows = input->size[1]; + nInputCols = input->size[2]; + + kstride0 = kernel->stride[0]; + nOutputPlane = kernel->size[0]; + nKernelRows = kernel->size[1]; + nKernelCols = kernel->size[2]; + + THArgCheck(nOutputPlane == nInputPlane, 2, "invalid number of input/kernel planes"); + THArgCheck( (nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *vf == 'F', 2, "conv2Dcmul : Input image is smaller than kernel"); + + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize3d)(r_, nOutputPlane, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long k; + for(k = 0; k < nOutputPlane; k++) + { + /* get kernel */ + real *ptr_weight = weight_data + k*kstride0; + /* get input */ + real *ptr_input = input_data + k*istride0; + + /* do image, kernel convolution */ + THTensor_(conv2d)(output_data, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol, vf, xc); + /* Next output plane */ + output_data += nOutputCols*nOutputRows; + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 3D input, 3D kernel, 3D output + component wise multiplication like with a permutation map + y <- y.*x + beta*y +*/ +void THTensor_(conv2Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, THTensor *map, long srow, long scol, const char *vf, const char *xc) +{ + long nInputPlane, nInputRows, nInputCols; + long nKernelRows, nKernelCols; + long nOutputPlane, nOutputRows, nOutputCols; + long istride0, kstride0; + + THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected"); + THArgCheck(k_->nDimension == 3 , 4, "kernel: 3D Tensor expected"); + THArgCheck(map->nDimension == 2 , 4, "map: 2D Tensor expected"); + THArgCheck(srow >= 1, 6, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 7, "Stride should be a positive integer"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor* kernel = THTensor_(newContiguous)(k_); + + istride0 = input->stride[0]; + nInputPlane = input->size[0]; + nInputRows = input->size[1]; + nInputCols = input->size[2]; + + kstride0 = kernel->stride[0]; + nOutputPlane = kernel->size[0]; + nKernelRows = kernel->size[1]; + nKernelCols = kernel->size[2]; + + THArgCheck(nOutputPlane == nInputPlane, 2, "invalid number of input/kernel planes"); + THArgCheck( (nInputRows >= nKernelRows && nInputCols >= nKernelCols) + || *vf == 'F', 2, "conv2Dmap : Input image is smaller than kernel"); + + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize3d)(r_, nOutputPlane, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long nmaps = map->size[0]; + + long k; + for(k = 0; k < nmaps; k++) + { + /* get indices */ + long from = (long)THTensor_(get2d)(map,k,0)-1; + long to = (long)THTensor_(get2d)(map,k,1)-1; + + /* get kernel */ + real *ptr_weight = weight_data + k*kstride0; + /* get input */ + real *ptr_input = input_data + from*istride0; + /* get output */ + real *ptr_output = output_data + to*nOutputRows*nOutputCols; + + /* do image, kernel convolution */ + THTensor_(conv2d)(ptr_output, + alpha, + ptr_input, nInputRows, nInputCols, + ptr_weight, nKernelRows, nKernelCols, + srow, scol, vf, xc); + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 4D input, 4D kernel, 5D output + like rank1 update + A <- xx' + beta*A + for sr,sc=1 this is equivalent to xcorr2Dger, but otherwise it is useful for + calculating derivatives wrt a kernel that is applied with stride sr,sc != 1 +*/ +void THTensor_(conv3DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, + long sdepth, long srow, long scol) +{ + long nInputPlane, nInputDepth, nInputRows, nInputCols; + long nKernelPlane, nKernelDepth, nKernelRows, nKernelCols; + long nOutputPlane, nOutputDepth, nOutputRows, nOutputCols; + long istride0, kstride0; + + THArgCheck(t_->nDimension == 4 , 3, "input: 4D Tensor expected"); + THArgCheck(k_->nDimension == 4 , 4, "kernel: 4D Tensor expected"); + THArgCheck(sdepth >= 1, 5, "Stride should be a positive integer"); + THArgCheck(srow >= 1, 6, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 7, "Stride should be a positive integer"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor *kernel = THTensor_(newContiguous)(k_); + + nInputPlane = input->size[0]; + istride0 = input->stride[0]; + nInputDepth = input->size[1]; + nInputRows = input->size[2]; + nInputCols = input->size[3]; + + kstride0 = kernel->stride[0]; + nKernelPlane = kernel->size[0]; + nKernelDepth= kernel->size[1]; + nKernelRows = kernel->size[2]; + nKernelCols = kernel->size[3]; + nOutputPlane = nInputPlane * kernel->size[0]; + + THArgCheck(nInputDepth >= nKernelDepth && nInputRows >= nKernelRows && nInputCols >= nKernelCols , 2, "conv3DRevger : Input image is smaller than kernel"); + + nOutputDepth = nInputDepth - (nKernelDepth - 1) * sdepth; + nOutputRows = nInputRows - (nKernelRows - 1) * srow; + nOutputCols = nInputCols - (nKernelCols - 1) * scol; + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize5d)(r_,nKernelPlane, nInputPlane, nOutputDepth, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long k,i; + for(k = 0; k < nKernelPlane; k++) + { + /* get kernel */ + real *ptr_weight = weight_data+k*kstride0; + + for(i = 0; i < nInputPlane; i++) + { + /* get input */ + real *ptr_input = input_data+i*istride0; + + /* do image, kernel convolution */ + THTensor_(validXCorr3DRevptr)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol); + /* Next output plane */ + output_data += nOutputDepth*nOutputCols*nOutputRows; + } + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + + +/* + 4D input, 4D kernel, 5D output + like rank1 update + A <- xx' + beta*A +*/ +void THTensor_(conv3Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, + long sdepth, long srow, long scol, const char *vf, const char *xc) +{ + long nInputPlane, nInputDepth, nInputRows, nInputCols; + long nKernelPlane, nKernelDepth, nKernelRows, nKernelCols; + long nOutputPlane, nOutputDepth, nOutputRows, nOutputCols; + long istride0, kstride0; + + THArgCheck(t_->nDimension == 4 , 3, "input: 4D Tensor expected"); + THArgCheck(k_->nDimension == 4 , 4, "kernel: 4D Tensor expected"); + THArgCheck(sdepth >= 1, 5, "Stride should be a positive integer"); + THArgCheck(srow >= 1, 6, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 7, "Stride should be a positive integer"); + THArgCheck(*vf == 'V' || *vf == 'F', 8, "type of convolution can 'V' or 'F'"); + THArgCheck(*xc == 'C' || *xc == 'X', 8, "type of convolution can 'X' or 'C'"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor *kernel = THTensor_(newContiguous)(k_); + + nInputPlane = input->size[0]; + istride0 = input->stride[0]; + nInputDepth = input->size[1]; + nInputRows = input->size[2]; + nInputCols = input->size[3]; + + kstride0 = kernel->stride[0]; + nKernelPlane = kernel->size[0]; + nKernelDepth = kernel->size[1]; + nKernelRows = kernel->size[2]; + nKernelCols = kernel->size[3]; + nOutputPlane = nInputPlane * kernel->size[0]; + + THArgCheck((nInputDepth >= nKernelDepth + && nInputRows >= nKernelRows + && nInputCols >= nKernelCols) + || *vf == 'F', 2, "conv3Dger : Input image is smaller than kernel"); + + nOutputDepth = THTensor_(convsize)(nInputDepth, nKernelDepth, sdepth, vf); + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize5d)(r_,nKernelPlane, nInputPlane, nOutputDepth, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long k,i; + for(k = 0; k < nKernelPlane; k++) + { + /* get kernel */ + real *ptr_weight = weight_data+k*kstride0; + + for(i = 0; i < nInputPlane; i++) + { + /* get input */ + real *ptr_input = input_data+i*istride0; + + /* do image, kernel convolution */ + THTensor_(conv3d)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol, vf, xc); + + /* Next output plane */ + output_data += nOutputDepth*nOutputCols*nOutputRows; + } + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 4D input, 5D kernel, 4D output + matrix vector product like + y <- Ax + beta*y +*/ +void THTensor_(conv3Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, + long sdepth, long srow, long scol, const char *vf, const char *xc) +{ + long nInputPlane, nInputDepth, nInputRows, nInputCols; + long nKernelDepth, nKernelRows, nKernelCols; + long nOutputPlane, nOutputDepth, nOutputRows, nOutputCols; + long istride0, kstride0, kstride1; + + THArgCheck(t_->nDimension == 4 , 3, "input: 4D Tensor expected"); + THArgCheck(k_->nDimension == 5 , 4, "kernel: 5D Tensor expected"); + THArgCheck(sdepth >= 1, 5, "Stride should be a positive integer"); + THArgCheck(srow >= 1, 6, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 7, "Stride should be a positive integer"); + THArgCheck(*vf == 'V' || *vf == 'F', 8, "type of convolution can 'V' or 'F'"); + THArgCheck(*xc == 'C' || *xc == 'X', 8, "type of convolution can 'X' or 'C'"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor* kernel; + if (!(k_->stride[4] == 1) || !(k_->stride[3] == k_->size[4])) { + kernel = THTensor_(newContiguous)(k_); + } else { + THTensor_(retain)(k_); + kernel = k_; + } + + nInputPlane = input->size[0]; + istride0 = input->stride[0]; + nInputDepth = input->size[1]; + nInputRows = input->size[2]; + nInputCols = input->size[3]; + + kstride0 = kernel->stride[0]; + kstride1 = kernel->stride[1]; + nKernelDepth = kernel->size[2]; + nKernelRows = kernel->size[3]; + nKernelCols = kernel->size[4]; + nOutputPlane = kernel->size[0]; + THArgCheck(kernel->size[1] == nInputPlane, 2, "invalid number of input planes"); + + THArgCheck( (nInputDepth >= nKernelDepth && nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *vf == 'F', 2, "conv3Dmv : Input image is smaller than kernel"); + + nOutputDepth = THTensor_(convsize)(nInputDepth, nKernelDepth, sdepth, vf); + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize4d)(r_, nOutputPlane, nOutputDepth, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long k,i; + for(k = 0; k < nOutputPlane; k++) + { + for(i = 0; i < nInputPlane; i++) + { + /* get kernel */ + real *ptr_weight = weight_data + k*kstride0 + i*kstride1; + /* get input */ + real *ptr_input = input_data + i*istride0; + + /* do image, kernel convolution */ + THTensor_(conv3d)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol, vf, xc); + } + /* Next output plane */ + output_data += nOutputDepth*nOutputCols*nOutputRows; + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 3D input, 3D kernel, 3D output + scalar multiplication like + y <- x*y + beta*y +*/ +void THTensor_(conv3Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, + long sdepth, long srow, long scol, const char *vf, const char *xc) +{ + + THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected"); + THArgCheck(k_->nDimension == 3 , 4, "kernel: 3D Tensor expected"); + THArgCheck(sdepth >= 1, 5, "Stride should be a positive integer"); + THArgCheck(srow >= 1, 6, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 7, "Stride should be a positive integer"); + THArgCheck(*vf == 'V' || *vf == 'F', 8, "type of convolution can 'V' or 'F'"); + THArgCheck(*xc == 'C' || *xc == 'X', 8, "type of convolution can 'X' or 'C'"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor* kernel = THTensor_(newContiguous)(k_); + + long nInputDepth = input->size[0]; + long nInputRows = input->size[1]; + long nInputCols = input->size[2]; + long nKernelDepth = kernel->size[0]; + long nKernelRows = kernel->size[1]; + long nKernelCols = kernel->size[2]; + long nOutputDepth, nOutputRows, nOutputCols; + + THArgCheck((nInputDepth >= nKernelDepth && nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *vf == 'F', 2, "conv3Dmul : Input image is smaller than kernel"); + + nOutputDepth = THTensor_(convsize)(nInputDepth, nKernelDepth, sdepth, vf); + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize3d)(r_, nOutputDepth, nOutputRows, nOutputCols); + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + THTensor_(zero)(r_); + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *ptr_input = THTensor_(data)(input); + real *ptr_weight = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + + /* do image, kernel convolution */ + THTensor_(conv3d)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol, vf, xc); + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 4D input, 4D kernel, 4D output + component wise multiplication like + y <- y.*x + beta*y +*/ +void THTensor_(conv3Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, + long sdepth, long srow, long scol, const char *vf, const char *xc) +{ + long nInputPlane, nInputDepth, nInputRows, nInputCols; + long nKernelDepth, nKernelRows, nKernelCols; + long nOutputPlane, nOutputDepth, nOutputRows, nOutputCols; + long istride0, kstride0; + + THArgCheck(t_->nDimension == 4 , 3, "input: 3D Tensor expected"); + THArgCheck(k_->nDimension == 4 , 4, "kernel: 3D Tensor expected"); + THArgCheck(srow >= 1, 5, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 6, "Stride should be a positive integer"); + THArgCheck(*vf == 'V' || *vf == 'F', 7, "type of convolution can 'V' or 'F'"); + THArgCheck(*xc == 'C' || *xc == 'X', 7, "type of convolution can 'X' or 'C'"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor* kernel = THTensor_(newContiguous)(k_); + + istride0 = input->stride[0]; + nInputPlane = input->size[0]; + nInputDepth = input->size[1]; + nInputRows = input->size[2]; + nInputCols = input->size[3]; + + kstride0 = kernel->stride[0]; + nOutputPlane = kernel->size[0]; + nKernelDepth = kernel->size[1]; + nKernelRows = kernel->size[2]; + nKernelCols = kernel->size[3]; + + THArgCheck(nOutputPlane == nInputPlane, 2, "invalid number of input/kernel planes"); + THArgCheck( (nInputDepth >= nKernelDepth && nInputRows >= nKernelRows && nInputCols >= nKernelCols) || *vf == 'F', 2, "conv3Dcmul : Input image is smaller than kernel"); + + nOutputDepth = THTensor_(convsize)(nInputDepth, nKernelDepth, sdepth, vf); + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize4d)(r_, nOutputPlane, nOutputDepth, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long k; + for(k = 0; k < nOutputPlane; k++) + { + /* get kernel */ + real *ptr_weight = weight_data + k*kstride0; + /* get input */ + real *ptr_input = input_data + k*istride0; + + /* do image, kernel convolution */ + THTensor_(conv3d)(output_data, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol, vf, xc); + + /* Next output plane */ + output_data += nOutputDepth*nOutputCols*nOutputRows; + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +/* + 4D input, 4D kernel, 4D output + component wise multiplication like with a permutation map + y <- y.*x + beta*y +*/ +void THTensor_(conv3Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, THTensor *map, + long sdepth, long srow, long scol, const char *vf, const char *xc) +{ + long nInputPlane, nInputDepth, nInputRows, nInputCols; + long nKernelDepth, nKernelRows, nKernelCols; + long nOutputPlane, nOutputDepth, nOutputRows, nOutputCols; + long istride0, kstride0; + + THArgCheck(t_->nDimension == 4 , 3, "input: 4D Tensor expected"); + THArgCheck(k_->nDimension == 4 , 4, "kernel: 4D Tensor expected"); + THArgCheck(map->nDimension == 2 , 4, "map: 2D Tensor expected"); + THArgCheck(srow >= 1, 6, "Stride should be a positive integer"); + THArgCheck(scol >= 1, 7, "Stride should be a positive integer"); + THArgCheck(*vf == 'V' || *vf == 'F', 8, "type of convolution can 'V' or 'F'"); + THArgCheck(*xc == 'C' || *xc == 'X', 8, "type of convolution can 'X' or 'C'"); + + THTensor *input = THTensor_(newContiguous)(t_); + THTensor* kernel = THTensor_(newContiguous)(k_); + + istride0 = input->stride[0]; + nInputPlane = input->size[0]; + nInputDepth = input->size[1]; + nInputRows = input->size[2]; + nInputCols = input->size[3]; + + kstride0 = kernel->stride[0]; + nOutputPlane = kernel->size[0]; + nKernelDepth = kernel->size[1]; + nKernelRows = kernel->size[2]; + nKernelCols = kernel->size[3]; + + THArgCheck(nOutputPlane == nInputPlane, 2, "invalid number of input/kernel planes"); + THArgCheck((nInputDepth >= nKernelDepth + && nInputRows >= nKernelRows + && nInputCols >= nKernelCols) || *vf == 'F', + 2, "conv3Dmap : Input image is smaller than kernel"); + + nOutputDepth = THTensor_(convsize)(nInputDepth, nKernelDepth, sdepth, vf); + nOutputRows = THTensor_(convsize)(nInputRows, nKernelRows, srow, vf); + nOutputCols = THTensor_(convsize)(nInputCols, nKernelCols, scol, vf); + + long nelem = THTensor_(nElement)(r_); + THTensor_(resize4d)(r_, nOutputPlane, nOutputDepth, nOutputRows, nOutputCols); + + if (nelem == 0 || beta == 0 || nelem != THTensor_(nElement)(r_)) + { + THTensor_(zero)(r_); + } + else if (beta != 1) + THTensor_(mul)(r_, r_, beta); + + real *input_data = THTensor_(data)(input); + real *weight_data = THTensor_(data)(kernel); + real *output_data = THTensor_(data)(r_); + + long nmaps = map->size[0]; + + long k; + for(k = 0; k < nmaps; k++) + { + /* get indices */ + long from = (long)THTensor_(get2d)(map,k,0)-1; + long to = (long)THTensor_(get2d)(map,k,1)-1; + + /* get kernel */ + real *ptr_weight = weight_data + k*kstride0; + /* get input */ + real *ptr_input = input_data + from*istride0; + /* get output */ + real *ptr_output = output_data + to*nOutputDepth*nOutputRows*nOutputCols; + + /* do image, kernel convolution */ + THTensor_(conv3d)(ptr_output, + alpha, + ptr_input, nInputDepth, nInputRows, nInputCols, + ptr_weight, nKernelDepth, nKernelRows, nKernelCols, + sdepth, srow, scol, vf, xc); + } + THTensor_(free)(input); + THTensor_(free)(kernel); +} + +#endif diff --git a/lib/TH/generic/THTensorConv.h b/lib/TH/generic/THTensorConv.h new file mode 100644 index 00000000000..dac48695bc4 --- /dev/null +++ b/lib/TH/generic/THTensorConv.h @@ -0,0 +1,78 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorConv.h" +#else + + +TH_API void THTensor_(validXCorr2Dptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc); + +TH_API void THTensor_(validConv2Dptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc); + +TH_API void THTensor_(fullXCorr2Dptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc); + +TH_API void THTensor_(fullConv2Dptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc); + +TH_API void THTensor_(validXCorr2DRevptr)(real *r_, + real alpha, + real *t_, long ir, long ic, + real *k_, long kr, long kc, + long sr, long sc); + +TH_API void THTensor_(conv2DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol); +TH_API void THTensor_(conv2Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); +TH_API void THTensor_(conv2Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); +TH_API void THTensor_(conv2Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); +TH_API void THTensor_(conv2Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc); + +TH_API void THTensor_(validXCorr3Dptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc); + +TH_API void THTensor_(validConv3Dptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc); + +TH_API void THTensor_(fullXCorr3Dptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc); + +TH_API void THTensor_(fullConv3Dptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc); + +TH_API void THTensor_(validXCorr3DRevptr)(real *r_, + real alpha, + real *t_, long it, long ir, long ic, + real *k_, long kt, long kr, long kc, + long st, long sr, long sc); + +TH_API void THTensor_(conv3DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol); +TH_API void THTensor_(conv3Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); +TH_API void THTensor_(conv3Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); +TH_API void THTensor_(conv3Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); +TH_API void THTensor_(conv3Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); + +#endif diff --git a/lib/TH/generic/THTensorCopy.c b/lib/TH/generic/THTensorCopy.c new file mode 100644 index 00000000000..371d3f72e7e --- /dev/null +++ b/lib/TH/generic/THTensorCopy.c @@ -0,0 +1,21 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorCopy.c" +#else + +#define IMPLEMENT_THTensor_COPY(TYPENAMESRC, TYPE_SRC) \ +void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \ +{ \ + TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)(*src_data);) \ +} + +IMPLEMENT_THTensor_COPY(, real) + +IMPLEMENT_THTensor_COPY(Byte, unsigned char) +IMPLEMENT_THTensor_COPY(Char, char) +IMPLEMENT_THTensor_COPY(Short, short) +IMPLEMENT_THTensor_COPY(Int, int) +IMPLEMENT_THTensor_COPY(Long, long) +IMPLEMENT_THTensor_COPY(Float, float) +IMPLEMENT_THTensor_COPY(Double, double) + +#endif diff --git a/lib/TH/generic/THTensorCopy.h b/lib/TH/generic/THTensorCopy.h new file mode 100644 index 00000000000..8d03b2207f1 --- /dev/null +++ b/lib/TH/generic/THTensorCopy.h @@ -0,0 +1,16 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorCopy.h" +#else + +/* Support for copy between different Tensor types */ + +TH_API void THTensor_(copy)(THTensor *tensor, THTensor *src); +TH_API void THTensor_(copyByte)(THTensor *tensor, struct THByteTensor *src); +TH_API void THTensor_(copyChar)(THTensor *tensor, struct THCharTensor *src); +TH_API void THTensor_(copyShort)(THTensor *tensor, struct THShortTensor *src); +TH_API void THTensor_(copyInt)(THTensor *tensor, struct THIntTensor *src); +TH_API void THTensor_(copyLong)(THTensor *tensor, struct THLongTensor *src); +TH_API void THTensor_(copyFloat)(THTensor *tensor, struct THFloatTensor *src); +TH_API void THTensor_(copyDouble)(THTensor *tensor, struct THDoubleTensor *src); + +#endif diff --git a/lib/TH/generic/THTensorLapack.c b/lib/TH/generic/THTensorLapack.c new file mode 100644 index 00000000000..8eb3ad52e46 --- /dev/null +++ b/lib/TH/generic/THTensorLapack.c @@ -0,0 +1,343 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorLapack.c" +#else + +static int THTensor_(lapackClone)(THTensor *r_, THTensor *m, int forced) +{ + int clone; + + if (!forced && m->stride[0] == 1 && m->stride[1] == m->size[0]) + { + clone = 0; + THTensor_(set)(r_,m); + } + else + { + clone = 1; + /* we need to copy */ + THTensor_(resize2d)(r_,m->size[1],m->size[0]); + THTensor_(transpose)(r_,NULL,0,1); + THTensor_(copy)(r_,m); + } + return clone; +} + +TH_API void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) +{ + int n, nrhs, lda, ldb, info; + THIntTensor *ipiv; + THTensor *ra__; + THTensor *rb__; + + int clonea; + int cloneb; + int destroya; + int destroyb; + + + if (a == NULL || ra_ == a) /* possibly destroy the inputs */ + { + ra__ = THTensor_(new)(); + clonea = THTensor_(lapackClone)(ra__,ra_,0); + destroya = 1; + } + else /*we want to definitely clone and use ra_ and rb_ as computational space*/ + { + clonea = THTensor_(lapackClone)(ra_,a,1); + ra__ = ra_; + destroya = 0; + } + if (b == NULL || rb_ == b) /* possibly destroy the inputs */ + { + rb__ = THTensor_(new)(); + cloneb = THTensor_(lapackClone)(rb__,rb_,0); + destroyb = 1; + } + else /*we want to definitely clone and use ra_ and rb_ as computational space*/ + { + cloneb = THTensor_(lapackClone)(rb_,b,1); + rb__ = rb_; + destroyb = 0; + } + + THArgCheck(ra__->nDimension == 2, 1, "A should be 2 dimensional"); + THArgCheck(rb__->nDimension == 2, 2, "b should be 2 dimensional"); + THArgCheck(ra__->size[0] == ra__->size[1], 1, "A should be square"); + THArgCheck(rb__->size[0] == ra__->size[0], 2, "A,b size incomptable"); + + n = (int)ra__->size[0]; + nrhs = (int)rb__->size[1]; + lda = n; + ldb = n; + + ipiv = THIntTensor_newWithSize1d((long)n); + THLapack_(gesv)(n, nrhs, + THTensor_(data)(ra__), lda, THIntTensor_data(ipiv), + THTensor_(data)(rb__), ldb, &info); + + /* clean up */ + if (destroya) + { + if (clonea) + { + THTensor_(copy)(ra_,ra__); + } + THTensor_(free)(ra__); + } + if (destroyb) + { + if (cloneb) + { + THTensor_(copy)(rb_,rb__); + } + THTensor_(free)(rb__); + } + + if (info < 0) + { + THError("Lapack gesv : Argument %d : illegal value", -info); + } + else if (info > 0) + { + THError("Lapack gesv : U(%d,%d) is zero, singular U.", info,info); + } + + THIntTensor_free(ipiv); +} + +TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) +{ + int m, n, nrhs, lda, ldb, info, lwork; + char transpose; + THTensor *work = NULL; + real wkopt = 0; + + THTensor *ra__; + THTensor *rb__; + + int clonea; + int cloneb; + int destroya; + int destroyb; + + + if (a == NULL || ra_ == a) /* possibly destroy the inputs */ + { + ra__ = THTensor_(new)(); + clonea = THTensor_(lapackClone)(ra__,ra_,0); + destroya = 1; + } + else /*we want to definitely clone and use ra_ and rb_ as computational space*/ + { + clonea = THTensor_(lapackClone)(ra_,a,1); + ra__ = ra_; + destroya = 0; + } + if (b == NULL || rb_ == b) /* possibly destroy the inputs */ + { + rb__ = THTensor_(new)(); + cloneb = THTensor_(lapackClone)(rb__,rb_,0); + destroyb = 1; + } + else /*we want to definitely clone and use ra_ and rb_ as computational space*/ + { + cloneb = THTensor_(lapackClone)(rb_,b,1); + rb__ = rb_; + destroyb = 0; + } + + THArgCheck(ra__->nDimension == 2, 1, "A should be 2 dimensional"); + THArgCheck(ra_->size[0] == rb__->size[0], 2, "size incompatible A,b"); + + m = ra__->size[0]; + n = ra__->size[1]; + nrhs = rb__->size[1]; + lda = m; + ldb = m; + info = 0; + + // get optimal workspace size + THLapack_(gels)('N', m, n, nrhs, THTensor_(data)(ra__), lda, + THTensor_(data)(rb__), ldb, + &wkopt, -1, &info); + lwork = (int)wkopt; + work = THTensor_(newWithSize1d)(lwork); + THLapack_(gels)('N', m, n, nrhs, THTensor_(data)(ra__), lda, + THTensor_(data)(rb__), ldb, + THTensor_(data)(work), lwork, &info); + + //printf("lwork = %d,%g\n",lwork,THTensor_(data)(work)[0]); + if (info != 0) + { + THError("Lapack gels : Argument %d : illegal value", -info); + } + /* clean up */ + if (destroya) + { + if (clonea) + { + THTensor_(copy)(ra_,ra__); + } + THTensor_(free)(ra__); + } + if (destroyb) + { + if (cloneb) + { + THTensor_(copy)(rb_,rb__); + } + THTensor_(free)(rb__); + } + THTensor_(free)(work); +} + +TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a, const char *jobz, const char *uplo) +{ + int n, lda, lwork, info; + THTensor *work; + real wkopt; + + THTensor *rv__; + + int clonea; + int destroy; + + if (a == NULL) /* possibly destroy the inputs */ + { + rv__ = THTensor_(new)(); + clonea = THTensor_(lapackClone)(rv__,rv_,0); + destroy = 1; + } + else /*we want to definitely clone and use ra_ and rb_ as computational space*/ + { + clonea = THTensor_(lapackClone)(rv_,a,1); + rv__ = rv_; + destroy = 0; + } + + THArgCheck(rv__->nDimension == 2, 2, "A should be 2 dimensional"); + + n = rv__->size[0]; + lda = n; + + THTensor_(resize1d)(re_,n); + + // get optimal workspace size + THLapack_(syev)(jobz[0], uplo[0], n, THTensor_(data)(rv__), lda, + THTensor_(data)(re_), &wkopt, -1, &info); + lwork = (int)wkopt; + work = THTensor_(newWithSize1d)(lwork); + THLapack_(syev)(jobz[0], uplo[0], n, THTensor_(data)(rv__), lda, + THTensor_(data)(re_), THTensor_(data)(work), lwork, &info); + + if (info > 0) + { + THError(" Lapack syev : Failed to converge. %d off-diagonal elements of an didn't converge to zero",info); + } + else if (info < 0) + { + THError("Lapack syev : Argument %d : illegal value", -info); + } + /* clean up */ + if (destroy) + { + if (clonea) + { + THTensor_(copy)(rv_,rv__); + } + THTensor_(free)(rv__); + } + THTensor_(free)(work); +} + +TH_API void THTensor_(gesvd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char* jobu) +{ + THTensor *ra_ = THTensor_(new)(); + THTensor_(gesvd2)(ru_, rs_, rv_, ra_, a, jobu); + THTensor_(free)(ra_); +} + +TH_API void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a, const char* jobu) +{ + int k,m, n, lda, ldu, ldvt, lwork, info; + THTensor *work; + real wkopt; + + THTensor *ra__; + + int clonea; + int destroy; + + if (a == NULL) /* possibly destroy the inputs */ + { + ra__ = THTensor_(new)(); + clonea = THTensor_(lapackClone)(ra__,ra_,0); + destroy = 1; + } + else /*we want to definitely clone */ + { + clonea = THTensor_(lapackClone)(ra_,a,1); + ra__ = ra_; + destroy = 0; + } + + THArgCheck(ra__->nDimension == 2, 2, "A should be 2 dimensional"); + + m = ra__->size[0]; + n = ra__->size[1]; + k = (m < n ? m : n); + + lda = m; + ldu = m; + ldvt = n; + THTensor_(resize1d)(rs_,k); + THTensor_(resize2d)(rv_,ldvt,n); + if (*jobu == 'A') + { + THTensor_(resize2d)(ru_,m,ldu); + } + else + { + THTensor_(resize2d)(ru_,k,ldu); + } + THTensor_(transpose)(ru_,NULL,0,1); + THTensor_(transpose)(rv_,NULL,0,1); + + THLapack_(gesvd)(jobu[0],jobu[0], + m,n,THTensor_(data)(ra__),lda, + THTensor_(data)(rs_), + THTensor_(data)(ru_), + ldu, + THTensor_(data)(rv_), ldvt, + &wkopt, -1, &info); + lwork = (int)wkopt; + work = THTensor_(newWithSize1d)(lwork); + THLapack_(gesvd)(jobu[0],jobu[0], + m,n,THTensor_(data)(ra__),lda, + THTensor_(data)(rs_), + THTensor_(data)(ru_), + ldu, + THTensor_(data)(rv_), ldvt, + THTensor_(data)(work),lwork, &info); + if (info > 0) + { + THError(" Lapack gesvd : %d superdiagonals failed to converge.",info); + } + else if (info < 0) + { + THError("Lapack gesvd : Argument %d : illegal value", -info); + } + + /* clean up */ + if (destroy) + { + if (clonea) + { + THTensor_(copy)(ra_,ra__); + } + THTensor_(free)(ra__); + } + THTensor_(free)(work); +} + +#endif diff --git a/lib/TH/generic/THTensorLapack.h b/lib/TH/generic/THTensorLapack.h new file mode 100644 index 00000000000..6d3f344cd1f --- /dev/null +++ b/lib/TH/generic/THTensorLapack.h @@ -0,0 +1,11 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorLapack.h" +#else + +TH_API void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); +TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); +TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobz, const char *uplo); +TH_API void THTensor_(gesvd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char *jobu); +TH_API void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a, const char *jobu); + +#endif diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c new file mode 100644 index 00000000000..e66d6ba9cd2 --- /dev/null +++ b/lib/TH/generic/THTensorMath.c @@ -0,0 +1,1063 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorMath.c" +#else + +void THTensor_(fill)(THTensor *r_, real value) +{ + TH_TENSOR_APPLY(real, r_, + THVector_(fill)(r__data, value, r__size); break;); +} + +void THTensor_(zero)(THTensor *r_) +{ + TH_TENSOR_APPLY(real, r_, + THVector_(fill)(r__data, 0, r__size); break;); +} + +accreal THTensor_(dot)(THTensor *tensor, THTensor *src) +{ + accreal sum = 0; + /* we use a trick here. careful with that. */ + TH_TENSOR_APPLY2(real, tensor, real, src, + long sz = (tensor_size-tensor_i < src_size-src_i ? tensor_size-tensor_i : src_size-src_i); + sum += THBlas_(dot)(sz, src_data, src_stride, tensor_data, tensor_stride); + tensor_i += sz; + src_i += sz; + tensor_data += sz*tensor_stride; + src_data += sz*src_stride; + break;); + return sum; +} + +real THTensor_(minall)(THTensor *tensor) +{ + real theMin; + THArgCheck(tensor->nDimension > 0, 1, "tensor must have one dimension"); + theMin = THTensor_(data)(tensor)[0]; + TH_TENSOR_APPLY(real, tensor, if(*tensor_data < theMin) theMin = *tensor_data;); + return theMin; +} + +real THTensor_(maxall)(THTensor *tensor) +{ + real theMax; + THArgCheck(tensor->nDimension > 0, 1, "tensor must have one dimension"); + theMax = THTensor_(data)(tensor)[0]; + TH_TENSOR_APPLY(real, tensor, if(*tensor_data > theMax) theMax = *tensor_data;); + return theMax; +} + +accreal THTensor_(sumall)(THTensor *tensor) +{ + accreal sum = 0; + TH_TENSOR_APPLY(real, tensor, sum += *tensor_data;); + return sum; +} + +void THTensor_(add)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(resizeAs)(r_, t); + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data + value;); +} + +void THTensor_(mul)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(resizeAs)(r_, t); + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data * value;); +} + +void THTensor_(div)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(resizeAs)(r_, t); + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data / value;); +} + +void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src) +{ + THTensor_(resizeAs)(r_, t); + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data + value * *src_data;); +} + +void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src) +{ + THTensor_(resizeAs)(r_, t); + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data * *src_data;); +} + +void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src) +{ + THTensor_(resizeAs)(r_, t); + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data / *src_data;); +} + +void THTensor_(addcmul)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2) +{ + if(r_ != t) + { + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); + } + + TH_TENSOR_APPLY3(real, r_, real, src1, real, src2, *r__data += value * *src1_data * *src2_data;); +} + + +void THTensor_(addcdiv)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2) +{ + if(r_ != t) + { + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); + } + + TH_TENSOR_APPLY3(real, r_, real, src1, real, src2, *r__data += value * *src1_data / *src2_data;); +} + +void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat, THTensor *vec) +{ + if( (mat->nDimension != 2) || (vec->nDimension != 1) ) + THError("matrix and vector expected"); + + if( mat->size[1] != vec->size[0] ) + THError("size mismatch"); + + if(t->nDimension != 1) + THError("size mismatch"); + + if(t->size[0] != mat->size[0]) + THError("size mismatch"); + + if(r_ != t) + { + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); + } + + if(mat->stride[0] == 1) + { + THBlas_(gemv)('n', mat->size[0], mat->size[1], + alpha, THTensor_(data)(mat), mat->stride[1], + THTensor_(data)(vec), vec->stride[0], + beta, THTensor_(data)(r_), r_->stride[0]); + } + else if(mat->stride[1] == 1) + { + THBlas_(gemv)('t', mat->size[1], mat->size[0], + alpha, THTensor_(data)(mat), mat->stride[0], + THTensor_(data)(vec), vec->stride[0], + beta, THTensor_(data)(r_), r_->stride[0]); + } + else + { + THTensor *cmat = THTensor_(newContiguous)(mat); + + THBlas_(gemv)('t', mat->size[1], mat->size[0], + alpha, THTensor_(data)(cmat), cmat->stride[0], + THTensor_(data)(vec), vec->stride[0], + beta, THTensor_(data)(r_), r_->stride[0]); + + THTensor_(free)(cmat); + } +} + +void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *m1, THTensor *m2) +{ + long r, c; + char transpose, transpose_m1, transpose_m2; + THTensor *r__, *m1_, *m2_; + + if( (m1->nDimension != 2) || (m2->nDimension != 2) ) + THError("matrix and matrix expected"); + + if(t->nDimension != 2) + THError("size mismatch"); + + if( (t->size[0] != m1->size[0]) || (t->size[1] != m2->size[1]) || (m1->size[1] != m2->size[0]) ) + THError("size mismatch"); + + if(t != r_) + { + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); + } + + /* r_ */ + if(r_->stride[0] == 1) + { + transpose = 'n'; + r__ = r_; + } + else if(r_->stride[1] == 1) + { + THTensor *swap = m2; + m2 = m1; + m1 = swap; + THTensor_(transpose)(r_, NULL, 0, 1); + THTensor_(transpose)(m1, NULL, 0, 1); + THTensor_(transpose)(m2, NULL, 0, 1); + transpose = 't'; + r__ = r_; + } + else + { + transpose = 'n'; + THTensor_(transpose)(r_, NULL, 0, 1); + r__ = THTensor_(newClone)(r_); + THTensor_(transpose)(r_, NULL, 0, 1); + THTensor_(transpose)(r__, NULL, 0, 1); + } + + /* m1 */ + if(m1->stride[0] == 1) + { + transpose_m1 = 'n'; + m1_ = m1; + } + else if(m1->stride[1] == 1) + { + transpose_m1 = 't'; + m1_ = m1; + } + else + { + transpose_m1 = 't'; + m1_ = THTensor_(newContiguous)(m1); + } + + /* m2 */ + if(m2->stride[0] == 1) + { + transpose_m2 = 'n'; + m2_ = m2; + } + else if(m2->stride[1] == 1) + { + transpose_m2 = 't'; + m2_ = m2; + } + else + { + transpose_m2 = 't'; + m2_ = THTensor_(newContiguous)(m2); + } + + /* do the operation */ + THBlas_(gemm)(transpose_m1, + transpose_m2, + r__->size[0], + r__->size[1], + m1_->size[1], + alpha, + THTensor_(data)(m1_), + (transpose_m1 == 'n' ? m1_->stride[1] : m1_->stride[0]), + THTensor_(data)(m2_), + (transpose_m2 == 'n' ? m2_->stride[1] : m2_->stride[0]), + beta, + THTensor_(data)(r__), + r__->stride[1]); + + /* free intermediate variables */ + if(m1_ != m1) + THTensor_(free)(m1_); + + if(m2_ != m2) + THTensor_(free)(m2_); + + if(r__ != r_) + THTensor_(freeCopyTo)(r__, r_); + + if(transpose == 't') + { + THTensor_(transpose)(r_, NULL, 0, 1); + THTensor_(transpose)(m1, NULL, 0, 1); + THTensor_(transpose)(m2, NULL, 0, 1); + } +} + +void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2) +{ + if( (vec1->nDimension != 1) || (vec2->nDimension != 1) ) + THError("vector and vector expected"); + + if(t->nDimension != 2) + THError("size mismatch"); + + if( (t->size[0] != vec1->size[0]) || (t->size[1] != vec2->size[0]) ) + THError("size mismatch"); + + if(r_ != t) + { + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); + } + + if(beta != 1) + THTensor_(mul)(r_, r_, beta); + + if(r_->stride[0] == 1) + { + THBlas_(ger)(vec1->size[0], vec2->size[0], + alpha, THTensor_(data)(vec1), vec1->stride[0], + THTensor_(data)(vec2), vec2->stride[0], + THTensor_(data)(r_), r_->stride[1]); + } + else if(r_->stride[1] == 1) + { + THBlas_(ger)(vec2->size[0], vec1->size[0], + alpha, THTensor_(data)(vec2), vec2->stride[0], + THTensor_(data)(vec1), vec1->stride[0], + THTensor_(data)(r_), r_->stride[0]); + } + else + { + THTensor *cr = THTensor_(newClone)(r_); + + THBlas_(ger)(vec2->size[0], vec1->size[0], + alpha, THTensor_(data)(vec2), vec2->stride[0], + THTensor_(data)(vec1), vec1->stride[0], + THTensor_(data)(cr), cr->stride[0]); + + THTensor_(freeCopyTo)(cr, r_); + } +} + +long THTensor_(numel)(THTensor *t) +{ + return THTensor_(nElement)(t); +} + +void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension) +{ + THLongStorage *dim; + long i; + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension out of range"); + + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(values_, dim, NULL); + THLongTensor_resize(indices_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY3(real, t, real, values_, long, indices_, dimension, + long theIndex = 0; + real theMax = t_data[0]; + for(i = 1; i < t_size; i++) + { + if(t_data[i*t_stride] > theMax) + { + theIndex = i; + theMax = t_data[i*t_stride]; + } + } + *indices__data = theIndex; + *values__data = theMax;); + +} + +void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension) +{ + THLongStorage *dim; + long i; + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension out of range"); + + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(values_, dim, NULL); + THLongTensor_resize(indices_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY3(real, t, real, values_, long, indices_, dimension, + long theIndex = 0; + real theMin = t_data[0]; + for(i = 1; i < t_size; i++) + { + if(t_data[i*t_stride] < theMin) + { + theIndex = i; + theMin = t_data[i*t_stride]; + } + } + *indices__data = theIndex; + *values__data = theMin;); + +} + + +void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension out of range"); + + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal sum = 0; + long i; + for(i = 0; i < t_size; i++) + sum += t_data[i*t_stride]; + *r__data = (real)sum;); +} + +void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension out of range"); + + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal prod = 1; + long i; + for(i = 0; i < t_size; i++) + prod *= t_data[i*t_stride]; + *r__data = (real)prod;); + +} + +void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension) +{ + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension out of range"); + + THTensor_(resizeAs)(r_, t); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal cumsum = 0; + long i; + for(i = 0; i < t_size; i++) + { + cumsum += t_data[i*t_stride]; + r__data[i*r__stride] = (real)cumsum; + }); +} + +void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension) +{ + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension out of range"); + + THTensor_(resizeAs)(r_, t); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal cumprod = 1; + long i; + for(i = 0; i < t_size; i++) + { + cumprod *= t_data[i*t_stride]; + r__data[i*r__stride] = (real)cumprod; + }); +} + + +void THTensor_(sign)(THTensor *r_, THTensor *t) +{ + THTensor_(resizeAs)(r_, t); + +#if defined (TH_REAL_IS_BYTE) + TH_TENSOR_APPLY2(real, r_, real, t, + if (*t_data > 0) *r__data = 1; + else *r__data = 0;); +#else + TH_TENSOR_APPLY2(real, r_, real, t, + if (*t_data > 0) *r__data = 1; + else if (*t_data < 0) *r__data = -1; + else *r__data = 0;); +#endif +} + + +accreal THTensor_(trace)(THTensor *t) +{ + real *t_data = THTensor_(data)(t); + accreal sum = 0; + long i = 0; + long t_stride_0, t_stride_1, t_diag_size; + + THArgCheck(THTensor_(nDimension)(t) == 2, 1, "not a matrix"); + + t_stride_0 = THTensor_(stride)(t, 0); + t_stride_1 = THTensor_(stride)(t, 1); + t_diag_size = THMin(THTensor_(size)(t, 0), THTensor_(size)(t, 1)); + while(i < t_diag_size) + { + sum += t_data[i*(t_stride_0+t_stride_1)]; + i++; + } + + return sum; +} + +void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension) +{ + int i; + + if(THTensor_(nDimension)(a) != THTensor_(nDimension)(b)) + THError("inconsitent tensor sizes"); + + for(i = 0; i < THTensor_(nDimension)(a); i++) + { + if(THTensor_(size)(a, i) != THTensor_(size)(b, i)) + THError("inconsistent tensor sizes"); + } + + if(dimension < 0) + { + for(i = 0; i < THTensor_(nDimension)(a); i++) + { + if(THTensor_(size)(a, i) == 3) + { + dimension = i; + break; + } + } + if(dimension < 0) + THError("no dimension of size 3"); + } + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(a), 3, "dimension out of range"); + THArgCheck(THTensor_(size)(a, dimension) == 3, 3, "dimension size is not 3"); + + THTensor_(resizeAs)(r_, a); + + TH_TENSOR_DIM_APPLY3(real, a, real, b, real, r_, dimension, + r__data[0*r__stride] = a_data[1*a_stride]*b_data[2*b_stride] - a_data[2*a_stride]*b_data[1*b_stride]; + r__data[1*r__stride] = a_data[2*a_stride]*b_data[0*b_stride] - a_data[0*a_stride]*b_data[2*b_stride]; + r__data[2*r__stride] = a_data[0*a_stride]*b_data[1*b_stride] - a_data[1*a_stride]*b_data[0*b_stride];); +} + +void THTensor_(zeros)(THTensor *r_, THLongStorage *size) +{ + THTensor_(resize)(r_, size, NULL); + THTensor_(zero)(r_); +} + +void THTensor_(ones)(THTensor *r_, THLongStorage *size) +{ + THTensor_(resize)(r_, size, NULL); + THTensor_(fill)(r_, 1); +} + +void THTensor_(diag)(THTensor *r_, THTensor *t, int k) +{ + THArgCheck(THTensor_(nDimension)(t) == 1 || THTensor_(nDimension)(t) == 2, 1, "matrix or a vector expected"); + + if(THTensor_(nDimension)(t) == 1) + { + real *t_data = THTensor_(data)(t); + long t_stride_0 = THTensor_(stride)(t, 0); + long t_size = THTensor_(size)(t, 0); + long sz = t_size + (k >= 0 ? k : -k); + real *r__data; + long r__stride_0; + long r__stride_1; + long i; + + THTensor_(resize2d)(r_, sz, sz); + THTensor_(zero)(r_); + r__data = THTensor_(data)(r_); + r__stride_0 = THTensor_(stride)(r_, 0); + r__stride_1 = THTensor_(stride)(r_, 1); + r__data += (k >= 0 ? k*r__stride_1 : -k*r__stride_0); + + for(i = 0; i < t_size; i++) + r__data[i*(r__stride_0+r__stride_1)] = t_data[i*t_stride_0]; + } + else + { + real *t_data = THTensor_(data)(t); + long t_stride_0 = THTensor_(stride)(t, 0); + long t_stride_1 = THTensor_(stride)(t, 1); + long sz; + real *r__data; + long r__stride_0; + long i; + + if(k >= 0) + sz = THMin(THTensor_(size)(t, 0), THTensor_(size)(t, 1)-k); + else + sz = THMin(THTensor_(size)(t, 0)+k, THTensor_(size)(t, 1)); + THTensor_(resize1d)(r_, sz); + r__data = THTensor_(data)(r_); + r__stride_0 = THTensor_(stride)(r_, 0); + + t_data += (k >= 0 ? k*t_stride_1 : -k*t_stride_0); + for(i = 0; i < sz; i++) + r__data[i*r__stride_0] = t_data[i*(t_stride_0+t_stride_1)]; + } +} + +void THTensor_(eye)(THTensor *r_, long n, long m) +{ + real *r__data; + long i, sz; + + THArgCheck(n > 0, 1, "invalid argument"); + + if(m <= 0) + m = n; + + THTensor_(resize2d)(r_, n, m); + THTensor_(zero)(r_); + + i = 0; + r__data = THTensor_(data)(r_); + sz = THMin(THTensor_(size)(r_, 0), THTensor_(size)(r_, 1)); + for(i = 0; i < sz; i++) + r__data[i*(r_->stride[0]+r_->stride[1])] = 1; +} + + +void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step) +{ + long size; + real i = 0; + + THArgCheck(step > 0, 3, "step must be a positive number"); + THArgCheck(xmax > xmin, 2, "upper bound must be larger than lower bound"); + + size = (long)((xmax-xmin)/step+1); + + THTensor_(resize1d)(r_, size); + + TH_TENSOR_APPLY(real, r_, *r__data = xmin + (i++)*step;); +} + +void THTensor_(randperm)(THTensor *r_, long n) +{ + real *r__data; + long r__stride_0; + long i; + + THArgCheck(n > 0, 1, "must be strictly positive"); + + THTensor_(resize1d)(r_, n); + r__data = THTensor_(data)(r_); + r__stride_0 = THTensor_(stride)(r_,0); + + for(i = 0; i < n; i++) + r__data[i*r__stride_0] = (real)(i); + + for(i = 0; i < n-1; i++) + { + long z = THRandom_random() % (n-i); + real sav = r__data[i*r__stride_0]; + r__data[i*r__stride_0] = r__data[(z+i)*r__stride_0]; + r__data[(z+i)*r__stride_0] = sav; + } +} + +void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size) +{ + THTensor_(resize)(r_, size, NULL); + THTensor_(copy)(r_, t); +} + +/* I cut and pasted (slightly adapted) the quicksort code from + http://www.alienryderflex.com/quicksort/ + This public-domain C implementation by Darel Rex Finley. + Thanks man :) +*/ +#define MAX_LEVELS 300 +static void THTensor_(quicksortascend)(real *arr, long *idx, long elements, long stride) +{ + long beg[MAX_LEVELS], end[MAX_LEVELS], i=0, L, R, swap, pid; + real piv; + + beg[0]=0; end[0]=elements; + while (i>=0) { + L=beg[i]; R=end[i]-1; + if (L=piv && Lend[i-1]-beg[i-1]) { + swap=beg[i]; beg[i]=beg[i-1]; beg[i-1]=swap; + swap=end[i]; end[i]=end[i-1]; end[i-1]=swap; }} + else { + i--; }}} + +static void THTensor_(quicksortdescend)(real *arr, long *idx, long elements, long stride) +{ + long beg[MAX_LEVELS], end[MAX_LEVELS], i=0, L, R, swap, pid; + real piv; + + beg[0]=0; end[0]=elements; + while (i>=0) { + L=beg[i]; R=end[i]-1; + if (L=piv && Lend[i-1]-beg[i-1]) { + swap=beg[i]; beg[i]=beg[i-1]; beg[i-1]=swap; + swap=end[i]; end[i]=end[i-1]; end[i-1]=swap; }} + else { + i--; }}} + +void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder) +{ + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension"); + + THTensor_(resizeAs)(rt_, t); + THTensor_(copy)(rt_, t); + + { + THLongStorage *size = THTensor_(newSizeOf)(t); + THLongTensor_resize(ri_, size, NULL); + THLongStorage_free(size); + } + + if(descendingOrder) + { + TH_TENSOR_DIM_APPLY2(real, rt_, long, ri_, dimension, + long i; + for(i = 0; i < ri__size; i++) + ri__data[i*ri__stride] = i; + THTensor_(quicksortdescend)(rt__data, ri__data, rt__size, rt__stride);) + } + else + { + TH_TENSOR_DIM_APPLY2(real, rt_, long, ri_, dimension, + long i; + for(i = 0; i < ri__size; i++) + ri__data[i*ri__stride] = i; + THTensor_(quicksortascend)(rt__data, ri__data, rt__size, rt__stride);) + } +} + +void THTensor_(tril)(THTensor *r_, THTensor *t, long k) +{ + long t_size_0, t_size_1; + long t_stride_0, t_stride_1; + long r__stride_0, r__stride_1; + real *t_data, *r__data; + long r, c; + + THArgCheck(THTensor_(nDimension)(t) == 2, 1, "not a matrix"); + + THTensor_(resizeAs)(r_, t); + + t_size_0 = THTensor_(size)(t, 0); + t_size_1 = THTensor_(size)(t, 1); + t_stride_0 = THTensor_(stride)(t, 0); + t_stride_1 = THTensor_(stride)(t, 1); + r__stride_0 = THTensor_(stride)(r_, 0); + r__stride_1 = THTensor_(stride)(r_, 1); + r__data = THTensor_(data)(r_); + t_data = THTensor_(data)(t); + + for(r = 0; r < t_size_0; r++) + { + long sz = THMin(r+k+1, t_size_1); + for(c = THMax(0, r+k); c < t_size_1; c++) + r__data[r*r__stride_0+c*r__stride_1] = 0; + for(c = 0; c < sz; c++) + r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1]; + } +} + +void THTensor_(triu)(THTensor *r_, THTensor *t, long k) +{ + long t_size_0, t_size_1; + long t_stride_0, t_stride_1; + long r__stride_0, r__stride_1; + real *t_data, *r__data; + long r, c; + + THArgCheck(THTensor_(nDimension)(t) == 2, 1, "not a matrix"); + + THTensor_(resizeAs)(r_, t); + + t_size_0 = THTensor_(size)(t, 0); + t_size_1 = THTensor_(size)(t, 1); + t_stride_0 = THTensor_(stride)(t, 0); + t_stride_1 = THTensor_(stride)(t, 1); + r__stride_0 = THTensor_(stride)(r_, 0); + r__stride_1 = THTensor_(stride)(r_, 1); + r__data = THTensor_(data)(r_); + t_data = THTensor_(data)(t); + + for(r = 0; r < t_size_0; r++) + { + long sz = THMin(r+k, t_size_1); + for(c = THMax(0, r+k); c < t_size_1; c++) + r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1]; + for(c = 0; c < sz; c++) + r__data[r*r__stride_0+c*r__stride_1] = 0; + } +} + +void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension) +{ + THLongStorage *size; + int i; + int ndim = THMax(ta->nDimension, tb->nDimension); + ndim = THMax(ndim, dimension+1); + + THArgCheck(dimension >= 0, 4, "invalid dimension"); + + size = THLongStorage_newWithSize(ndim); + for(i = 0; i < ndim; i++) + { + int tadi = (i < ta->nDimension ? ta->size[i] : 1); + int tbdi = (i < tb->nDimension ? tb->size[i] : 1); + + if(i == dimension) + size->data[i] = tadi+tbdi; + else + { + if(tadi != tbdi) + { + THLongStorage_free(size); + THError("inconsistent tensor sizes"); + } + size->data[i] = tadi; + } + } + + THTensor_(resize)(r_, size, NULL); + THLongStorage_free(size); + + { + THTensor *nta = THTensor_(newWithTensor)(r_); + THTensor_(narrow)(nta, NULL, dimension, 0, (dimension < ta->nDimension ? ta->size[dimension] : 1)); + THTensor_(copy)(nta, ta); + THTensor_(free)(nta); + } + + { + THTensor *ntb = THTensor_(newWithTensor)(r_); + THTensor_(narrow)(ntb, NULL, dimension, (dimension < ta->nDimension ? ta->size[dimension] : 1), (dimension < tb->nDimension ? tb->size[dimension] : 1)); + THTensor_(copy)(ntb, tb); + THTensor_(free)(ntb); + } +} + +/* floating point only now */ + +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) + +#define LAB_IMPLEMENT_BASIC_FUNCTION(NAME, CFUNC) \ + void THTensor_(NAME)(THTensor *r_, THTensor *t) \ + { \ + THTensor_(resizeAs)(r_, t); \ + TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \ + } \ + +#define LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(NAME, CFUNC) \ + void THTensor_(NAME)(THTensor *r_, THTensor *t, real value) \ + { \ + THTensor_(resizeAs)(r_, t); \ + TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data, value);); \ + } \ + \ +LAB_IMPLEMENT_BASIC_FUNCTION(log,log) +LAB_IMPLEMENT_BASIC_FUNCTION(log1p,log1p) +LAB_IMPLEMENT_BASIC_FUNCTION(exp,exp) +LAB_IMPLEMENT_BASIC_FUNCTION(cos,cos) +LAB_IMPLEMENT_BASIC_FUNCTION(acos,acos) +LAB_IMPLEMENT_BASIC_FUNCTION(cosh,cosh) +LAB_IMPLEMENT_BASIC_FUNCTION(sin,sin) +LAB_IMPLEMENT_BASIC_FUNCTION(asin,asin) +LAB_IMPLEMENT_BASIC_FUNCTION(sinh,sinh) +LAB_IMPLEMENT_BASIC_FUNCTION(tan,tan) +LAB_IMPLEMENT_BASIC_FUNCTION(atan,atan) +LAB_IMPLEMENT_BASIC_FUNCTION(tanh,tanh) +LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(pow,pow) +LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,sqrt) +LAB_IMPLEMENT_BASIC_FUNCTION(ceil,ceil) +LAB_IMPLEMENT_BASIC_FUNCTION(floor,floor) +LAB_IMPLEMENT_BASIC_FUNCTION(abs,fabs) + +void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension"); + + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal sum = 0; + long i; + for(i = 0; i < t_size; i++) + sum += t_data[i*t_stride]; + *r__data = (real)sum/t_size;); +} + +void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension"); + + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal sum = 0; + accreal sum2 = 0; + long i; + for(i = 0; i < t_size; i++) + { + real z = t_data[i*t_stride]; + sum += z; + sum2 += z*z; + } + + if(flag) + { + sum /= t_size; + sum2 /= t_size; + sum2 -= sum*sum; + sum2 = (sum2 < 0 ? 0 : sum2); + *r__data = (real)sqrt(sum2); + } + else + { + sum /= t_size; + sum2 /= t_size-1; + sum2 -= ((real)t_size)/((real)(t_size-1))*sum*sum; + sum2 = (sum2 < 0 ? 0 : sum2); + *r__data = (real)sqrt(sum2); + }); +} + +void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension"); + + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal sum = 0; + accreal sum2 = 0; + long i; + for(i = 0; i < t_size; i++) + { + real z = t_data[i*t_stride]; + sum += z; + sum2 += z*z; + } + + if(flag) + { + sum /= t_size; + sum2 /= t_size; + sum2 -= sum*sum; + sum2 = (sum2 < 0 ? 0 : sum2); + *r__data = sum2; + } + else + { + sum /= t_size; + sum2 /= t_size-1; + sum2 -= ((real)t_size)/((real)(t_size-1))*sum*sum; + sum2 = (sum2 < 0 ? 0 : sum2); + *r__data = (real)sum2; + }); +} + +accreal THTensor_(norm)(THTensor *tensor, real value) +{ + accreal sum = 0; + TH_TENSOR_APPLY(real, tensor, sum += pow(fabs(*tensor_data), value);); + return pow(sum, 1.0/value); +} + +accreal THTensor_(dist)(THTensor *tensor, THTensor *src, real value) +{ + real sum = 0; + TH_TENSOR_APPLY2(real, tensor, real, src, + sum += pow(fabs(*tensor_data - *src_data), value);) + return pow(sum, 1.0/value); +} + +accreal THTensor_(meanall)(THTensor *tensor) +{ + THArgCheck(tensor->nDimension > 0, 1, "empty Tensor"); + return THTensor_(sumall)(tensor)/THTensor_(nElement)(tensor); +} + +accreal THTensor_(varall)(THTensor *tensor) +{ + accreal mean = THTensor_(meanall)(tensor); + accreal sum = 0; + TH_TENSOR_APPLY(real, tensor, sum += (*tensor_data - mean)*(*tensor_data - mean);); + sum /= (THTensor_(nElement)(tensor)-1); + return sum; +} + +accreal THTensor_(stdall)(THTensor *tensor) +{ + return sqrt(THTensor_(varall)(tensor)); +} + +void THTensor_(linspace)(THTensor *r_, real a, real b, long n) +{ + real i = 0; + + THArgCheck(n > 0, 3, "invalid number of points"); + THArgCheck(a <= b, 2, "end range should be greater than start range"); + + THTensor_(resize1d)(r_, n); + + TH_TENSOR_APPLY(real, r_, + *r__data = a + i*(b-a)/((real)(n-1)); + i++; + ); +} + +void THTensor_(logspace)(THTensor *r_, real a, real b, long n) +{ + real i = 0; + + THArgCheck(n > 0, 3, "invalid number of points"); + THArgCheck(a <= b, 2, "end range should be greater than start range"); + + THTensor_(resize1d)(r_, n); + + TH_TENSOR_APPLY(real, r_, + *r__data = pow(10.0, a + i*(b-a)/((real)(n-1))); + i++; + ); +} + +void THTensor_(rand)(THTensor *r_, THLongStorage *size) +{ + THTensor_(resize)(r_, size, NULL); + THTensor_(uniform)(r_, 0, 1); +} + +void THTensor_(randn)(THTensor *r_, THLongStorage *size) +{ + THTensor_(resize)(r_, size, NULL); + THTensor_(normal)(r_, 0, 1); +} + +#endif /* floating point only part */ +#endif diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h new file mode 100644 index 00000000000..ba0b9913b06 --- /dev/null +++ b/lib/TH/generic/THTensorMath.h @@ -0,0 +1,90 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorMath.h" +#else + +TH_API void THTensor_(fill)(THTensor *r_, real value); +TH_API void THTensor_(zero)(THTensor *r_); + +TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src); + +TH_API real THTensor_(minall)(THTensor *t); +TH_API real THTensor_(maxall)(THTensor *t); +TH_API accreal THTensor_(sumall)(THTensor *t); + +TH_API void THTensor_(add)(THTensor *r_, THTensor *t, real value); +TH_API void THTensor_(mul)(THTensor *r_, THTensor *t, real value); +TH_API void THTensor_(div)(THTensor *r_, THTensor *t, real value); + +TH_API void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src); +TH_API void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src); +TH_API void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src); + +TH_API void THTensor_(addcmul)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2); +TH_API void THTensor_(addcdiv)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2); + +TH_API void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat, THTensor *vec); +TH_API void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat1, THTensor *mat2); +TH_API void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2); + +TH_API long THTensor_(numel)(THTensor *t); +TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); +TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); +TH_API void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension); +TH_API void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension); +TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension); +TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension); +TH_API void THTensor_(sign)(THTensor *r_, THTensor *t); +TH_API accreal THTensor_(trace)(THTensor *t); +TH_API void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension); + +TH_API void THTensor_(zeros)(THTensor *r_, THLongStorage *size); +TH_API void THTensor_(ones)(THTensor *r_, THLongStorage *size); +TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k); +TH_API void THTensor_(eye)(THTensor *r_, long n, long m); +TH_API void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step); +TH_API void THTensor_(randperm)(THTensor *r_, long n); + +TH_API void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size); +TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder); +TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, long k); +TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, long k); +TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension); + +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) + +TH_API void THTensor_(log)(THTensor *r_, THTensor *t); +TH_API void THTensor_(log1p)(THTensor *r_, THTensor *t); +TH_API void THTensor_(exp)(THTensor *r_, THTensor *t); +TH_API void THTensor_(cos)(THTensor *r_, THTensor *t); +TH_API void THTensor_(acos)(THTensor *r_, THTensor *t); +TH_API void THTensor_(cosh)(THTensor *r_, THTensor *t); +TH_API void THTensor_(sin)(THTensor *r_, THTensor *t); +TH_API void THTensor_(asin)(THTensor *r_, THTensor *t); +TH_API void THTensor_(sinh)(THTensor *r_, THTensor *t); +TH_API void THTensor_(tan)(THTensor *r_, THTensor *t); +TH_API void THTensor_(atan)(THTensor *r_, THTensor *t); +TH_API void THTensor_(tanh)(THTensor *r_, THTensor *t); +TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, real value); +TH_API void THTensor_(sqrt)(THTensor *r_, THTensor *t); +TH_API void THTensor_(ceil)(THTensor *r_, THTensor *t); +TH_API void THTensor_(floor)(THTensor *r_, THTensor *t); +TH_API void THTensor_(abs)(THTensor *r_, THTensor *t); + +TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension); +TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag); +TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag); +TH_API accreal THTensor_(norm)(THTensor *t, real value); +TH_API accreal THTensor_(dist)(THTensor *a, THTensor *b, real value); + +TH_API accreal THTensor_(meanall)(THTensor *self); +TH_API accreal THTensor_(varall)(THTensor *self); +TH_API accreal THTensor_(stdall)(THTensor *self); + +TH_API void THTensor_(linspace)(THTensor *r_, real a, real b, long n); +TH_API void THTensor_(logspace)(THTensor *r_, real a, real b, long n); +TH_API void THTensor_(rand)(THTensor *r_, THLongStorage *size); +TH_API void THTensor_(randn)(THTensor *r_, THLongStorage *size); + +#endif + +#endif diff --git a/lib/TH/generic/THTensorRandom.c b/lib/TH/generic/THTensorRandom.c new file mode 100644 index 00000000000..372398ca520 --- /dev/null +++ b/lib/TH/generic/THTensorRandom.c @@ -0,0 +1,65 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorRandom.c" +#else + +TH_API void THTensor_(random)(THTensor *self) +{ +#if defined(TH_REAL_IS_BYTE) + TH_TENSOR_APPLY(real, self, *self_data = (unsigned char)(THRandom_random() % (UCHAR_MAX+1));); +#elif defined(TH_REAL_IS_CHAR) + TH_TENSOR_APPLY(real, self, *self_data = (char)(THRandom_random() % (CHAR_MAX+1));); +#elif defined(TH_REAL_IS_SHORT) + TH_TENSOR_APPLY(real, self, *self_data = (short)(THRandom_random() % (SHRT_MAX+1));); +#elif defined(TH_REAL_IS_INT) + TH_TENSOR_APPLY(real, self, *self_data = (int)(THRandom_random() % (INT_MAX+1UL));); +#elif defined(TH_REAL_IS_LONG) + TH_TENSOR_APPLY(real, self, *self_data = (long)(THRandom_random() % (LONG_MAX+1UL));); +#elif defined(TH_REAL_IS_FLOAT) + TH_TENSOR_APPLY(real, self, *self_data = (float)(THRandom_random() % ((1UL << FLT_MANT_DIG)+1));); +#elif defined(TH_REAL_IS_DOUBLE) + TH_TENSOR_APPLY(real, self, *self_data = (float)(THRandom_random() % ((1UL << DBL_MANT_DIG)+1));); +#else +#error "Unknown type" +#endif +} + +TH_API void THTensor_(geometric)(THTensor *self, double p) +{ + TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_geometric(p);); +} + +TH_API void THTensor_(bernoulli)(THTensor *self, double p) +{ + TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(p);); +} + +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) + +TH_API void THTensor_(uniform)(THTensor *self, double a, double b) +{ + TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_uniform(a, b);); +} + +TH_API void THTensor_(normal)(THTensor *self, double mean, double stdv) +{ + TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_normal(mean, stdv);); +} + +TH_API void THTensor_(exponential)(THTensor *self, double lambda) +{ + TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_exponential(lambda);); +} + +TH_API void THTensor_(cauchy)(THTensor *self, double median, double sigma) +{ + TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_cauchy(median, sigma);); +} + +TH_API void THTensor_(logNormal)(THTensor *self, double mean, double stdv) +{ + TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_logNormal(mean, stdv);); +} + +#endif + +#endif diff --git a/lib/TH/generic/THTensorRandom.h b/lib/TH/generic/THTensorRandom.h new file mode 100644 index 00000000000..ec320edd6e8 --- /dev/null +++ b/lib/TH/generic/THTensorRandom.h @@ -0,0 +1,17 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorRandom.h" +#else + +TH_API void THTensor_(random)(THTensor *self); +TH_API void THTensor_(geometric)(THTensor *self, double p); +TH_API void THTensor_(bernoulli)(THTensor *self, double p); + +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) +TH_API void THTensor_(uniform)(THTensor *self, double a, double b); +TH_API void THTensor_(normal)(THTensor *self, double mean, double stdv); +TH_API void THTensor_(exponential)(THTensor *self, double lambda); +TH_API void THTensor_(cauchy)(THTensor *self, double median, double sigma); +TH_API void THTensor_(logNormal)(THTensor *self, double mean, double stdv); +#endif + +#endif diff --git a/lib/TH/generic/THVector.c b/lib/TH/generic/THVector.c new file mode 100644 index 00000000000..cc64d52d3a2 --- /dev/null +++ b/lib/TH/generic/THVector.c @@ -0,0 +1,84 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THVector.c" +#else + +static inline void THVector_(fill)(real *x, const real c, const long n) { + long i = 0; + + for(; i < n-4; i += 4) + { + x[i] = c; + x[i+1] = c; + x[i+2] = c; + x[i+3] = c; + } + + for(; i < n; i++) + x[i] = c; +} + +static inline void THVector_(add)(real *y, const real *x, const real c, const long n) +{ + long i = 0; + + for(;i < n-4; i += 4) + { + y[i] += c * x[i]; + y[i+1] += c * x[i+1]; + y[i+2] += c * x[i+2]; + y[i+3] += c * x[i+3]; + } + + for(; i < n; i++) + y[i] += c * x[i]; +} + +static inline void THVector_(diff)(real *z, const real *x, const real *y, const long n) +{ + long i = 0; + + for(; i < n-4; i += 4) + { + z[i] = x[i] - y[i]; + z[i+1] + x[i+1] - y[i+1]; + z[i+2] = x[i+2] - y[i+2]; + z[i+3] = x[i+3] - y[i+3]; + } + + for(; i < n; i++) + z[i] = x[i] - y[i]; +} + +static inline void THVector_(scale)(real *y, const real c, const long n) +{ + long i = 0; + + for(; i < n-4; i +=4) + { + y[i] *= c; + y[i+1] *= c; + y[i+2] *= c; + y[i+3] *= c; + } + + for(; i < n; i++) + y[i] *= c; +} + +static inline void THVector_(mul)(real *y, const real *x, const long n) +{ + long i = 0; + + for(; i < n-4; i += 4) + { + y[i] *= x[i]; + y[i+1] *= x[i+1]; + y[i+2] *= x[i+2]; + y[i+3] *= x[i+3]; + } + + for(; i < n; i++) + y[i] *= x[i]; +} + +#endif diff --git a/lib/luaT/CMakeLists.txt b/lib/luaT/CMakeLists.txt new file mode 100644 index 00000000000..8d2cffb5ae5 --- /dev/null +++ b/lib/luaT/CMakeLists.txt @@ -0,0 +1,28 @@ +# -*- cmake -*- + +FIND_PACKAGE(Lua REQUIRED) + +INCLUDE_DIRECTORIES(${LUA_INCLUDE_DIR}) + +ADD_LIBRARY(luaT SHARED luaT.h luaT.c) +TARGET_LINK_LIBRARIES(luaT ${LUA_LIBRARIES}) + +INSTALL(TARGETS luaT + RUNTIME DESTINATION "${Torch_INSTALL_BIN_SUBDIR}" + LIBRARY DESTINATION "${Torch_INSTALL_LIB_SUBDIR}" + ARCHIVE DESTINATION "${Torch_INSTALL_LIB_SUBDIR}") + +INSTALL(FILES luaT.h + DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}") + +# Create luaT.cmake +GET_TARGET_PROPERTY(LUAT_OUTPUT_NAME luaT LOCATION) +GET_FILENAME_COMPONENT(LUAT_OUTPUT_NAME ${LUAT_OUTPUT_NAME} NAME) +SET(LUAT_LIBRARIES "${Torch_INSTALL_LIB}/${LUAT_OUTPUT_NAME}") +SET(LUAT_INCLUDE_DIR "${Torch_INSTALL_INCLUDE}") +CONFIGURE_FILE(luaTConfig.cmake.in "${Torch_BINARY_DIR}/cmake-external/luaTConfig.cmake") +INSTALL(FILES "${Torch_BINARY_DIR}/cmake-external/luaTConfig.cmake" + DESTINATION "${Torch_INSTALL_CMAKE_SUBDIR}") + +# luaT help +ADD_TORCH_DOK(dok luaT "Torch C Libraries" "luaT" 5.1) diff --git a/lib/luaT/dok/index.dok b/lib/luaT/dok/index.dok new file mode 100644 index 00000000000..9145b542fe3 --- /dev/null +++ b/lib/luaT/dok/index.dok @@ -0,0 +1,137 @@ +====== Lua Torch C API ====== +{{anchor:luat.dok}} + +luaT provides an API to interface Lua and C in Torch packages. It defines a +concept of //classes// to Lua for Torch, and provides a mechanism to easily +handle these Lua classes from C. + +It additionally provides few functions that ''luaL'' should have defined, and +redefine some ''luaL'' functions for better type error printing when using +''luaT'' classes. + +===== Memory functions ===== +{{anchor:luat.memory.dok}} + +Classical memory allocation functions which generate a Lua error in case of +problem. + +==== void* luaT_alloc(lua_State *L, long size) ==== +{{anchor:luaT_alloc}} + +Allocates ''size'' bytes, and return a pointer on the allocated +memory. A Lua error will be generated if running out of memory. + +==== void* luaT_realloc(lua_State *L, void *ptr, long size) ==== +{{anchor:luaT_realloc}} + +Realloc ''ptr'' to ''size'' bytes. ''ptr'' must have been previously +allocated with [[#luaT_alloc|luaT_alloc]] or +[[#luaT_realloc|luaT_realloc]], or the C ''malloc'' or ''realloc'' +functions. A Lua error will be generated if running out of memory. + +==== void luaT_free(lua_State *L, void *ptr) ==== +{{anchor:luaT_free}} + +Free memory allocated at address ''ptr''. The memory must have been +previously allocated with [[#luaT_alloc|luaT_alloc]] or +[[#luaT_realloc|luaT_realloc]], or the C ''malloc'' or ''realloc'' +functions. + +===== Class creation and basic handling ===== +{{anchor:luat.classcreate}} + +A ''luaT'' class is basically either a Lua //table// or //userdata// +with an appropriate //metatable//. This appropriate metatable, that +we call in this section //root metatable// is created with +[[#luaT_newmetatable|luaT_newmetatable]]. + +The root metatable of a ''luaT'' object has itself a metatable that we +call //metaclass//. The metaclass is the actual metatable containing +all the methods of the class. If the class inherit from another class, +then the metaclass will itself have a metatable corresponding to the +//parent metaclass//: the metaclasses are cascaded according to the +class inheritance. Multiple inheritance is not supported. + +The root metatable of a ''luaT'' object contains ''Lua'' operators +like ''%%__index%%'', ''%%__newindex%%'', ''%%__tostring%%'', +''%%__add%%'' (etc...). These operators will respectively look for +''//index//'', ''//newindex//'', ''//tostring//'', ''//add//'' +(etc...) in the metaclass. If found, the corresponding function or +value will be returned, else a Lua error will be raised. + +If one wants to provide ''//index//'' or ''//newindex//'' in the +metaclass, these operators must follow a particular scheme: + + * ''//index//'' must either return a value //and// ''true'' or return ''false'' only. In the first case, it means ''//index//'' was able to handle the given argument (for e.g., the type was correct). The second case means it was not able to do anything, so ''%%__index%%'' in the root metatable can then try to see if the metaclass contains the required value. + + * ''//newindex//'' must either return ''true'' or ''false''. As for ''//index//'', ''true'' means it could handle the argument and ''false'' not. If not, the root metatable ''%%__newindex%%'' will then raise an error if the object was a userdata, or apply a rawset if the object was a Lua table. + +Other metaclass operators like ''//tostring//'', ''//add//'', etc... do not have any particular constraint. + +==== const void* luaT_newmetatable(lua_State *L, const char *tname, const char *parenttname, lua_CFunction constructor, lua_CFunction destructor, lua_CFunction factory) ==== +{{anchor:luat_newmetatable}} + +==== void luaT_pushmetatable(lua_State *L, const void *id) ==== +{{anchor:luat_pushmetatable}} + +==== void luaT_pushmetaclass(lua_State *L, const void *id) ==== +{{anchor:luaT_pushmetaclass}} + +==== int luaT_getmetaclass(lua_State *L, int index) ==== +{{anchor:luaT_getmetaclass}} + + +===== Other Functions ===== + +==== void* luaT_alloc(lua_State *L, long size) ==== +==== void* luaT_realloc(lua_State *L, void *ptr, long size) ==== +==== void luaT_free(lua_State *L, void *ptr) ==== +==== void luaT_stackdump(lua_State *L) ==== + +==== void luaT_registeratid(lua_State *L, const struct luaL_Reg *methods, const void *id) ==== +==== void luaT_registeratname(lua_State *L, const struct luaL_Reg *methods, const char *name) ==== + +==== const char* luaT_id2typename(lua_State *L, const void *id) ==== +==== const void* luaT_typename2id(lua_State *L, const char*) ==== +==== const void* luaT_checktypename2id(lua_State *L, const char *tname) ==== + +==== const void* luaT_id(lua_State *L, int ud) ==== +==== const char* luaT_typename(lua_State *L, int ud) ==== + +==== void luaT_pushudata(lua_State *L, void *udata, const void *id) ==== + +==== void *luaT_toudata (lua_State *L, int ud, const void *id) ==== +==== int luaT_isudata (lua_State *L, int ud, const void *id) ==== +==== void *luaT_checkudata (lua_State *L, int ud, const void *id) ==== + +==== void *luaT_getfieldcheckudata (lua_State *L, int ud, const char *field, const void *id) ==== +==== void *luaT_getfieldchecklightudata (lua_State *L, int ud, const char *field) ==== +==== double luaT_getfieldchecknumber (lua_State *L, int ud, const char *field) ==== +==== int luaT_getfieldcheckint (lua_State *L, int ud, const char *field) ==== +==== const char* luaT_getfieldcheckstring (lua_State *L, int ud, const char *field) ==== +==== int luaT_getfieldcheckboolean (lua_State *L, int ud, const char *field) ==== +==== void luaT_getfieldchecktable (lua_State *L, int ud, const char *field) ==== + +==== int luaT_typerror(lua_State *L, int ud, const char *tname) ==== +==== int luaT_checkboolean(lua_State *L, int narg) ==== +==== int luaT_optboolean(lua_State *L, int narg, int def) ==== + +==== const char *luaT_classrootname(const char *tname) ==== +==== const char *luaT_classmodulename(const char *tname) ==== +==== void luaT_stackdump(lua_State *L) ==== + + +==== int luaT_lua_newmetatable(lua_State *L) ==== +==== int luaT_lua_factory(lua_State *L) ==== +==== int luaT_lua_getconstructortable(lua_State *L) ==== +==== int luaT_lua_id(lua_State *L) ==== +==== int luaT_lua_typename(lua_State *L) ==== +==== int luaT_lua_isequal(lua_State *L) ==== +==== int luaT_lua_pointer(lua_State *L) ==== +==== int luaT_lua_setenv(lua_State *L) ==== +==== int luaT_lua_getenv(lua_State *L) ==== +==== int luaT_lua_getmetatable(lua_State *L) ==== +==== int luaT_lua_version(lua_State *L) ==== +==== int luaT_lua_setmetatable(lua_State *L) ==== +==== int luaT_lua_typename2id(lua_State *L) ==== + diff --git a/lib/luaT/luaT.c b/lib/luaT/luaT.c new file mode 100644 index 00000000000..2cb6700cedc --- /dev/null +++ b/lib/luaT/luaT.c @@ -0,0 +1,1067 @@ +#include +#include + +#include "luaT.h" + +void* luaT_alloc(lua_State *L, long size) +{ + void *ptr; + + if(size == 0) + return NULL; + + if(size < 0) + luaL_error(L, "$ Torch: invalid memory size -- maybe an overflow?"); + + ptr = malloc(size); + if(!ptr) + luaL_error(L, "$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824); + + return ptr; +} + +void* luaT_realloc(lua_State *L, void *ptr, long size) +{ + if(!ptr) + return(luaT_alloc(L, size)); + + if(size == 0) + { + luaT_free(L, ptr); + return NULL; + } + + if(size < 0) + luaL_error(L, "$ Torch: invalid memory size -- maybe an overflow?"); + + ptr = realloc(ptr, size); + if(!ptr) + luaL_error(L, "$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824); + return ptr; +} + +void luaT_free(lua_State *L, void *ptr) +{ + free(ptr); +} + +void luaT_stackdump(lua_State *L) +{ + int i; + const char *tname; + int top = lua_gettop(L); + for(i = 1; i <= top; i++) + { + int t = lua_type(L, i); + printf("%3d. ", i); + switch(t) + { + case LUA_TSTRING: + printf("'%s'", lua_tostring(L,i)); + break; + case LUA_TBOOLEAN: + printf(lua_toboolean(L, i) ? "true" : "false"); + break; + case LUA_TNUMBER: + printf("%g", lua_tonumber(L,i)); + break; + case LUA_TUSERDATA: + tname = luaT_typename(L, i); + printf("userdata [%s]", (tname ? tname : "not a Torch object")); + break; + case LUA_TTABLE: + tname = luaT_id2typename(L, lua_topointer(L, i)); + if(tname) + printf("metaclass [%s]", tname); + else + { + tname = luaT_typename(L, i); + printf("table [%s]", (tname ? tname : "not a Torch object")); + } + break; + default: + printf("%s", lua_typename(L,t)); + break; + } + printf("\n"); + } + printf("---------------------------------------------\n"); +} + +/* Root-metatable methods */ +static int luaT_rmt__index(lua_State *L); +static int luaT_rmt__newindex(lua_State *L); +static int luaT_rmt__tostring(lua_State *L); +static int luaT_rmt__add(lua_State *L); +static int luaT_rmt__sub(lua_State *L); +static int luaT_rmt__mul(lua_State *L); +static int luaT_rmt__div(lua_State *L); +static int luaT_rmt__mod(lua_State *L); +static int luaT_rmt__pow(lua_State *L); +static int luaT_rmt__unm(lua_State *L); +static int luaT_rmt__concat(lua_State *L); +static int luaT_rmt__len(lua_State *L); +static int luaT_rmt__eq(lua_State *L); +static int luaT_rmt__lt(lua_State *L); +static int luaT_rmt__le(lua_State *L); +static int luaT_rmt__call(lua_State *L); + +/* Constructor-metatable methods */ +static int luaT_cmt__call(lua_State *L); +static int luaT_cmt__newindex(lua_State *L); + +const void* luaT_newmetatable(lua_State *L, const char *tname, const char *parenttname, + lua_CFunction constructor, lua_CFunction destructor, lua_CFunction factory) +{ + lua_pushcfunction(L, luaT_lua_newmetatable); + lua_pushstring(L, tname); + (parenttname ? lua_pushstring(L, parenttname) : lua_pushnil(L)); + (constructor ? lua_pushcfunction(L, constructor) : lua_pushnil(L)); + (destructor ? lua_pushcfunction(L, destructor) : lua_pushnil(L)); + (factory ? lua_pushcfunction(L, factory) : lua_pushnil(L)); + lua_call(L, 5, 1); + return lua_topointer(L, -1); +} + +void luaT_pushmetatable(lua_State *L, const void *id) +{ + lua_pushlightuserdata(L, (void*)id); + lua_rawget(L, LUA_REGISTRYINDEX); +} + +void luaT_pushmetaclass(lua_State *L, const void *id) +{ + luaT_pushmetatable(L, id); + if(!lua_isnil(L, -1)) + { + if(!lua_getmetatable(L, -1)) + luaL_error(L, "internal error: cannot find metaclass"); + lua_remove(L, -2); /* remove metatable */ + } +} + +const char* luaT_id2typename(lua_State *L, const void *id) +{ + const char* tname = NULL; + + lua_getfield(L, LUA_REGISTRYINDEX, "*torch.id2tname*"); + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); + return NULL; + } + lua_pushlightuserdata(L, (void*)id); + lua_gettable(L, -2); + if(!lua_isnil(L, -1)) + tname = lua_tostring(L, -1); + lua_pop(L, 2); + return tname; /* still exists, because in a table ... */ +} + +const void* luaT_typename2id(lua_State *L, const char *tname) +{ + const void *id = NULL; + + lua_getfield(L, LUA_REGISTRYINDEX, "*torch.tname2id*"); + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); + return NULL; + } + lua_pushstring(L, tname); + lua_gettable(L, -2); + if(!lua_isnil(L, -1)) + id = lua_topointer(L, -1); + lua_pop(L, 2); + return id; /* id still exists, because in a table ... */ +} + +const void* luaT_checktypename2id(lua_State *L, const char *tname) +{ + const void* id = luaT_typename2id(L, tname); + if(!id) + luaL_error(L, "unknown class <%s>", tname); + return id; +} + +int luaT_getmetaclass(lua_State *L, int index) +{ + if(lua_getmetatable(L, index)) /* get metatable */ + { + if(lua_getmetatable(L, -1)) /* get metaclass */ + { + lua_remove(L, -2); + return 1; + } + else + { + lua_pop(L, 1); + return 0; + } + } + return 0; +} + +const void* luaT_id(lua_State *L, int ud) +{ + if(luaT_getmetaclass(L, ud)) + { + const char *id = lua_topointer(L, -1); + lua_pop(L, 1); + if(luaT_id2typename(L, id)) + return id; + } + return NULL; +} + +const char* luaT_typename(lua_State *L, int ud) +{ + if(luaT_getmetaclass(L, ud)) + { + const char *tname = luaT_id2typename(L, lua_topointer(L, -1)); + lua_pop(L, 1); + return tname; + } + return NULL; +} + +void luaT_pushudata(lua_State *L, void *udata, const void *id) +{ + if(udata) + { + void **udata_p = lua_newuserdata(L, sizeof(void*)); + *udata_p = udata; + luaT_pushmetatable(L, id); + if(lua_isnil(L, -1)) + luaL_error(L, "Torch internal problem: cannot find a metatable"); + lua_setmetatable(L, -2); + } + else + lua_pushnil(L); +} + +void *luaT_toudata (lua_State *L, int ud, const void *id) +{ + void **p = lua_touserdata(L, ud); + if (p != NULL) /* value is a userdata? */ + { + lua_pushvalue(L, ud); /* initialize the table we want to get the metatable on */ + if(lua_getmetatable(L, -1)) /* get the metatable */ + { + lua_remove(L, -2); /* remove the original value */ + while(lua_getmetatable(L, -1)) /* get the next metaclass */ + { + lua_remove(L, -2); /* remove the original metatable/metaclass */ + if(lua_topointer(L, -1) == id) + { + lua_pop(L, 1); /* remove metaclass */ + return *p; + } + } + } + lua_pop(L, 1); /* remove remaing value/metatable/metaclass */ + } + return NULL; +} + +int luaT_isudata(lua_State *L, int ud, const void *id) +{ + if(luaT_toudata(L, ud, id)) + return 1; + else + return 0; +} + +void *luaT_checkudata (lua_State *L, int ud, const void *id) +{ + void *p = luaT_toudata(L, ud, id); + if(!p) + luaT_typerror(L, ud, luaT_id2typename(L, id)); + return p; +} + +void *luaT_getfieldcheckudata (lua_State *L, int ud, const char *field, const void *id) +{ + void *p; + lua_getfield(L, ud, field); + if(lua_isnil(L, -1)) + luaL_error(L, "bad argument #%d (field %s does not exist)", ud, field); + p = luaT_toudata(L, -1, id); + if(!p) + luaL_error(L, "bad argument #%d (field %s is not a %s)", ud, field, luaT_id2typename(L, id)); + return p; +} + +void *luaT_getfieldchecklightudata (lua_State *L, int ud, const char *field) +{ + void *p; + lua_getfield(L, ud, field); + if(lua_isnil(L, -1)) + luaL_error(L, "bad argument #%d (field %s does not exist)", ud, field); + + if(!lua_islightuserdata(L, -1)) + luaL_error(L, "bad argument #%d (field %s is not a light userdata)", ud, field); + + p = lua_touserdata(L, -1); + + return p; +} + +double luaT_getfieldchecknumber (lua_State *L, int ud, const char *field) +{ + lua_getfield(L, ud, field); + if(lua_isnil(L, -1)) + luaL_error(L, "bad argument #%d (field %s does not exist)", ud, field); + if(!lua_isnumber(L, -1)) + luaL_error(L, "bad argument #%d (field %s is not a number)", ud, field); + return lua_tonumber(L, -1); +} + +int luaT_getfieldcheckint (lua_State *L, int ud, const char *field) +{ + lua_getfield(L, ud, field); + if(lua_isnil(L, -1)) + luaL_error(L, "bad argument #%d (field %s does not exist)", ud, field); + if(!lua_isnumber(L, -1)) + luaL_error(L, "bad argument #%d (field %s is not a number)", ud, field); + return (int)lua_tonumber(L, -1); +} + +const char* luaT_getfieldcheckstring (lua_State *L, int ud, const char *field) +{ + lua_getfield(L, ud, field); + if(lua_isnil(L, -1)) + luaL_error(L, "bad argument #%d (field %s does not exist)", ud, field); + if(!lua_isstring(L, -1)) + luaL_error(L, "bad argument #%d (field %s is not a string)", ud, field); + return lua_tostring(L, -1); +} + +int luaT_getfieldcheckboolean (lua_State *L, int ud, const char *field) +{ + lua_getfield(L, ud, field); + if(lua_isnil(L, -1)) + luaL_error(L, "bad argument #%d (field %s does not exist)", ud, field); + if(!lua_isboolean(L, -1)) + luaL_error(L, "bad argument #%d (field %s is not a boolean)", ud, field); + return lua_toboolean(L, -1); +} + +void luaT_getfieldchecktable (lua_State *L, int ud, const char *field) +{ + lua_getfield(L, ud, field); + if(lua_isnil(L, -1)) + luaL_error(L, "bad argument #%d (field %s does not exist)", ud, field); + if(!lua_istable(L, -1)) + luaL_error(L, "bad argument #%d (field %s is not a table)", ud, field); +} + +/**** type checks as in luaL ****/ +int luaT_typerror(lua_State *L, int narg, const char *tname) +{ + const char *tnamenarg = (lua_istable(L, narg) ? luaT_id2typename(L, lua_topointer(L, narg)) : NULL); + const char *msg; + + if(tnamenarg) + { + msg = lua_pushfstring(L, "%s expected, got %s metatable", + tname, + (tnamenarg ? tnamenarg : luaL_typename(L, narg))); + } + else + { + tnamenarg = luaT_typename(L, narg); + msg = lua_pushfstring(L, "%s expected, got %s", + tname, + (tnamenarg ? tnamenarg : luaL_typename(L, narg))); + } + return luaL_argerror(L, narg, msg); +} + +int luaT_checkboolean(lua_State *L, int narg) +{ + if(!lua_isboolean(L, narg)) + luaT_typerror(L, narg, lua_typename(L, LUA_TBOOLEAN)); + return lua_toboolean(L, narg); +} + +int luaT_optboolean(lua_State *L, int narg, int def) +{ + if(lua_isnoneornil(L,narg)) + return def; + + return luaT_checkboolean(L, narg); +} + + +/* utility functions */ +const char *luaT_classrootname(const char *tname) +{ + int i; + int sz = strlen(tname); + + for(i = 0; i < sz; i++) + { + if(tname[i] == '.') + return tname+i+1; + } + return tname; +} + +static char luaT_class_module_name[256]; +const char *luaT_classmodulename(const char *tname) +{ + int i; + + strncpy(luaT_class_module_name, tname, 256); + + for(i = 0; i < 256; i++) + { + if(luaT_class_module_name[i] == '\0') + break; + if(luaT_class_module_name[i] == '.') + { + luaT_class_module_name[i] = '\0'; + return luaT_class_module_name; + } + } + return NULL; +} + +void luaT_registeratid(lua_State *L, const struct luaL_Reg *methods, const void *id) +{ + int idx = lua_gettop(L); + + luaL_checktype(L, idx, LUA_TTABLE); + lua_pushlightuserdata(L, (void*)id); + lua_rawget(L, idx); + + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); + lua_pushlightuserdata(L, (void*)id); + lua_newtable(L); + lua_rawset(L, idx); + + lua_pushlightuserdata(L, (void*)id); + lua_rawget(L, idx); + } + + luaL_register(L, NULL, methods); + lua_pop(L, 1); +} + +void luaT_registeratname(lua_State *L, const struct luaL_Reg *methods, const char *name) +{ + int idx = lua_gettop(L); + + luaL_checktype(L, idx, LUA_TTABLE); + lua_pushstring(L, name); + lua_rawget(L, idx); + + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); + lua_pushstring(L, name); + lua_newtable(L); + lua_rawset(L, idx); + + lua_pushstring(L, name); + lua_rawget(L, idx); + } + + luaL_register(L, NULL, methods); + lua_pop(L, 1); +} + +/* Lua only functions */ +int luaT_lua_newmetatable(lua_State *L) +{ + const char* tname = luaL_checkstring(L, 1); + const void *id; + + lua_settop(L, 5); + luaL_argcheck(L, lua_isnoneornil(L, 2) || lua_isstring(L, 2), 2, "parent class name or nil expected"); + luaL_argcheck(L, lua_isnoneornil(L, 3) || lua_isfunction(L, 3), 3, "constructor function or nil expected"); + luaL_argcheck(L, lua_isnoneornil(L, 4) || lua_isfunction(L, 4), 4, "destructor function or nil expected"); + luaL_argcheck(L, lua_isnoneornil(L, 5) || lua_isfunction(L, 5), 5, "factory function or nil expected"); + + if(luaT_classmodulename(tname)) + lua_getfield(L, LUA_GLOBALSINDEX, luaT_classmodulename(tname)); + else + lua_pushvalue(L, LUA_GLOBALSINDEX); + if(!lua_istable(L, 6)) + luaL_error(L, "while creating metatable %s: bad argument #1 (%s is an invalid module name)", tname, luaT_classmodulename(tname)); + + /* we first create the new metaclass if we have to */ + if(!luaT_typename2id(L, tname)) + { + /* create the metaclass */ + lua_newtable(L); + id = lua_topointer(L, -1); /* id = pointer on metaclass */ + + /* __index points on itself */ + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + + /* new points to constructor */ + lua_pushvalue(L, 3); + lua_setfield(L, -2, "new"); + + /* __typename contains the typename */ + lua_pushstring(L, tname); + lua_setfield(L, -2, "__typename"); + + /* by default, __version equals 1 */ + lua_pushnumber(L, 1); + lua_setfield(L, -2, "__version"); + + /* register in "*torch.id2tname*" registry table + (id -> typename) */ + lua_getfield(L, LUA_REGISTRYINDEX, "*torch.id2tname*"); + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); + lua_newtable(L); + lua_setfield(L, LUA_REGISTRYINDEX, "*torch.id2tname*"); + lua_getfield(L, LUA_REGISTRYINDEX, "*torch.id2tname*"); + } + lua_pushlightuserdata(L, (void*)id); + lua_pushstring(L, tname); + lua_settable(L, -3); + lua_pop(L, 1); + + /* register in "*torch.tname2id*" registry table + (typename -> id) */ + lua_getfield(L, LUA_REGISTRYINDEX, "*torch.tname2id*"); + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); + lua_newtable(L); + lua_setfield(L, LUA_REGISTRYINDEX, "*torch.tname2id*"); + lua_getfield(L, LUA_REGISTRYINDEX, "*torch.tname2id*"); + } + lua_pushstring(L, tname); + lua_pushlightuserdata(L, (void*)id); + lua_settable(L, -3); + lua_pop(L, 1); + } + + /* we retrieve the existing metaclass */ + else + { + id = luaT_typename2id(L, tname); + luaT_pushmetaclass(L, id); + } + + /* we assign the parent class if necessary */ + if(!lua_isnoneornil(L, 2)) + { + if(lua_getmetatable(L, -1)) + luaL_error(L, "class %s has been already assigned a parent class\n", tname); + else + { + const char* parenttname = luaL_checkstring(L, 2); + luaT_pushmetaclass(L, luaT_typename2id(L, parenttname)); + if(lua_isnil(L, -1)) + luaL_error(L, "bad argument #2 (invalid parent class name %s)", parenttname); + lua_setmetatable(L, -2); + } + } + + /******** root-metatable ********/ + + /* id is the pointer on the metatable + registry[id] = root-metatable, so try to see if it exists */ + + lua_pushlightuserdata(L, (void*)id); /* id */ + lua_rawget(L, LUA_REGISTRYINDEX); + + /* not existing? we create a new one! */ + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); /* remove nil on stack */ + lua_newtable(L); + + /* __index handling */ + lua_pushcfunction(L, luaT_rmt__index); + lua_setfield(L, -2, "__index"); + + /* __newindex handling */ + lua_pushcfunction(L, luaT_rmt__newindex); + lua_setfield(L, -2, "__newindex"); + + /* __metatable field (point on the metaclass) */ + lua_pushvalue(L, -2); + lua_setfield(L, -2, "__metatable"); + + /* __typename contains the typename */ + lua_pushstring(L, tname); + lua_setfield(L, -2, "__typename"); + + /* operators handling */ +#define MT_ADD_OPERATOR(name) \ + lua_pushcfunction(L, luaT_rmt__##name); \ + lua_setfield(L, -2, "__" #name) + + MT_ADD_OPERATOR(tostring); + MT_ADD_OPERATOR(add); + MT_ADD_OPERATOR(sub); + MT_ADD_OPERATOR(mul); + MT_ADD_OPERATOR(div); + MT_ADD_OPERATOR(mod); + MT_ADD_OPERATOR(pow); + MT_ADD_OPERATOR(unm); + MT_ADD_OPERATOR(concat); + MT_ADD_OPERATOR(len); + MT_ADD_OPERATOR(eq); + MT_ADD_OPERATOR(lt); + MT_ADD_OPERATOR(le); + MT_ADD_OPERATOR(call); + + /* assign the metaclass as metatable... */ + lua_pushvalue(L, -2); + lua_setmetatable(L, -2); + + /* id is the pointer on the metatable + set registry[id] = root-metatable */ + lua_pushlightuserdata(L, (void*)id); /* id */ + lua_pushvalue(L, -2); /* metatable */ + lua_rawset(L, LUA_REGISTRYINDEX); /* registry[id] = metatable */ + + } /* ok, so now we have the root-metatable on the stack */ + + /* register the destructor function */ + if(!lua_isnoneornil(L, 4)) + { + /* does it exists already? */ + lua_pushstring(L, "__gc"); + lua_rawget(L, -2); + + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); /* pop nil */ + lua_pushstring(L, "__gc"); + lua_pushvalue(L, 4); + lua_rawset(L, -3); + } + else + luaL_error(L, "%s has been already assigned a destructor", tname); + } + + /* register the factory function */ + if(!lua_isnoneornil(L, 5)) + { + /* does it exists already? */ + lua_pushstring(L, "__factory"); + lua_rawget(L, -2); + + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); /* pop nil */ + lua_pushstring(L, "__factory"); + lua_pushvalue(L, 5); + lua_rawset(L, -3); + } + else + luaL_error(L, "%s has been already assigned a factory", tname); + } + + /******** Constructor table and metatable ********/ + lua_pushstring(L, "__constructor"); + lua_rawget(L, -2); + + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); /* pop nil */ + lua_newtable(L); /* fancy table */ + lua_newtable(L); /* fancy metatable */ + + lua_pushvalue(L, -4); /* metaclass */ + lua_setfield(L, -2, "__index"); /* so we can get the methods */ + + lua_pushcfunction(L, luaT_cmt__newindex); + lua_setfield(L, -2, "__newindex"); /* so we cannot messup */ + + lua_pushcfunction(L, luaT_cmt__call); + lua_setfield(L, -2, "__call"); /* so we can create */ + + lua_pushvalue(L, -4); + lua_setfield(L, -2, "__metatable"); /* redirect to metatable with methods */ + + lua_setmetatable(L, -2); /* metatable is ... the fancy metatable */ + + /* set root-metatable[__constructor] = constructor-metatable */ + lua_pushstring(L, "__constructor"); + lua_pushvalue(L, -2); + lua_rawset(L, -4); + } + + /* register the constructor function */ + if(!lua_isnoneornil(L, 3)) + { + /* get constructor metatable */ + lua_getmetatable(L, -1); + + /* does it exists already? */ + lua_pushstring(L, "__new"); + lua_rawget(L, -2); + + if(lua_isnil(L, -1)) + { + lua_pop(L, 1); /* pop nil */ + lua_pushstring(L, "__new"); + lua_pushvalue(L, 3); + lua_rawset(L, -3); + } + else + luaL_error(L, "%s has been already assigned a constructor", tname); + + /* pop constructor metatable */ + lua_pop(L, 1); + } + + lua_setfield(L, 6, luaT_classrootname(tname)); /* module.name = constructor-metatable */ + lua_pop(L, 1); /* pop the root-metatable */ + + return 1; /* returns the metaclass */ +} + + +/* Lua only utility functions */ +int luaT_lua_factory(lua_State *L) +{ + const char* tname = luaL_checkstring(L, 1); + luaT_pushmetatable(L, luaT_typename2id(L, tname)); + if(!lua_isnil(L, -1)) + { + lua_pushstring(L, "__factory"); + lua_rawget(L, -2); + } + return 1; +} + +int luaT_lua_getconstructortable(lua_State *L) +{ + const char* tname = luaL_checkstring(L, 1); + luaT_pushmetatable(L, luaT_typename2id(L, tname)); + if(!lua_isnil(L, -1)) + { + lua_pushstring(L, "__constructor"); + lua_rawget(L, -2); + } + return 1; +} + + +int luaT_lua_typename(lua_State *L) +{ + luaL_checkany(L, 1); + + if(luaT_getmetaclass(L, 1)) + { + const char *tname = luaT_id2typename(L, lua_topointer(L, -1)); + if(tname) + { + lua_pushstring(L, tname); + return 1; + } + } + return 0; +} + +int luaT_lua_id(lua_State *L) +{ + const void *id; + + luaL_checkany(L, 1); + id = luaT_id(L, 1); + if(id) + lua_pushlightuserdata(L, (void*)id); + else + lua_pushnil(L); + + return 1; +} + +int luaT_lua_typename2id(lua_State *L) +{ + const char* typename = luaL_checkstring(L, 1); + const void* id = luaT_typename2id(L, typename); + if(id) + lua_pushlightuserdata(L, (void*)id); + else + lua_pushnil(L); + return 1; +} + +int luaT_lua_isequal(lua_State *L) +{ + if(lua_isuserdata(L, 1) && lua_isuserdata(L, 2)) + { + void **u1, **u2; + luaL_argcheck(L, luaT_id(L, 1), 1, "Torch object expected"); + luaL_argcheck(L, luaT_id(L, 2), 2, "Torch object expected"); + + u1 = lua_touserdata(L, 1); + u2 = lua_touserdata(L, 2); + if(*u1 == *u2) + lua_pushboolean(L, 1); + else + lua_pushboolean(L, 0); + } + else if(lua_istable(L, 1) && lua_istable(L, 2)) + lua_pushboolean(L, lua_rawequal(L, 1, 2)); + else + lua_pushboolean(L, 0); + return 1; +} + +int luaT_lua_pointer(lua_State *L) +{ + if(lua_isuserdata(L, 1)) + { + void **ptr; + luaL_argcheck(L, luaT_id(L, 1), 1, "Torch object expected"); + ptr = lua_touserdata(L, 1); + lua_pushnumber(L, (long)(*ptr)); + return 1; + } + else if(lua_istable(L, 1) || lua_isthread(L, 1) || lua_isfunction(L, 1)) + { + const void* ptr = lua_topointer(L, 1); + lua_pushnumber(L, (long)(ptr)); + return 1; + } + else + luaL_error(L, "Torch object, table, thread or function expected"); + + return 0; +} + +int luaT_lua_setenv(lua_State *L) +{ + if(!lua_isfunction(L, 1) && !lua_isuserdata(L, 1)) + luaL_typerror(L, 1, "function or userdata"); + luaL_checktype(L, 2, LUA_TTABLE); + lua_setfenv(L, 1); + return 0; +} + +int luaT_lua_getenv(lua_State *L) +{ + if(!lua_isfunction(L, 1) && !lua_isuserdata(L, 1)) + luaL_typerror(L, 1, "function or userdata"); + lua_getfenv(L, 1); + return 1; +} + +int luaT_lua_getmetatable(lua_State *L) +{ + const char *tname = luaL_checkstring(L, 1); + luaT_pushmetaclass(L, luaT_typename2id(L, tname)); /* note: in Lua, root-metatable is hidden, so... you get it eh... */ + return 1; +} + +int luaT_lua_version(lua_State *L) +{ + luaL_checkany(L, 1); + + if(luaT_getmetaclass(L, 1)) + { + lua_pushstring(L, "__version"); + lua_rawget(L, -2); + return 1; + } + return 0; +} + +int luaT_lua_setmetatable(lua_State *L) +{ + const char *tname = luaL_checkstring(L, 2); + luaL_checktype(L, 1, LUA_TTABLE); + + lua_pushvalue(L, 1); + luaT_pushmetatable(L, luaT_typename2id(L, tname)); + if(lua_isnil(L, -1)) + luaL_error(L, "unknown typename %s\n", tname); + lua_setmetatable(L, -2); + return 1; +} + +/* root-metatable functions */ +static int luaT_rmt__index(lua_State *L) +{ + if(!luaT_getmetaclass(L, 1)) + luaL_error(L, "critical internal indexing error: no metatable found"); + + if(!lua_istable(L, -1)) + luaL_error(L, "critical internal indexing error: not a metatable"); + + /* test for __index__ method first */ + lua_getfield(L, -1, "__index__"); + if(!lua_isnil(L, -1)) + { + int result; + + if(!lua_isfunction(L, -1)) + luaL_error(L, "critical internal indexing error: __index__ is not a function"); + + lua_pushvalue(L, 1); + lua_pushvalue(L, 2); + + lua_call(L, 2, LUA_MULTRET); /* DEBUG: risque: faut vraiment retourner 1 ou 2 valeurs... */ + + result = lua_toboolean(L, -1); + lua_pop(L, 1); + + if(result) + return 1; + + /* on the stack: 1. the object 2. the value 3. the metatable */ + /* apparently, __index wants only one element returned */ + /* return lua_gettop(L)-3; */ + + } + else + lua_pop(L, 1); /* remove nil __index__ on the stack */ + + lua_pushvalue(L, 2); + lua_gettable(L, -2); + + return 1; +} + +static int luaT_rmt__newindex(lua_State *L) +{ + if(!luaT_getmetaclass(L, 1)) + luaL_error(L, "critical internal indexing error: no metatable found"); + + if(!lua_istable(L, -1)) + luaL_error(L, "critical internal indexing error: not a metatable"); + + /* test for __newindex__ method first */ + lua_getfield(L, -1, "__newindex__"); + if(!lua_isnil(L, -1)) + { + int result; + + if(!lua_isfunction(L, -1)) + luaL_error(L, "critical internal indexing error: __newindex__ is not a function"); + + lua_pushvalue(L, 1); + lua_pushvalue(L, 2); + lua_pushvalue(L, 3); + + lua_call(L, 3, 1); /* DEBUG: risque: faut vraiment retourner qqch */ + + result = lua_toboolean(L, -1); + lua_pop(L, 1); + + if(result) + return 0; + } + else + lua_pop(L, 1); /* remove nil __newindex__ on the stack */ + + lua_pop(L, 1); /* pop the metaclass */ + if(lua_istable(L, 1)) + lua_rawset(L, 1); + else + luaL_error(L, "the class %s cannot be indexed", luaT_typename(L, 1)); + + return 0; +} + +/* note: check dans metatable pour ca, donc necessaire */ +#define MT_DECLARE_OPERATOR(NAME, NIL_BEHAVIOR) \ +int luaT_rmt__##NAME(lua_State *L) \ +{ \ + if(!lua_getmetatable(L, 1)) \ + luaL_error(L, "internal error in __" #NAME ": no metatable"); \ +\ + if(!lua_getmetatable(L, -1)) \ + luaL_error(L, "internal error in __" #NAME ": no metaclass"); \ +\ + lua_getfield(L, -1, "__" #NAME "__"); \ +\ + if(lua_isnil(L, -1)) \ + { \ + NIL_BEHAVIOR; \ + } \ + else \ + { \ + if(lua_isfunction(L, -1)) \ + { \ + lua_insert(L, 1); /* insert function */ \ + lua_pop(L, 2); /* remove metatable and metaclass */ \ + lua_call(L, lua_gettop(L)-1, 1); /* we return the result of the call */ \ + } \ + /* we return the thing the user left in __tostring__ */ \ + } \ + return 1; \ +} + +MT_DECLARE_OPERATOR(tostring, lua_pushstring(L, luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(add, luaL_error(L, "%s has no addition operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(sub, luaL_error(L, "%s has no substraction operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(mul, luaL_error(L, "%s has no multiplication operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(div, luaL_error(L, "%s has no division operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(mod, luaL_error(L, "%s has no modulo operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(pow, luaL_error(L, "%s has no power operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(unm, luaL_error(L, "%s has no negation operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(concat, luaL_error(L, "%s has no concat operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(len, luaL_error(L, "%s has no length operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(eq, + lua_settop(L, 2); + lua_pushcfunction(L, luaT_lua_isequal); + lua_insert(L, 1); + lua_call(L, 2, 1);) +MT_DECLARE_OPERATOR(lt, luaL_error(L, "%s has no lower than operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(le, luaL_error(L, "%s has no lower or equal than operator", luaT_typename(L, 1))) +MT_DECLARE_OPERATOR(call, luaL_error(L, "%s has no call operator", luaT_typename(L, 1))) + +/* constructor metatable methods */ +int luaT_cmt__call(lua_State *L) +{ + if(!lua_istable(L, 1)) + luaL_error(L, "internal error in __call: not a constructor table"); + + if(!lua_getmetatable(L, 1)) + luaL_error(L, "internal error in __call: no metatable available"); + + lua_pushstring(L, "__new"); + lua_rawget(L, -2); + + if(lua_isnil(L, -1)) + luaL_error(L, "no constructor available"); + + lua_remove(L, 1); /* remove root metatable */ + lua_insert(L, 1); /* insert constructor */ + lua_pop(L, 1); /* remove fancy metatable */ + + lua_call(L, lua_gettop(L)-1, 1); + return 1; +} + +int luaT_cmt__newindex(lua_State *L) +{ + if(!lua_istable(L, 1)) + luaL_error(L, "internal error in __newindex: not a constructor table"); + + if(!lua_getmetatable(L, 1)) + luaL_error(L, "internal error in __newindex: no metatable available"); + + lua_pushstring(L, "__metatable"); + lua_rawget(L, -2); + + if(!lua_istable(L, -1)) + luaL_error(L, "internal error in __newindex: no metaclass available"); + + lua_insert(L, 2); + lua_pop(L, 1); /* remove the metatable over the constructor table */ + + lua_rawset(L, -3); + + return 0; +} diff --git a/lib/luaT/luaT.h b/lib/luaT/luaT.h new file mode 100644 index 00000000000..5c6f235a1a0 --- /dev/null +++ b/lib/luaT/luaT.h @@ -0,0 +1,93 @@ +#ifndef LUAT_UTILS_INC +#define LUAT_UTILS_INC + +#include +#include + +#ifndef LUA_EXTERNC +# ifdef __cplusplus +# define LUA_EXTERNC extern "C" +# else +# define LUA_EXTERNC extern +# endif +#endif + + +#ifdef _MSC_VER +# define DLL_EXPORT __declspec(dllexport) +# define DLL_IMPORT __declspec(dllimport) +# ifdef luaT_EXPORTS +# define LUAT_API LUA_EXTERNC DLL_EXPORT +# else +# define LUAT_API LUA_EXTERNC DLL_IMPORT +# endif +#else +# define DLL_EXPORT +# define DLL_IMPORT +# define LUAT_API LUA_EXTERNC +#endif + + +/* C functions */ + +LUAT_API void* luaT_alloc(lua_State *L, long size); +LUAT_API void* luaT_realloc(lua_State *L, void *ptr, long size); +LUAT_API void luaT_free(lua_State *L, void *ptr); +LUAT_API void luaT_stackdump(lua_State *L); + +LUAT_API void luaT_registeratid(lua_State *L, const struct luaL_Reg *methods, const void *id); +LUAT_API void luaT_registeratname(lua_State *L, const struct luaL_Reg *methods, const char *name); + +LUAT_API const void* luaT_newmetatable(lua_State *L, const char *tname, const char *parenttname, + lua_CFunction constructor, lua_CFunction destructor, lua_CFunction factory); + +LUAT_API void luaT_pushmetatable(lua_State *L, const void *id); +LUAT_API void luaT_pushmetaclass(lua_State *L, const void *id); +LUAT_API int luaT_getmetaclass(lua_State *L, int index); + +LUAT_API const char* luaT_id2typename(lua_State *L, const void *id); +LUAT_API const void* luaT_typename2id(lua_State *L, const char*); +LUAT_API const void* luaT_checktypename2id(lua_State *L, const char *tname); + +LUAT_API const void* luaT_id(lua_State *L, int ud); +LUAT_API const char* luaT_typename(lua_State *L, int ud); + +LUAT_API void luaT_pushudata(lua_State *L, void *udata, const void *id); + +LUAT_API void *luaT_toudata (lua_State *L, int ud, const void *id); +LUAT_API int luaT_isudata (lua_State *L, int ud, const void *id); +LUAT_API void *luaT_checkudata (lua_State *L, int ud, const void *id); + +LUAT_API void *luaT_getfieldcheckudata (lua_State *L, int ud, const char *field, const void *id); +LUAT_API void *luaT_getfieldchecklightudata (lua_State *L, int ud, const char *field); +LUAT_API double luaT_getfieldchecknumber (lua_State *L, int ud, const char *field); +LUAT_API int luaT_getfieldcheckint (lua_State *L, int ud, const char *field); +LUAT_API const char* luaT_getfieldcheckstring (lua_State *L, int ud, const char *field); +LUAT_API int luaT_getfieldcheckboolean (lua_State *L, int ud, const char *field); +LUAT_API void luaT_getfieldchecktable (lua_State *L, int ud, const char *field); + +LUAT_API int luaT_typerror(lua_State *L, int ud, const char *tname); +LUAT_API int luaT_checkboolean(lua_State *L, int narg); +LUAT_API int luaT_optboolean(lua_State *L, int narg, int def); + +LUAT_API const char *luaT_classrootname(const char *tname); +LUAT_API const char *luaT_classmodulename(const char *tname); +LUAT_API void luaT_stackdump(lua_State *L); + +/* Lua functions */ + +LUAT_API int luaT_lua_newmetatable(lua_State *L); +LUAT_API int luaT_lua_factory(lua_State *L); +LUAT_API int luaT_lua_getconstructortable(lua_State *L); +LUAT_API int luaT_lua_id(lua_State *L); +LUAT_API int luaT_lua_typename(lua_State *L); +LUAT_API int luaT_lua_isequal(lua_State *L); +LUAT_API int luaT_lua_pointer(lua_State *L); +LUAT_API int luaT_lua_setenv(lua_State *L); +LUAT_API int luaT_lua_getenv(lua_State *L); +LUAT_API int luaT_lua_getmetatable(lua_State *L); +LUAT_API int luaT_lua_version(lua_State *L); +LUAT_API int luaT_lua_setmetatable(lua_State *L); +LUAT_API int luaT_lua_typename2id(lua_State *L); + +#endif diff --git a/lib/luaT/luaTConfig.cmake.in b/lib/luaT/luaTConfig.cmake.in new file mode 100644 index 00000000000..bfb20b87a4c --- /dev/null +++ b/lib/luaT/luaTConfig.cmake.in @@ -0,0 +1,9 @@ +# Find the luaT includes and library +# +# LUAT_INCLUDE_DIR -- where to find the includes +# LUAT_LIBRARIES -- list of libraries to link against +# LUAT_FOUND -- set to 1 if found + +SET(LUAT_FOUND 1) +SET(LUAT_INCLUDE_DIR "@LUAT_INCLUDE_DIR@") +SET(LUAT_LIBRARIES "@LUAT_LIBRARIES@") diff --git a/random.lua b/random.lua new file mode 100644 index 00000000000..b6d49601b44 --- /dev/null +++ b/random.lua @@ -0,0 +1,29 @@ +local interface = wrap.CInterface.new() + +interface:print( + [[ +#include "luaT.h" +#include "TH.h" + ]]) + +for _,name in ipairs({"seed", "initialSeed"}) do + interface:wrap(name, + string.format("THRandom_%s",name), + {{name="long", creturned=true}}) +end + +interface:wrap('manualSeed', + 'THRandom_manualSeed', + {{name="long"}}) + +interface:register("random__") + +interface:print( + [[ +void torch_random_init(lua_State *L) +{ + luaL_register(L, NULL, random__); +} + ]]) + +interface:tofile(arg[1]) diff --git a/test/test.lua b/test/test.lua new file mode 100644 index 00000000000..4abc18a6172 --- /dev/null +++ b/test/test.lua @@ -0,0 +1,352 @@ +--require 'torch' + +local mytester +local torchtest = {} +local msize = 100 + +local function maxdiff(x,y) + local d = x-y + if x:type() == 'torch.DoubleTensor' or x:type() == 'torch.FloatTensor' then + return d:abs():maxall() + else + local dd = torch.Tensor():resize(d:size()):copy(d) + return dd:abs():maxall() + end +end + +function torchtest.max() + local x = torch.rand(msize,msize) + local mx,ix = torch.max(x,1) + local mxx = torch.Tensor() + local ixx = torch.LongTensor() + torch.max(mxx,ixx,x,1) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.max value') + mytester:asserteq(maxdiff(ix,ixx),0,'torch.max index') +end +function torchtest.min() + local x = torch.rand(msize,msize) + local mx,ix = torch.min(x) + local mxx = torch.Tensor() + local ixx = torch.LongTensor() + torch.min(mxx,ixx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.min value') + mytester:asserteq(maxdiff(ix,ixx),0,'torch.min index') +end +function torchtest.sum() + local x = torch.rand(msize,msize) + local mx = torch.sum(x) + local mxx = torch.Tensor() + torch.sum(mxx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.sum value') +end +function torchtest.prod() + local x = torch.rand(msize,msize) + local mx = torch.prod(x) + local mxx = torch.Tensor() + torch.prod(mxx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.prod value') +end +function torchtest.cumsum() + local x = torch.rand(msize,msize) + local mx = torch.cumsum(x) + local mxx = torch.Tensor() + torch.cumsum(mxx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.cumsum value') +end +function torchtest.cumprod() + local x = torch.rand(msize,msize) + local mx = torch.cumprod(x) + local mxx = torch.Tensor() + torch.cumprod(mxx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.cumprod value') +end +function torchtest.cross() + local x = torch.rand(msize,3,msize) + local y = torch.rand(msize,3,msize) + local mx = torch.cross(x,y) + local mxx = torch.Tensor() + torch.cross(mxx,x,y) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.cross value') +end +function torchtest.zeros() + local mx = torch.zeros(msize,msize) + local mxx = torch.Tensor() + torch.zeros(mxx,msize,msize) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.zeros value') +end +function torchtest.ones() + local mx = torch.ones(msize,msize) + local mxx = torch.Tensor() + torch.ones(mxx,msize,msize) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.ones value') +end +function torchtest.diag() + local x = torch.rand(msize,msize) + local mx = torch.diag(x) + local mxx = torch.Tensor() + torch.diag(mxx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.diag value') +end +function torchtest.eye() + local mx = torch.eye(msize,msize) + local mxx = torch.Tensor() + torch.eye(mxx,msize,msize) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.eye value') +end +function torchtest.range() + local mx = torch.range(0,1) + local mxx = torch.Tensor() + torch.range(mxx,0,1) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.range value') +end +function torchtest.randperm() + local t=os.time() + torch.manualSeed(t) + local mx = torch.randperm(msize) + local mxx = torch.Tensor() + torch.manualSeed(t) + torch.randperm(mxx,msize) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.randperm value') +end +function torchtest.reshape() + local x = torch.rand(10,13,23) + local mx = torch.reshape(x,130,23) + local mxx = torch.Tensor() + torch.reshape(mxx,x,130,23) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.reshape value') +end +function torchtest.sort() + local x = torch.rand(msize,msize) + local mx,ix = torch.sort(x) + local mxx = torch.Tensor() + local ixx = torch.LongTensor() + torch.sort(mxx,ixx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.sort value') + mytester:asserteq(maxdiff(ix,ixx),0,'torch.sort index') +end +function torchtest.tril() + local x = torch.rand(msize,msize) + local mx = torch.tril(x) + local mxx = torch.Tensor() + torch.tril(mxx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.tril value') +end +function torchtest.triu() + local x = torch.rand(msize,msize) + local mx = torch.triu(x) + local mxx = torch.Tensor() + torch.triu(mxx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.tril value') +end +function torchtest.cat() + local x = torch.rand(13,msize,msize) + local y = torch.rand(17,msize,msize) + local mx = torch.cat(x,y,1) + local mxx = torch.Tensor() + torch.cat(mxx,x,y,1) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.cat value') +end +function torchtest.sin() + local x = torch.rand(msize,msize,msize) + local mx = torch.sin(x) + local mxx = torch.Tensor() + torch.sin(mxx,x) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.sin value') +end +function torchtest.linspace() + local from = math.random() + local to = from+math.random() + local mx = torch.linspace(from,to,137) + local mxx = torch.Tensor() + torch.linspace(mxx,from,to,137) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.linspace value') +end +function torchtest.logspace() + local from = math.random() + local to = from+math.random() + local mx = torch.logspace(from,to,137) + local mxx = torch.Tensor() + torch.logspace(mxx,from,to,137) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.logspace value') +end +function torchtest.rand() + torch.manualSeed(123456) + local mx = torch.rand(msize,msize) + local mxx = torch.Tensor() + torch.manualSeed(123456) + torch.rand(mxx,msize,msize) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.rand value') +end +function torchtest.randn() + torch.manualSeed(123456) + local mx = torch.randn(msize,msize) + local mxx = torch.Tensor() + torch.manualSeed(123456) + torch.randn(mxx,msize,msize) + mytester:asserteq(maxdiff(mx,mxx),0,'torch.randn value') +end +function torchtest.gesv() + if not torch.gesv then return end + local a=torch.Tensor({{6.80, -2.11, 5.66, 5.97, 8.23}, + {-6.05, -3.30, 5.36, -4.44, 1.08}, + {-0.45, 2.58, -2.70, 0.27, 9.04}, + {8.32, 2.71, 4.35, -7.17, 2.14}, + {-9.67, -5.14, -7.26, 6.08, -6.87}}):t() + local b=torch.Tensor({{4.02, 6.19, -8.22, -7.57, -3.03}, + {-1.56, 4.00, -8.67, 1.75, 2.86}, + {9.81, -4.09, -4.57, -8.61, 8.99}}):t() + local mx = torch.gesv(b,a) + mytester:assertlt(b:dist(a*mx),1e-12,'torch.gesv') + local ta = torch.Tensor() + local tb = torch.Tensor() + local mxx = torch.gesv(tb,ta,b,a) + local mxxx = torch.gesv(b,a,b,a) + mytester:asserteq(maxdiff(mx,tb),0,'torch.gesv value temp') + mytester:asserteq(maxdiff(mx,b),0,'torch.gesv value flag') + mytester:asserteq(maxdiff(mx,mxx),0,'torch.gesv value out1') + mytester:asserteq(maxdiff(mx,mxxx),0,'torch.gesv value out2') +end +function torchtest.gels() + if not torch.gels then return end + local a=torch.Tensor({{ 1.44, -9.96, -7.55, 8.34, 7.08, -5.45}, + {-7.84, -0.28, 3.24, 8.09, 2.52, -5.70}, + {-4.39, -3.24, 6.27, 5.28, 0.74, -1.19}, + {4.53, 3.83, -6.64, 2.06, -2.47, 4.70}}):t() + local b=torch.Tensor({{8.58, 8.26, 8.48, -5.28, 5.72, 8.93}, + {9.35, -4.43, -0.70, -0.26, -7.36, -2.52}}):t() + local mx = torch.gels(b,a) + local ta = torch.Tensor() + local tb = torch.Tensor() + local mxx = torch.gels(tb,ta,b,a) + local mxxx = torch.gels(b,a,b,a) + mytester:asserteq(maxdiff(mx,tb),0,'torch.gels value temp') + mytester:asserteq(maxdiff(mx,b),0,'torch.gels value flag') + mytester:asserteq(maxdiff(mx,mxx),0,'torch.gels value out1') + mytester:asserteq(maxdiff(mx,mxxx),0,'torch.gels value out2') +end +function torchtest.eig() + if not torch.eig then return end + local a=torch.Tensor({{ 1.96, 0.00, 0.00, 0.00, 0.00}, + {-6.49, 3.80, 0.00, 0.00, 0.00}, + {-0.47, -6.39, 4.17, 0.00, 0.00}, + {-7.20, 1.50, -1.51, 5.70, 0.00}, + {-0.65, -6.34, 2.67, 1.80, -7.10}}):t():clone() + local e = torch.eig(a) + local ee,vv = torch.eig(a,'V') + local te = torch.Tensor() + local tv = torch.Tensor() + local eee,vvv = torch.eig(te,tv,a,'V') + mytester:assertlt(maxdiff(e,ee),1e-12,'torch.eig value') + mytester:assertlt(maxdiff(ee,eee),1e-12,'torch.eig value') + mytester:assertlt(maxdiff(ee,te),1e-12,'torch.eig value') + mytester:assertlt(maxdiff(vv,vvv),1e-12,'torch.eig value') + mytester:assertlt(maxdiff(vv,tv),1e-12,'torch.eig value') +end +function torchtest.svd() + if not torch.svd then return end + local a=torch.Tensor({{8.79, 6.11, -9.15, 9.57, -3.49, 9.84}, + {9.93, 6.91, -7.93, 1.64, 4.02, 0.15}, + {9.83, 5.04, 4.86, 8.83, 9.80, -8.99}, + {5.45, -0.27, 4.85, 0.74, 10.00, -6.02}, + {3.16, 7.98, 3.01, 5.80, 4.27, -5.31}}):t():clone() + local u,s,v = torch.svd(a) + local uu = torch.Tensor() + local ss = torch.Tensor() + local vv = torch.Tensor() + uuu,sss,vvv = torch.svd(uu,ss,vv,a) + mytester:asserteq(maxdiff(u,uu),0,'torch.svd') + mytester:asserteq(maxdiff(u,uuu),0,'torch.svd') + mytester:asserteq(maxdiff(s,ss),0,'torch.svd') + mytester:asserteq(maxdiff(s,sss),0,'torch.svd') + mytester:asserteq(maxdiff(v,vv),0,'torch.svd') + mytester:asserteq(maxdiff(v,vvv),0,'torch.svd') +end + +function torchtest.conv2() + local x = torch.rand(math.floor(torch.uniform(50,100)),math.floor(torch.uniform(50,100))) + local k = torch.rand(math.floor(torch.uniform(10,20)),math.floor(torch.uniform(10,20))) + local imvc = torch.conv2(x,k) + local imvc2 = torch.conv2(x,k,'V') + local imfc = torch.conv2(x,k,'F') + + local ki = k:clone(); + local ks = k:storage() + local kis = ki:storage() + for i=ks:size(),1,-1 do kis[ks:size()-i+1]=ks[i] end + local imvx = torch.xcorr2(x,ki) + local imvx2 = torch.xcorr2(x,ki,'V') + local imfx = torch.xcorr2(x,ki,'F') + + mytester:asserteq(maxdiff(imvc,imvc2),0,'torch.conv2') + mytester:asserteq(maxdiff(imvc,imvx),0,'torch.conv2') + mytester:asserteq(maxdiff(imvc,imvx2),0,'torch.conv2') + mytester:asserteq(maxdiff(imfc,imfx),0,'torch.conv2') + mytester:assertlt(math.abs(x:dot(x)-torch.xcorr2(x,x)[1][1]),1e-10,'torch.conv2') + + local xx = torch.Tensor(2,x:size(1),x:size(2)) + xx[1]:copy(x) + xx[2]:copy(x) + local kk = torch.Tensor(2,k:size(1),k:size(2)) + kk[1]:copy(k) + kk[2]:copy(k) + + local immvc = torch.conv2(xx,kk) + local immvc2 = torch.conv2(xx,kk,'V') + local immfc = torch.conv2(xx,kk,'F') + + mytester:asserteq(maxdiff(immvc[1],immvc[2]),0,'torch.conv2') + mytester:asserteq(maxdiff(immvc[1],imvc),0,'torch.conv2') + mytester:asserteq(maxdiff(immvc2[1],imvc2),0,'torch.conv2') + mytester:asserteq(maxdiff(immfc[1],immfc[2]),0,'torch.conv2') + mytester:asserteq(maxdiff(immfc[1],imfc),0,'torch.conv2') +end + +function torchtest.conv3() + local x = torch.rand(math.floor(torch.uniform(20,40)), + math.floor(torch.uniform(20,40)), + math.floor(torch.uniform(20,40))) + local k = torch.rand(math.floor(torch.uniform(5,10)), + math.floor(torch.uniform(5,10)), + math.floor(torch.uniform(5,10))) + local imvc = torch.conv3(x,k) + local imvc2 = torch.conv3(x,k,'V') + local imfc = torch.conv3(x,k,'F') + + local ki = k:clone(); + local ks = k:storage() + local kis = ki:storage() + for i=ks:size(),1,-1 do kis[ks:size()-i+1]=ks[i] end + local imvx = torch.xcorr3(x,ki) + local imvx2 = torch.xcorr3(x,ki,'V') + local imfx = torch.xcorr3(x,ki,'F') + + mytester:asserteq(maxdiff(imvc,imvc2),0,'torch.conv3') + mytester:asserteq(maxdiff(imvc,imvx),0,'torch.conv3') + mytester:asserteq(maxdiff(imvc,imvx2),0,'torch.conv3') + mytester:asserteq(maxdiff(imfc,imfx),0,'torch.conv3') + mytester:assertlt(math.abs(x:dot(x)-torch.xcorr3(x,x)[1][1][1]),1e-10,'torch.conv3') + + local xx = torch.Tensor(2,x:size(1),x:size(2),x:size(3)) + xx[1]:copy(x) + xx[2]:copy(x) + local kk = torch.Tensor(2,k:size(1),k:size(2),k:size(3)) + kk[1]:copy(k) + kk[2]:copy(k) + + local immvc = torch.conv3(xx,kk) + local immvc2 = torch.conv3(xx,kk,'V') + local immfc = torch.conv3(xx,kk,'F') + + mytester:asserteq(maxdiff(immvc[1],immvc[2]),0,'torch.conv3') + mytester:asserteq(maxdiff(immvc[1],imvc),0,'torch.conv3') + mytester:asserteq(maxdiff(immvc2[1],imvc2),0,'torch.conv3') + mytester:asserteq(maxdiff(immfc[1],immfc[2]),0,'torch.conv3') + mytester:asserteq(maxdiff(immfc[1],imfc),0,'torch.conv3') +end + +function torch.test() + math.randomseed(os.time()) + mytester = torch.Tester() + mytester:add(torchtest) + mytester:run() +end diff --git a/torch.in b/torch.in new file mode 100644 index 00000000000..29bf2307f4c --- /dev/null +++ b/torch.in @@ -0,0 +1,43 @@ +#!/bin/bash + +# install prefix +PREFIX=@Torch_INSTALL_BIN@ + +# load torch.lua +INIT="require 'torch'; torch.include('torch','torch.lua')" + +# check special arguments +if [ $# -gt 0 ] +then + if [ $1 == 'install' ] + then + if [ $# -eq 2 ] + then + $PREFIX/lua -e "$INIT; install('$2')"; + else + $PREFIX/lua -e "$INIT; install('.')"; + fi + exit + fi +fi + +# try to run qlua, and default to lua if not available +# all the functions defined above are executed before +# returning to a user prompt +if [ -f $PREFIX/qlua ] +then + if [ $DISPLAY ] + then + echo "Try the IDE: torch -ide" + echo "Type help() for more info" + $PREFIX/qlua -e "$INIT" -i $*; + else + echo "Unable to connect X11 server (disabling graphics)" + echo "Type help() for more info" + $PREFIX/qlua -nographics -e "$INIT" -i $*; + fi +else + echo "Install Qt4 and rebuild Torch7 for graphics capability" + echo "Type help() for more info" + $PREFIX/lua -e "$INIT" -i $*; +fi diff --git a/torch.lua b/torch.lua new file mode 100644 index 00000000000..e6f4220330a --- /dev/null +++ b/torch.lua @@ -0,0 +1,177 @@ + +-- welcome message +print 'Torch 7.0 Copyright (C) 2001-2011 Idiap, NEC Labs, NYU' + +-- custom prompt +_PROMPT = 't7> ' +_PROMPT2 = '. > ' + +-- helper +local function sizestr(x) + local strt = {} + if x:nDimension() == 0 then + table.insert(strt, _G.torch.typename(x):match('torch%.(.+)') .. ' - empty') + else + table.insert(strt, _G.torch.typename(x):match('torch%.(.+)') .. ' - size: ') + for i=1,x:nDimension() do + table.insert(strt, x:size(i)) + if i ~= x:nDimension() then + table.insert(strt, 'x') + end + end + end + return table.concat(strt) +end + +-- k : name of variable +-- m : max length +local function printvar(key,val,m) + local name = '[' .. tostring(key) .. ']' + --io.write(name) + name = name .. string.rep(' ',m-name:len()+2) + local tp = type(val) + if tp == 'userdata' then + tp = torch.typename(val) or '' + if tp:find('torch.*Tensor') then + tp = sizestr(val) + elseif tp:find('torch.*Storage') then + tp = sizestr(val) + else + tp = tostring(val) + end + elseif tp == 'table' then + tp = tp .. ' - size: ' .. #val + elseif tp == 'string' then + local tostr = val:gsub('\n','\\n') + if #tostr>40 then + tostr = tostr:sub(1,40) .. '...' + end + tp = tp .. ' : "' .. tostr .. '"' + else + tp = tostring(val) + end + return name .. ' = ' .. tp +end + +-- helper +local function getmaxlen(vars) + local m = 0 + if type(vars) ~= 'table' then return tostring(vars):len() end + for k,v in pairs(vars) do + local s = tostring(k) + if s:len() > m then + m = s:len() + end + end + return m +end + +-- who: +-- a simple function that prints all the symbols defined by the user +-- very much like Matlab's who function +function who() + local m = getmaxlen(_G) + local p = _G._preloaded_ + local function printsymb(sys) + for k,v in pairs(_G) do + if (sys and p[k]) or (not sys and not p[k]) then + print(printvar(k,_G[k],m)) + end + end + end + print('== System Variables ==') + printsymb(true) + print('== User Variables ==') + printsymb(false) + print('==') +end + +print_old=print +_G._preloaded_ = {} +for k,v in pairs(_G) do + _G._preloaded_[k] = true +end + +-- print: +-- a smarter print for Lua, the default Lua print is quite terse +-- this new print is much more verbose, automatically recursing through +-- lua tables, and objects. +function print(obj,...) + local m = getmaxlen(obj) + if _G.type(obj) == 'table' then + local mt = _G.getmetatable(obj) + if mt and mt.__tostring__ then + _G.io.write(mt.__tostring__(obj)) + else + local tos = _G.tostring(obj) + local obj_w_usage = false + if tos and not _G.string.find(tos,'table: ') then + if obj.usage and _G.type(obj.usage) == 'string' then + _G.io.write(obj.usage) + _G.io.write('\n\nFIELDS:\n') + obj_w_usage = true + else + _G.io.write(tos .. ':\n') + end + end + _G.io.write('{') + local idx = 1 + local tab = '' + local newline = '' + for k,v in pairs(obj) do + local line = printvar(k,v,m) + _G.io.write(newline .. tab .. line) + if idx == 1 then + tab = ' ' + newline = '\n' + end + idx = idx + 1 + end + _G.io.write('}') + if obj_w_usage then + _G.io.write('') + end + end + else + _G.io.write(_G.tostring(obj)) + end + if _G.select('#',...) > 0 then + _G.io.write(' ') + print(...) + else + _G.io.write('\n') + end +end + +-- import: +-- this function is a python-like loader, it requires a module, +-- and then imports all its symbols globally +function import(package, forced) + require(package) + if _G[package] then + _G._torchimport = _G._torchimport or {} + _G._torchimport[package] = _G[package] + end + for k,v in pairs(_G[package]) do + if not _G[k] or forced then + _G[k] = v + end + end +end + +-- install module: +-- this function builds and install a specified module +function install(path) + path = paths.concat(paths.cwd(), path) + print('--> installing module ' .. path) + os.execute('mkdir ' .. paths.concat(path,'build') .. '; ' + .. 'cd ' .. paths.concat(path,'build') .. '; ' + .. 'cmake .. -DCMAKE_INSTALL_PREFIX=' .. paths.install_prefix .. '; ' + .. 'make install; cd .. ; rm -r build') + print('--> module installed') +end + +-- preload basic libraries +import 'torch' +import 'gnuplot' +import 'dok' diff --git a/utils.c b/utils.c new file mode 100644 index 00000000000..7d93b7a163a --- /dev/null +++ b/utils.c @@ -0,0 +1,135 @@ +#include "general.h" +#include "utils.h" + +#include + +static const void* torch_LongStorage_id = NULL; +static const void* torch_default_tensor_id = NULL; + +THLongStorage* torch_checklongargs(lua_State *L, int index) +{ + THLongStorage *storage; + int i; + int narg = lua_gettop(L)-index+1; + + if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id)) + { + THLongStorage *storagesrc = luaT_toudata(L, index, torch_LongStorage_id); + storage = THLongStorage_newWithSize(storagesrc->size); + THLongStorage_copy(storage, storagesrc); + } + else + { + storage = THLongStorage_newWithSize(narg); + for(i = index; i < index+narg; i++) + { + if(!lua_isnumber(L, i)) + { + THLongStorage_free(storage); + luaL_argerror(L, i, "number expected"); + } + THLongStorage_set(storage, i-index, lua_tonumber(L, i)); + } + } + return storage; +} + +int torch_islongargs(lua_State *L, int index) +{ + int narg = lua_gettop(L)-index+1; + + if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id)) + { + return 1; + } + else + { + int i; + + for(i = index; i < index+narg; i++) + { + if(!lua_isnumber(L, i)) + return 0; + } + return 1; + } + return 0; +} + + + +static int torch_lua_tic(lua_State* L) +{ + struct timeval tv; + gettimeofday(&tv,NULL); + double ttime = (double)tv.tv_sec + (double)(tv.tv_usec)/1000000.0; + lua_pushnumber(L,ttime); + return 1; +} + +static int torch_lua_toc(lua_State* L) +{ + struct timeval tv; + gettimeofday(&tv,NULL); + double toctime = (double)tv.tv_sec + (double)(tv.tv_usec)/1000000.0; + lua_Number tictime = luaL_checknumber(L,1); + lua_pushnumber(L,toctime-tictime); + return 1; +} + +static int torch_lua_setdefaulttensortype(lua_State *L) +{ + const void *id; + + luaL_checkstring(L, 1); + + if(!(id = luaT_typename2id(L, lua_tostring(L, 1)))) \ + return luaL_error(L, "<%s> is not a string describing a torch object", lua_tostring(L, 1)); \ + + torch_default_tensor_id = id; + + return 0; +} + +static int torch_lua_getdefaulttensortype(lua_State *L) +{ + lua_pushstring(L, luaT_id2typename(L, torch_default_tensor_id)); + return 1; +} + +void torch_setdefaulttensorid(const void* id) +{ + torch_default_tensor_id = id; +} + +const void* torch_getdefaulttensorid() +{ + return torch_default_tensor_id; +} + +static const struct luaL_Reg torch_utils__ [] = { + {"__setdefaulttensortype", torch_lua_setdefaulttensortype}, + {"getdefaulttensortype", torch_lua_getdefaulttensortype}, + {"tic", torch_lua_tic}, + {"toc", torch_lua_toc}, + {"factory", luaT_lua_factory}, + {"getconstructortable", luaT_lua_getconstructortable}, + {"id", luaT_lua_id}, + {"typename", luaT_lua_typename}, + {"typename2id", luaT_lua_typename2id}, + {"isequal", luaT_lua_isequal}, + {"getenv", luaT_lua_getenv}, + {"setenv", luaT_lua_setenv}, + {"newmetatable", luaT_lua_newmetatable}, + {"setmetatable", luaT_lua_setmetatable}, + {"getmetatable", luaT_lua_getmetatable}, + {"version", luaT_lua_version}, + {"pointer", luaT_lua_pointer}, + {NULL, NULL} +}; + +void torch_utils_init(lua_State *L) +{ + torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); + luaL_register(L, NULL, torch_utils__); +} diff --git a/utils.h b/utils.h new file mode 100644 index 00000000000..a00dcb8cebc --- /dev/null +++ b/utils.h @@ -0,0 +1,13 @@ +#ifndef TORCH_UTILS_INC +#define TORCH_UTILS_INC + +#include "luaT.h" +#include "TH.h" + +THLongStorage* torch_checklongargs(lua_State *L, int index); +int torch_islongargs(lua_State *L, int index); + +void torch_setdefaulttensorid(const void* id); +const void* torch_getdefaulttensorid(); + +#endif