mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
initial revamp of torch7 tree
This commit is contained in:
commit
053065ba23
18
CMakeLists.txt
Normal file
18
CMakeLists.txt
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
SET(src DiskFile.c File.c MemoryFile.c PipeFile.c Storage.c Tensor.c Timer.c utils.c init.c TensorOperator.c TensorMathWrap.c random.c)
|
||||
SET(luasrc init.lua File.lua Tensor.lua TensorMath.lua CmdLine.lua Tester.lua torch.lua test/test.lua)
|
||||
|
||||
# Necessary do generate wrapper
|
||||
ADD_TORCH_WRAP(tensormathwrap TensorMathWrap.lua)
|
||||
ADD_TORCH_WRAP(randomwrap random.lua)
|
||||
|
||||
ADD_TORCH_PACKAGE(torch "${src}" "${luasrc}" "Basics")
|
||||
ADD_TORCH_DOK(dok torch "Fundamentals" "Torch package" 1.1)
|
||||
|
||||
TARGET_LINK_LIBRARIES(torch luaT TH)
|
||||
|
||||
CONFIGURE_FILE(torch.in "${Torch_BINARY_DIR}/torch")
|
||||
INSTALL(FILES "${Torch_BINARY_DIR}/torch"
|
||||
DESTINATION "${Torch_INSTALL_BIN_SUBDIR}"
|
||||
PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ
|
||||
GROUP_EXECUTE GROUP_READ
|
||||
WORLD_EXECUTE WORLD_READ)
|
||||
244
CmdLine.lua
Normal file
244
CmdLine.lua
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
local CmdLine = torch.class('torch.CmdLine')
|
||||
|
||||
local function strip(str)
|
||||
return string.match(str, '%-*(.*)')
|
||||
end
|
||||
|
||||
local function pad(str, sz)
|
||||
return str .. string.rep(' ', sz-#str)
|
||||
end
|
||||
|
||||
function CmdLine:error(msg)
|
||||
print('')
|
||||
print(msg)
|
||||
print('')
|
||||
self:help()
|
||||
os.exit(0)
|
||||
end
|
||||
|
||||
function CmdLine:__readArgument__(params, arg, i, nArgument)
|
||||
local argument = self.arguments[nArgument]
|
||||
local value = arg[i]
|
||||
|
||||
if nArgument > #self.arguments then
|
||||
self:error('invalid argument: ' .. value)
|
||||
end
|
||||
if argument.type and type(value) ~= argument.type then
|
||||
self:error('invalid argument type for argument ' .. argument.key .. ' (should be ' .. argument.type .. ')')
|
||||
end
|
||||
params[strip(argument.key)] = value
|
||||
return 1
|
||||
end
|
||||
|
||||
function CmdLine:__readOption__(params, arg, i)
|
||||
local key = arg[i]
|
||||
local option = self.options[key]
|
||||
if not option then
|
||||
self:error('unknown option ' .. key)
|
||||
end
|
||||
|
||||
if option.type and option.type == 'boolean' then
|
||||
params[strip(key)] = not option.default
|
||||
return 1
|
||||
else
|
||||
local value = arg[i+1]
|
||||
if not value then
|
||||
self:error('missing argument for option ' .. key)
|
||||
end
|
||||
if not option.type or option.type == 'string' then
|
||||
elseif option.type == 'number' then
|
||||
value = tonumber(value)
|
||||
else
|
||||
self:error('unknown required option type ' .. option.type)
|
||||
end
|
||||
if not value then
|
||||
self:error('invalid type for option ' .. key .. ' (should be ' .. option.type .. ')')
|
||||
end
|
||||
params[strip(key)] = value
|
||||
return 2
|
||||
end
|
||||
end
|
||||
|
||||
function CmdLine:__init(argseparator_,keyseparator_)
|
||||
self.argseparator = argseparator_ or ','
|
||||
self.keyseparator = keyseparator_ or '='
|
||||
self.options = {}
|
||||
self.arguments = {}
|
||||
self.helplines = {}
|
||||
end
|
||||
|
||||
function CmdLine:argument(key, help, _type_)
|
||||
table.insert(self.arguments, {key=key, help=help, type=_type_})
|
||||
table.insert(self.helplines, self.arguments[#self.arguments])
|
||||
end
|
||||
|
||||
function CmdLine:option(key, default, help, _type_)
|
||||
if default == nil then
|
||||
error('option ' .. key .. ' has no default value')
|
||||
end
|
||||
_type_ = _type_ or type(default)
|
||||
if type(default) ~= _type_ then
|
||||
error('option ' .. key .. ' has wrong default type value')
|
||||
end
|
||||
self.options[key] = {key=key, default=default, help=help, type=_type_}
|
||||
table.insert(self.helplines, self.options[key])
|
||||
end
|
||||
|
||||
function CmdLine:default()
|
||||
local params = {}
|
||||
for option,v in pairs(self.options) do
|
||||
params[strip(option)] = v.default
|
||||
end
|
||||
return params
|
||||
end
|
||||
|
||||
function CmdLine:parse(arg)
|
||||
local i = 1
|
||||
local params = self:default()
|
||||
|
||||
local nArgument = 0
|
||||
|
||||
while i <= #arg do
|
||||
if arg[i] == '-help' or arg[i] == '-h' or arg[i] == '--help' then
|
||||
self:help(arg)
|
||||
os.exit(0)
|
||||
end
|
||||
|
||||
if self.options[arg[i]] then
|
||||
i = i + self:__readOption__(params, arg, i)
|
||||
else
|
||||
nArgument = nArgument + 1
|
||||
i = i + self:__readArgument__(params, arg, i, nArgument)
|
||||
end
|
||||
end
|
||||
|
||||
if nArgument ~= #self.arguments then
|
||||
self:error('not enough arguments')
|
||||
end
|
||||
|
||||
return params
|
||||
end
|
||||
|
||||
function CmdLine:string(prefix, params, ignore)
|
||||
local arguments = {}
|
||||
local options = {}
|
||||
prefix = prefix or ''
|
||||
|
||||
for k,v in pairs(params) do
|
||||
if ignore[k] then
|
||||
print('-- ignore option ' .. k)
|
||||
elseif self.options['-' .. k] then
|
||||
if v ~= self.options['-' .. k].default then
|
||||
if type(v) == 'boolean' then
|
||||
if v then
|
||||
v = 't'
|
||||
else
|
||||
v = 'f'
|
||||
end
|
||||
end
|
||||
table.insert(options, k .. self.keyseparator .. v)
|
||||
print(k,v,self.options['-' .. k].default)
|
||||
end
|
||||
else
|
||||
local narg
|
||||
for i=1,#self.arguments do
|
||||
if strip(self.arguments[i].key) == k then
|
||||
narg = i
|
||||
end
|
||||
end
|
||||
if narg then
|
||||
arguments[narg] = k .. self.keyseparator .. v
|
||||
else
|
||||
print('WARNING: unknown option/argument: ' .. k .. ' IGNORING for DIRECTORY NAME')
|
||||
end
|
||||
end
|
||||
end
|
||||
table.sort(options)
|
||||
local str = table.concat(arguments, self.argseparator)
|
||||
if str == '' then
|
||||
str = table.concat(options, self.argseparator)
|
||||
else
|
||||
str = str .. self.argseparator .. table.concat(options, self.argseparator)
|
||||
end
|
||||
if str == '' then
|
||||
return prefix
|
||||
else
|
||||
return prefix .. self.argseparator .. str
|
||||
end
|
||||
end
|
||||
|
||||
function CmdLine:log(file, params)
|
||||
local f = io.open(file, 'w')
|
||||
local oprint = print
|
||||
function print(...)
|
||||
local n = select("#", ...)
|
||||
oprint(...)
|
||||
for i=1,n do
|
||||
f:write(tostring(select(i, ...)))
|
||||
if i ~= n then
|
||||
f:write(' ')
|
||||
else
|
||||
f:write('\n')
|
||||
end
|
||||
end
|
||||
f:flush()
|
||||
end
|
||||
print('[program started on ' .. os.date() .. ']')
|
||||
print('[command line arguments]')
|
||||
if params then
|
||||
for k,v in pairs(params) do
|
||||
print(k,v)
|
||||
end
|
||||
end
|
||||
print('[----------------------]')
|
||||
end
|
||||
|
||||
function CmdLine:text(txt)
|
||||
txt = txt or ''
|
||||
assert(type(txt) == 'string')
|
||||
table.insert(self.helplines, txt)
|
||||
end
|
||||
|
||||
function CmdLine:help(arg)
|
||||
io.write('Usage: ')
|
||||
if arg then io.write(arg[0] .. ' ') end
|
||||
io.write('[options] ')
|
||||
for i=1,#self.arguments do
|
||||
io.write('<' .. strip(self.arguments[i].key) .. '>')
|
||||
end
|
||||
io.write('\n')
|
||||
|
||||
-- first pass to compute max length
|
||||
local optsz = 0
|
||||
for _,option in ipairs(self.helplines) do
|
||||
if type(option) == 'table' then
|
||||
if option.default ~= nil then -- it is an option
|
||||
if #option.key > optsz then
|
||||
optsz = #option.key
|
||||
end
|
||||
else -- it is an argument
|
||||
if #strip(option.key)+2 > optsz then
|
||||
optsz = #strip(option.key)+2
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- second pass to print
|
||||
for _,option in ipairs(self.helplines) do
|
||||
if type(option) == 'table' then
|
||||
io.write(' ')
|
||||
if option.default ~= nil then -- it is an option
|
||||
io.write(pad(option.key, optsz))
|
||||
if option.help then io.write(' ' .. option.help) end
|
||||
io.write(' [' .. tostring(option.default) .. ']')
|
||||
else -- it is an argument
|
||||
io.write(pad('<' .. strip(option.key) .. '>', optsz))
|
||||
if option.help then io.write(' ' .. option.help) end
|
||||
end
|
||||
else
|
||||
io.write(option) -- just some additional help
|
||||
end
|
||||
io.write('\n')
|
||||
end
|
||||
end
|
||||
87
DiskFile.c
Normal file
87
DiskFile.c
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
#include "general.h"
|
||||
|
||||
static const void* torch_DiskFile_id = NULL;
|
||||
|
||||
static int torch_DiskFile_new(lua_State *L)
|
||||
{
|
||||
const char *name = luaL_checkstring(L, 1);
|
||||
const char *mode = luaL_optstring(L, 2, "r");
|
||||
int isQuiet = luaT_optboolean(L, 3, 0);
|
||||
THFile *self = THDiskFile_new(name, mode, isQuiet);
|
||||
|
||||
luaT_pushudata(L, self, torch_DiskFile_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_DiskFile_free(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id);
|
||||
THFile_free(self);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int torch_DiskFile_isLittleEndianCPU(lua_State *L)
|
||||
{
|
||||
lua_pushboolean(L, THDiskFile_isLittleEndianCPU());
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_DiskFile_isBigEndianCPU(lua_State *L)
|
||||
{
|
||||
lua_pushboolean(L, !THDiskFile_isLittleEndianCPU());
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_DiskFile_nativeEndianEncoding(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id);
|
||||
THDiskFile_nativeEndianEncoding(self);
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_DiskFile_littleEndianEncoding(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id);
|
||||
THDiskFile_littleEndianEncoding(self);
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_DiskFile_bigEndianEncoding(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id);
|
||||
THDiskFile_bigEndianEncoding(self);
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_DiskFile___tostring__(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_DiskFile_id);
|
||||
lua_pushfstring(L, "torch.DiskFile on <%s> [status: %s -- mode %c%c]",
|
||||
THDiskFile_name(self),
|
||||
(THFile_isOpened(self) ? "open" : "closed"),
|
||||
(THFile_isReadable(self) ? 'r' : ' '),
|
||||
(THFile_isWritable(self) ? 'w' : ' '));
|
||||
|
||||
return 1;
|
||||
}
|
||||
static const struct luaL_Reg torch_DiskFile__ [] = {
|
||||
{"isLittleEndianCPU", torch_DiskFile_isLittleEndianCPU},
|
||||
{"isBigEndianCPU", torch_DiskFile_isBigEndianCPU},
|
||||
{"nativeEndianEncoding", torch_DiskFile_nativeEndianEncoding},
|
||||
{"littleEndianEncoding", torch_DiskFile_littleEndianEncoding},
|
||||
{"bigEndianEncoding", torch_DiskFile_bigEndianEncoding},
|
||||
{"__tostring__", torch_DiskFile___tostring__},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_DiskFile_init(lua_State *L)
|
||||
{
|
||||
torch_DiskFile_id = luaT_newmetatable(L, "torch.DiskFile", "torch.File",
|
||||
torch_DiskFile_new, torch_DiskFile_free, NULL);
|
||||
|
||||
luaL_register(L, NULL, torch_DiskFile__);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
225
File.c
Normal file
225
File.c
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
#include "THFile.h"
|
||||
#include "luaT.h"
|
||||
|
||||
static const void *torch_File_id = NULL;
|
||||
static const void *torch_ByteStorage_id = NULL;
|
||||
static const void *torch_CharStorage_id = NULL;
|
||||
static const void *torch_ShortStorage_id = NULL;
|
||||
static const void *torch_IntStorage_id = NULL;
|
||||
static const void *torch_LongStorage_id = NULL;
|
||||
static const void *torch_FloatStorage_id = NULL;
|
||||
static const void *torch_DoubleStorage_id = NULL;
|
||||
|
||||
#define IMPLEMENT_TORCH_FILE_FLAG(NAME) \
|
||||
static int torch_File_##NAME(lua_State *L) \
|
||||
{ \
|
||||
THFile *self = luaT_checkudata(L, 1, torch_File_id); \
|
||||
lua_pushboolean(L, THFile_##NAME(self)); \
|
||||
return 1; \
|
||||
}
|
||||
|
||||
IMPLEMENT_TORCH_FILE_FLAG(isQuiet)
|
||||
IMPLEMENT_TORCH_FILE_FLAG(isReadable)
|
||||
IMPLEMENT_TORCH_FILE_FLAG(isWritable)
|
||||
IMPLEMENT_TORCH_FILE_FLAG(isBinary)
|
||||
IMPLEMENT_TORCH_FILE_FLAG(isAutoSpacing)
|
||||
IMPLEMENT_TORCH_FILE_FLAG(hasError)
|
||||
|
||||
#define IMPLEMENT_TORCH_FILE_FUNC(NAME) \
|
||||
static int torch_File_##NAME(lua_State *L) \
|
||||
{ \
|
||||
THFile *self = luaT_checkudata(L, 1, torch_File_id); \
|
||||
THFile_##NAME(self); \
|
||||
lua_settop(L, 1); \
|
||||
return 1; \
|
||||
}
|
||||
|
||||
IMPLEMENT_TORCH_FILE_FUNC(binary)
|
||||
IMPLEMENT_TORCH_FILE_FUNC(ascii)
|
||||
IMPLEMENT_TORCH_FILE_FUNC(autoSpacing)
|
||||
IMPLEMENT_TORCH_FILE_FUNC(noAutoSpacing)
|
||||
IMPLEMENT_TORCH_FILE_FUNC(quiet)
|
||||
IMPLEMENT_TORCH_FILE_FUNC(pedantic)
|
||||
IMPLEMENT_TORCH_FILE_FUNC(clearError)
|
||||
|
||||
IMPLEMENT_TORCH_FILE_FUNC(synchronize)
|
||||
|
||||
static int torch_File_seek(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_File_id);
|
||||
long position = luaL_checklong(L, 2)-1;
|
||||
THFile_seek(self, position);
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
IMPLEMENT_TORCH_FILE_FUNC(seekEnd)
|
||||
|
||||
static int torch_File_position(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_File_id);
|
||||
lua_pushnumber(L, THFile_position(self)+1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
IMPLEMENT_TORCH_FILE_FUNC(close)
|
||||
|
||||
#define IMPLEMENT_TORCH_FILE_RW(TYPEC, TYPE) \
|
||||
static int torch_File_read##TYPEC(lua_State *L) \
|
||||
{ \
|
||||
THFile *self = luaT_checkudata(L, 1, torch_File_id); \
|
||||
int narg = lua_gettop(L); \
|
||||
\
|
||||
if(narg == 1) \
|
||||
{ \
|
||||
lua_pushnumber(L, THFile_read##TYPEC##Scalar(self)); \
|
||||
return 1; \
|
||||
} \
|
||||
else if(narg == 2) \
|
||||
{ \
|
||||
if(lua_isnumber(L, 2)) \
|
||||
{ \
|
||||
long size = lua_tonumber(L, 2); \
|
||||
long nread; \
|
||||
\
|
||||
TH##TYPEC##Storage *storage = TH##TYPEC##Storage_newWithSize(size); \
|
||||
luaT_pushudata(L, storage, torch_##TYPEC##Storage_id); \
|
||||
nread = THFile_read##TYPEC(self, storage); \
|
||||
if(nread != size) \
|
||||
TH##TYPEC##Storage_resize(storage, size); \
|
||||
return 1; \
|
||||
} \
|
||||
else if(luaT_toudata(L, 2, torch_##TYPEC##Storage_id)) \
|
||||
{ \
|
||||
TH##TYPEC##Storage *storage = luaT_toudata(L, 2, torch_##TYPEC##Storage_id); \
|
||||
lua_pushnumber(L, THFile_read##TYPEC(self, storage)); \
|
||||
return 1; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
luaL_error(L, "nothing, number, or Storage expected"); \
|
||||
return 0; \
|
||||
} \
|
||||
\
|
||||
static int torch_File_write##TYPEC(lua_State *L) \
|
||||
{ \
|
||||
THFile *self = luaT_checkudata(L, 1, torch_File_id); \
|
||||
int narg = lua_gettop(L); \
|
||||
\
|
||||
if(narg == 2) \
|
||||
{ \
|
||||
if(lua_isnumber(L, 2)) \
|
||||
{ \
|
||||
TYPE value = lua_tonumber(L, 2); \
|
||||
THFile_write##TYPEC##Scalar(self, (TYPE)value); \
|
||||
return 0; \
|
||||
} \
|
||||
else if(luaT_toudata(L, 2, torch_##TYPEC##Storage_id)) \
|
||||
{ \
|
||||
TH##TYPEC##Storage *storage = luaT_toudata(L, 2, torch_##TYPEC##Storage_id); \
|
||||
lua_pushnumber(L, THFile_write##TYPEC(self, storage)); \
|
||||
return 1; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
luaL_error(L, "number, or Storage expected"); \
|
||||
return 0; \
|
||||
}
|
||||
|
||||
|
||||
IMPLEMENT_TORCH_FILE_RW(Byte, unsigned char)
|
||||
IMPLEMENT_TORCH_FILE_RW(Char, char)
|
||||
IMPLEMENT_TORCH_FILE_RW(Short, short)
|
||||
IMPLEMENT_TORCH_FILE_RW(Int, int)
|
||||
IMPLEMENT_TORCH_FILE_RW(Long, long)
|
||||
IMPLEMENT_TORCH_FILE_RW(Float, float)
|
||||
IMPLEMENT_TORCH_FILE_RW(Double, double)
|
||||
|
||||
static int torch_File_readString(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_File_id);
|
||||
const char *format = luaL_checkstring(L, 2);
|
||||
char *str;
|
||||
long size;
|
||||
|
||||
size = THFile_readStringRaw(self, format, &str);
|
||||
lua_pushlstring(L, str, size);
|
||||
THFree(str);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_File_writeString(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_File_id);
|
||||
const char *str = NULL;
|
||||
size_t size;
|
||||
long nwrite;
|
||||
|
||||
luaL_checktype(L, 2, LUA_TSTRING);
|
||||
str = lua_tolstring(L, 2, &size);
|
||||
lua_pushnumber(L, THFile_writeStringRaw(self, str, (long)size));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_File__ [] = {
|
||||
{"isQuiet", torch_File_isQuiet},
|
||||
{"isReadable", torch_File_isReadable},
|
||||
{"isWritable", torch_File_isWritable},
|
||||
{"isBinary", torch_File_isBinary},
|
||||
{"isAutoSpacing", torch_File_isAutoSpacing},
|
||||
{"hasError", torch_File_hasError},
|
||||
{"binary", torch_File_binary},
|
||||
{"ascii", torch_File_ascii},
|
||||
{"autoSpacing", torch_File_autoSpacing},
|
||||
{"noAutoSpacing", torch_File_noAutoSpacing},
|
||||
{"quiet", torch_File_quiet},
|
||||
{"pedantic", torch_File_pedantic},
|
||||
{"clearError", torch_File_clearError},
|
||||
|
||||
/* DEBUG: CHECK DISK FREE & READ/WRITE STRING*/
|
||||
|
||||
{"readByte", torch_File_readByte},
|
||||
{"readChar", torch_File_readChar},
|
||||
{"readShort", torch_File_readShort},
|
||||
{"readInt", torch_File_readInt},
|
||||
{"readLong", torch_File_readLong},
|
||||
{"readFloat", torch_File_readFloat},
|
||||
{"readDouble", torch_File_readDouble},
|
||||
{"readString", torch_File_readString},
|
||||
|
||||
{"writeByte", torch_File_writeByte},
|
||||
{"writeChar", torch_File_writeChar},
|
||||
{"writeShort", torch_File_writeShort},
|
||||
{"writeInt", torch_File_writeInt},
|
||||
{"writeLong", torch_File_writeLong},
|
||||
{"writeFloat", torch_File_writeFloat},
|
||||
{"writeDouble", torch_File_writeDouble},
|
||||
{"writeString", torch_File_writeString},
|
||||
|
||||
{"synchronize", torch_File_synchronize},
|
||||
{"seek", torch_File_seek},
|
||||
{"seekEnd", torch_File_seekEnd},
|
||||
{"position", torch_File_position},
|
||||
{"close", torch_File_close},
|
||||
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_File_init(lua_State *L)
|
||||
{
|
||||
torch_File_id = luaT_newmetatable(L, "torch.File", NULL, NULL, NULL, NULL);
|
||||
luaL_register(L, NULL, torch_File__);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
|
||||
void torch_File_init_storage_id(lua_State *L)
|
||||
{
|
||||
torch_ByteStorage_id = luaT_checktypename2id(L, "torch.ByteStorage");
|
||||
torch_CharStorage_id = luaT_checktypename2id(L, "torch.CharStorage");
|
||||
torch_ShortStorage_id = luaT_checktypename2id(L, "torch.ShortStorage");
|
||||
torch_IntStorage_id = luaT_checktypename2id(L, "torch.IntStorage");
|
||||
torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
||||
torch_FloatStorage_id = luaT_checktypename2id(L, "torch.FloatStorage");
|
||||
torch_DoubleStorage_id = luaT_checktypename2id(L, "torch.DoubleStorage");
|
||||
}
|
||||
240
File.lua
Normal file
240
File.lua
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
local File = torch.getmetatable('torch.File')
|
||||
|
||||
function File:writeBool(value)
|
||||
if value then
|
||||
self:writeInt(1)
|
||||
else
|
||||
self:writeInt(0)
|
||||
end
|
||||
end
|
||||
|
||||
function File:readBool()
|
||||
return (self:readInt() == 1)
|
||||
end
|
||||
|
||||
local TYPE_NIL = 0
|
||||
local TYPE_NUMBER = 1
|
||||
local TYPE_STRING = 2
|
||||
local TYPE_TABLE = 3
|
||||
local TYPE_TORCH = 4
|
||||
local TYPE_BOOLEAN = 5
|
||||
local TYPE_FUNCTION = 6
|
||||
|
||||
function File:isWritableObject(object)
|
||||
local typename = type(object)
|
||||
local typeidx
|
||||
if type(object) ~= 'boolean' and not object then
|
||||
typeidx = TYPE_NIL
|
||||
elseif torch.typename(object) and torch.factory(torch.typename(object)) then
|
||||
typeidx = TYPE_TORCH
|
||||
elseif typename == 'table' then
|
||||
typeidx = TYPE_TABLE
|
||||
elseif typename == 'number' then
|
||||
typeidx = TYPE_NUMBER
|
||||
elseif typename == 'string' then
|
||||
typeidx = TYPE_STRING
|
||||
elseif typename == 'boolean' then
|
||||
typeidx = TYPE_BOOLEAN
|
||||
elseif typename == 'function' and pcall(string.dump, object) then
|
||||
typeidx = TYPE_FUNCTION
|
||||
end
|
||||
return typeidx
|
||||
end
|
||||
|
||||
function File:writeObject(object)
|
||||
-- we use an environment to keep a record of written objects
|
||||
if not torch.getenv(self).writeObjects then
|
||||
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}})
|
||||
end
|
||||
|
||||
-- if nil object, only write the type and return
|
||||
if type(object) ~= 'boolean' and not object then
|
||||
self:writeInt(TYPE_NIL)
|
||||
return
|
||||
end
|
||||
|
||||
-- check the type we are dealing with
|
||||
local typeidx = self:isWritableObject(object)
|
||||
if not typeidx then
|
||||
error(string.format('Unwritable object <%s>', type(object)))
|
||||
end
|
||||
self:writeInt(typeidx)
|
||||
|
||||
if typeidx == TYPE_NUMBER then
|
||||
self:writeDouble(object)
|
||||
elseif typeidx == TYPE_BOOLEAN then
|
||||
self:writeBool(object)
|
||||
elseif typeidx == TYPE_STRING then
|
||||
local stringStorage = torch.CharStorage():string(object)
|
||||
self:writeInt(#stringStorage)
|
||||
self:writeChar(stringStorage)
|
||||
elseif typeidx == TYPE_FUNCTION then
|
||||
local upvalues = {}
|
||||
while true do
|
||||
local name,value = debug.getupvalue(object, #upvalues+1)
|
||||
if not name then break end
|
||||
table.insert(upvalues, value)
|
||||
end
|
||||
local dumped = string.dump(object)
|
||||
local stringStorage = torch.CharStorage():string(dumped)
|
||||
self:writeInt(#stringStorage)
|
||||
self:writeChar(stringStorage)
|
||||
self:writeObject(upvalues)
|
||||
elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE then
|
||||
-- check it exists already (we look at the pointer!)
|
||||
local objects = torch.getenv(self).writeObjects
|
||||
local objectsRef = torch.getenv(self).writeObjectsRef
|
||||
local index = objects[torch.pointer(object)]
|
||||
|
||||
if index then
|
||||
-- if already exists, write only its index
|
||||
self:writeInt(index)
|
||||
else
|
||||
-- else write the object itself
|
||||
index = objects.nWriteObject or 0
|
||||
index = index + 1
|
||||
objects[torch.pointer(object)] = index
|
||||
objectsRef[object] = index -- we make sure the object is not going to disappear
|
||||
self:writeInt(index)
|
||||
objects.nWriteObject = index
|
||||
|
||||
if typeidx == TYPE_TORCH then
|
||||
local version = torch.CharStorage():string('V ' .. torch.version(object))
|
||||
local className = torch.CharStorage():string(torch.typename(object))
|
||||
self:writeInt(#version)
|
||||
self:writeChar(version)
|
||||
self:writeInt(#className)
|
||||
self:writeChar(className)
|
||||
if object.write then
|
||||
object:write(self)
|
||||
elseif type(object) == 'table' then
|
||||
local var = {}
|
||||
for k,v in pairs(object) do
|
||||
if self:isWritableObject(v) then
|
||||
var[k] = v
|
||||
else
|
||||
print(string.format('$ Warning: cannot write object field <%s>', k))
|
||||
end
|
||||
end
|
||||
self:writeObject(var)
|
||||
else
|
||||
error(string.format('<%s> is a non-serializable Torch object', torch.typename(object)))
|
||||
end
|
||||
else -- it is a table
|
||||
local size = 0; for k,v in pairs(object) do size = size + 1 end
|
||||
self:writeInt(size)
|
||||
for k,v in pairs(object) do
|
||||
self:writeObject(k)
|
||||
self:writeObject(v)
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
error('Unwritable object')
|
||||
end
|
||||
end
|
||||
|
||||
function File:readObject()
|
||||
-- we use an environment to keep a record of read objects
|
||||
if not torch.getenv(self).writeObjects then
|
||||
torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}})
|
||||
end
|
||||
|
||||
-- read the typeidx
|
||||
local typeidx = self:readInt()
|
||||
|
||||
-- is it nil?
|
||||
if typeidx == TYPE_NIL then
|
||||
return nil
|
||||
end
|
||||
|
||||
if typeidx == TYPE_NUMBER then
|
||||
return self:readDouble()
|
||||
elseif typeidx == TYPE_BOOLEAN then
|
||||
return self:readBool()
|
||||
elseif typeidx == TYPE_STRING then
|
||||
local size = self:readInt()
|
||||
return self:readChar(size):string()
|
||||
elseif typeidx == TYPE_FUNCTION then
|
||||
local size = self:readInt()
|
||||
local dumped = self:readChar(size):string()
|
||||
local func = loadstring(dumped)
|
||||
local upvalues = self:readObject()
|
||||
for index,upvalue in ipairs(upvalues) do
|
||||
debug.setupvalue(func, index, upvalue)
|
||||
end
|
||||
return func
|
||||
elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH then
|
||||
-- read the index
|
||||
local index = self:readInt()
|
||||
|
||||
-- check it is loaded already
|
||||
local objects = torch.getenv(self).readObjects
|
||||
if objects[index] then
|
||||
return objects[index]
|
||||
end
|
||||
|
||||
-- otherwise read it
|
||||
if typeidx == TYPE_TORCH then
|
||||
local version, className, versionNumber
|
||||
version = self:readChar(self:readInt()):string()
|
||||
versionNumber = tonumber(string.match(version, '^V (.*)$'))
|
||||
if not versionNumber then
|
||||
className = version
|
||||
versionNumber = 0 -- file created before existence of versioning system
|
||||
else
|
||||
className = self:readChar(self:readInt()):string()
|
||||
end
|
||||
if not torch.factory(className) then
|
||||
error(string.format('unknown Torch class <%s>' .. className))
|
||||
end
|
||||
local object = torch.factory(className)()
|
||||
objects[index] = object
|
||||
if object.read then
|
||||
object:read(self, versionNumber)
|
||||
elseif type(object) == 'table' then
|
||||
local var = self:readObject(var)
|
||||
for k,v in pairs(var) do
|
||||
object[k] = v
|
||||
end
|
||||
else
|
||||
error(string.format('Cannot load object class <%s>', className))
|
||||
end
|
||||
return object
|
||||
else -- it is a table
|
||||
local size = self:readInt()
|
||||
local object = {}
|
||||
objects[index] = object
|
||||
for i = 1,size do
|
||||
local k = self:readObject()
|
||||
local v = self:readObject()
|
||||
object[k] = v
|
||||
end
|
||||
return object
|
||||
end
|
||||
else
|
||||
error('unknown object')
|
||||
end
|
||||
end
|
||||
|
||||
-- simple helpers to save/load arbitrary objects/tables
|
||||
function torch.save(filename, object, mode)
|
||||
mode = mode or 'binary'
|
||||
local file = torch.DiskFile(filename, 'w')
|
||||
file[mode](file)
|
||||
file:writeObject(object)
|
||||
file:close()
|
||||
end
|
||||
|
||||
function torch.load(filename, mode)
|
||||
mode = mode or 'binary'
|
||||
local file = torch.DiskFile(filename, 'r')
|
||||
file[mode](file)
|
||||
local object = file:readObject()
|
||||
file:close()
|
||||
return object
|
||||
end
|
||||
|
||||
-- public API (saveobj/loadobj are safe for global import)
|
||||
torch.saveobj = torch.save
|
||||
torch.loadobj = torch.load
|
||||
67
MemoryFile.c
Normal file
67
MemoryFile.c
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#include "general.h"
|
||||
|
||||
static const void* torch_MemoryFile_id;
|
||||
static const void* torch_CharStorage_id;
|
||||
|
||||
static int torch_MemoryFile_new(lua_State *L)
|
||||
{
|
||||
const char *mode;
|
||||
THCharStorage *storage = luaT_toudata(L, 1, torch_CharStorage_id);
|
||||
THFile *self;
|
||||
|
||||
if(storage)
|
||||
{
|
||||
mode = luaL_optstring(L, 2, "rw");
|
||||
self = THMemoryFile_newWithStorage(storage, mode);
|
||||
}
|
||||
else
|
||||
{
|
||||
mode = luaL_optstring(L, 1, "rw");
|
||||
self = THMemoryFile_new(mode);
|
||||
}
|
||||
|
||||
luaT_pushudata(L, self, torch_MemoryFile_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_MemoryFile_storage(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id);
|
||||
THCharStorage_retain(THMemoryFile_storage(self));
|
||||
luaT_pushudata(L, THMemoryFile_storage(self), torch_CharStorage_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_MemoryFile_free(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id);
|
||||
THFile_free(self);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int torch_MemoryFile___tostring__(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_MemoryFile_id);
|
||||
lua_pushfstring(L, "torch.MemoryFile [status: %s -- mode: %c%c]",
|
||||
(THFile_isOpened(self) ? "open" : "closed"),
|
||||
(THFile_isReadable(self) ? 'r' : ' '),
|
||||
(THFile_isWritable(self) ? 'w' : ' '));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_MemoryFile__ [] = {
|
||||
{"storage", torch_MemoryFile_storage},
|
||||
{"__tostring__", torch_MemoryFile___tostring__},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_MemoryFile_init(lua_State *L)
|
||||
{
|
||||
torch_CharStorage_id = luaT_checktypename2id(L, "torch.CharStorage");
|
||||
|
||||
torch_MemoryFile_id = luaT_newmetatable(L, "torch.MemoryFile", "torch.File",
|
||||
torch_MemoryFile_new, torch_MemoryFile_free, NULL);
|
||||
|
||||
luaL_register(L, NULL, torch_MemoryFile__);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
46
PipeFile.c
Normal file
46
PipeFile.c
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
#include "general.h"
|
||||
|
||||
static const void* torch_PipeFile_id = NULL;
|
||||
|
||||
static int torch_PipeFile_new(lua_State *L)
|
||||
{
|
||||
const char *name = luaL_checkstring(L, 1);
|
||||
const char *mode = luaL_optstring(L, 2, "r");
|
||||
int isQuiet = luaT_optboolean(L, 3, 0);
|
||||
THFile *self = THPipeFile_new(name, mode, isQuiet);
|
||||
|
||||
luaT_pushudata(L, self, torch_PipeFile_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_PipeFile_free(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_PipeFile_id);
|
||||
THFile_free(self);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int torch_PipeFile___tostring__(lua_State *L)
|
||||
{
|
||||
THFile *self = luaT_checkudata(L, 1, torch_PipeFile_id);
|
||||
lua_pushfstring(L, "torch.PipeFile on <%s> [status: %s -- mode: %c%c]",
|
||||
THDiskFile_name(self),
|
||||
(THFile_isOpened(self) ? "open" : "closed"),
|
||||
(THFile_isReadable(self) ? 'r' : ' '),
|
||||
(THFile_isWritable(self) ? 'w' : ' '));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_PipeFile__ [] = {
|
||||
{"__tostring__", torch_PipeFile___tostring__},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_PipeFile_init(lua_State *L)
|
||||
{
|
||||
torch_PipeFile_id = luaT_newmetatable(L, "torch.PipeFile", "torch.DiskFile",
|
||||
torch_PipeFile_new, torch_PipeFile_free, NULL);
|
||||
|
||||
luaL_register(L, NULL, torch_PipeFile__);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
19
Storage.c
Normal file
19
Storage.c
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
#include "general.h"
|
||||
|
||||
static const void *torch_File_id = NULL;
|
||||
static const void *torch_ByteStorage_id = NULL;
|
||||
static const void *torch_CharStorage_id = NULL;
|
||||
static const void *torch_ShortStorage_id = NULL;
|
||||
static const void *torch_IntStorage_id = NULL;
|
||||
static const void *torch_LongStorage_id = NULL;
|
||||
static const void *torch_FloatStorage_id = NULL;
|
||||
static const void *torch_DoubleStorage_id = NULL;
|
||||
|
||||
#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME)
|
||||
#define torch_Storage_id TH_CONCAT_3(torch_,Real,Storage_id)
|
||||
#define THFile_readRealRaw TH_CONCAT_3(THFile_read, Real, Raw)
|
||||
#define THFile_writeRealRaw TH_CONCAT_3(THFile_write, Real, Raw)
|
||||
#define STRING_torchStorage TH_CONCAT_STRING_3(torch.,Real,Storage)
|
||||
|
||||
#include "generic/Storage.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
29
Tensor.c
Normal file
29
Tensor.c
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#include "general.h"
|
||||
|
||||
static const void *torch_File_id = NULL;
|
||||
|
||||
static const void *torch_ByteStorage_id = NULL;
|
||||
static const void *torch_CharStorage_id = NULL;
|
||||
static const void *torch_ShortStorage_id = NULL;
|
||||
static const void *torch_IntStorage_id = NULL;
|
||||
static const void *torch_LongStorage_id = NULL;
|
||||
static const void *torch_FloatStorage_id = NULL;
|
||||
static const void *torch_DoubleStorage_id = NULL;
|
||||
|
||||
static const void *torch_ByteTensor_id = NULL;
|
||||
static const void *torch_CharTensor_id = NULL;
|
||||
static const void *torch_ShortTensor_id = NULL;
|
||||
static const void *torch_IntTensor_id = NULL;
|
||||
static const void *torch_LongTensor_id = NULL;
|
||||
static const void *torch_FloatTensor_id = NULL;
|
||||
static const void *torch_DoubleTensor_id = NULL;
|
||||
|
||||
#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME)
|
||||
#define torch_Storage_id TH_CONCAT_3(torch_,Real,Storage_id)
|
||||
#define STRING_torchStorage TH_CONCAT_STRING_3(torch.,Real,Storage)
|
||||
#define torch_Tensor_(NAME) TH_CONCAT_4(torch_,Real,Tensor_,NAME)
|
||||
#define torch_Tensor_id TH_CONCAT_3(torch_,Real,Tensor_id)
|
||||
#define STRING_torchTensor TH_CONCAT_STRING_3(torch.,Real,Tensor)
|
||||
|
||||
#include "generic/Tensor.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
279
Tensor.lua
Normal file
279
Tensor.lua
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
-- additional methods for Storage
|
||||
local Storage = {}
|
||||
|
||||
-- additional methods for Tensor
|
||||
local Tensor = {}
|
||||
|
||||
-- types
|
||||
local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'}
|
||||
|
||||
-- tostring() functions for Tensor and Storage
|
||||
local function Storage__printformat(self)
|
||||
local intMode = true
|
||||
local type = torch.typename(self)
|
||||
-- if type == 'torch.FloatStorage' or type == 'torch.DoubleStorage' then
|
||||
for i=1,self:size() do
|
||||
if self[i] ~= math.ceil(self[i]) then
|
||||
intMode = false
|
||||
break
|
||||
end
|
||||
end
|
||||
-- end
|
||||
local tensor = torch.DoubleTensor(torch.DoubleStorage(self:size()):copy(self), 1, self:size()):abs()
|
||||
local expMin = tensor:minall()
|
||||
if expMin ~= 0 then
|
||||
expMin = math.floor(math.log10(expMin)) + 1
|
||||
end
|
||||
local expMax = tensor:maxall()
|
||||
if expMax ~= 0 then
|
||||
expMax = math.floor(math.log10(expMax)) + 1
|
||||
end
|
||||
|
||||
local format
|
||||
local scale
|
||||
local sz
|
||||
if intMode then
|
||||
if expMax > 9 then
|
||||
format = "%11.4e"
|
||||
sz = 11
|
||||
else
|
||||
format = "%SZd"
|
||||
sz = expMax + 1
|
||||
end
|
||||
else
|
||||
if expMax-expMin > 4 then
|
||||
format = "%SZ.4e"
|
||||
sz = 11
|
||||
if math.abs(expMax) > 99 or math.abs(expMin) > 99 then
|
||||
sz = sz + 1
|
||||
end
|
||||
else
|
||||
if expMax > 5 or expMax < 0 then
|
||||
format = "%SZ.4f"
|
||||
sz = 7
|
||||
scale = math.pow(10, expMax-1)
|
||||
else
|
||||
format = "%SZ.4f"
|
||||
if expMax == 0 then
|
||||
sz = 7
|
||||
else
|
||||
sz = expMax+6
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
format = string.gsub(format, 'SZ', sz)
|
||||
if scale == 1 then
|
||||
scale = nil
|
||||
end
|
||||
return format, scale, sz
|
||||
end
|
||||
|
||||
function Storage.__tostring__(self)
|
||||
local strt = {'\n'}
|
||||
local format,scale = Storage__printformat(self)
|
||||
if format:sub(2,4) == 'nan' then format = '%f' end
|
||||
if scale then
|
||||
table.insert(strt, string.format('%g', scale) .. ' *\n')
|
||||
for i = 1,self:size() do
|
||||
table.insert(strt, string.format(format, self[i]/scale) .. '\n')
|
||||
end
|
||||
else
|
||||
for i = 1,self:size() do
|
||||
table.insert(strt, string.format(format, self[i]) .. '\n')
|
||||
end
|
||||
end
|
||||
table.insert(strt, '[' .. torch.typename(self) .. ' of size ' .. self:size() .. ']\n')
|
||||
str = table.concat(strt)
|
||||
return str
|
||||
end
|
||||
|
||||
for _,type in ipairs(types) do
|
||||
local metatable = torch.getmetatable('torch.' .. type .. 'Storage')
|
||||
for funcname, func in pairs(Storage) do
|
||||
rawset(metatable, funcname, func)
|
||||
end
|
||||
end
|
||||
|
||||
local function Tensor__printMatrix(self, indent)
|
||||
local format,scale,sz = Storage__printformat(self:storage())
|
||||
if format:sub(2,4) == 'nan' then format = '%f' end
|
||||
-- print('format = ' .. format)
|
||||
scale = scale or 1
|
||||
indent = indent or ''
|
||||
local strt = {indent}
|
||||
local nColumnPerLine = math.floor((80-#indent)/(sz+1))
|
||||
-- print('sz = ' .. sz .. ' and nColumnPerLine = ' .. nColumnPerLine)
|
||||
local firstColumn = 1
|
||||
local lastColumn = -1
|
||||
while firstColumn <= self:size(2) do
|
||||
if firstColumn + nColumnPerLine - 1 <= self:size(2) then
|
||||
lastColumn = firstColumn + nColumnPerLine - 1
|
||||
else
|
||||
lastColumn = self:size(2)
|
||||
end
|
||||
if nColumnPerLine < self:size(2) then
|
||||
if firstColumn ~= 1 then
|
||||
table.insert(strt, '\n')
|
||||
end
|
||||
table.insert(strt, 'Columns ' .. firstColumn .. ' to ' .. lastColumn .. '\n' .. indent)
|
||||
end
|
||||
if scale ~= 1 then
|
||||
table.insert(strt, string.format('%g', scale) .. ' *\n ' .. indent)
|
||||
end
|
||||
for l=1,self:size(1) do
|
||||
local row = self:select(1, l)
|
||||
for c=firstColumn,lastColumn do
|
||||
table.insert(strt, string.format(format, row[c]/scale))
|
||||
if c == lastColumn then
|
||||
table.insert(strt, '\n')
|
||||
if l~=self:size(1) then
|
||||
if scale ~= 1 then
|
||||
table.insert(strt, indent .. ' ')
|
||||
else
|
||||
table.insert(strt, indent)
|
||||
end
|
||||
end
|
||||
else
|
||||
table.insert(strt, ' ')
|
||||
end
|
||||
end
|
||||
end
|
||||
firstColumn = lastColumn + 1
|
||||
end
|
||||
local str = table.concat(strt)
|
||||
return str
|
||||
end
|
||||
|
||||
local function Tensor__printTensor(self)
|
||||
local counter = torch.LongStorage(self:nDimension()-2)
|
||||
local strt = {''}
|
||||
local finished
|
||||
counter:fill(1)
|
||||
counter[1] = 0
|
||||
while true do
|
||||
for i=1,self:nDimension()-2 do
|
||||
counter[i] = counter[i] + 1
|
||||
if counter[i] > self:size(i) then
|
||||
if i == self:nDimension()-2 then
|
||||
finished = true
|
||||
break
|
||||
end
|
||||
counter[i] = 1
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
if finished then
|
||||
break
|
||||
end
|
||||
-- print(counter)
|
||||
if #strt > 1 then
|
||||
table.insert(strt, '\n')
|
||||
end
|
||||
table.insert(strt, '(')
|
||||
local tensor = self
|
||||
for i=1,self:nDimension()-2 do
|
||||
tensor = tensor:select(1, counter[i])
|
||||
table.insert(strt, counter[i] .. ',')
|
||||
end
|
||||
table.insert(strt, '.,.) = \n')
|
||||
table.insert(strt, Tensor__printMatrix(tensor, ' '))
|
||||
end
|
||||
local str = table.concat(strt)
|
||||
return str
|
||||
end
|
||||
|
||||
function Tensor.__tostring__(self)
|
||||
local str = '\n'
|
||||
local strt = {''}
|
||||
if self:nDimension() == 0 then
|
||||
table.insert(strt, '[' .. torch.typename(self) .. ' with no dimension]\n')
|
||||
else
|
||||
local tensor = torch.DoubleTensor():resize(self:size()):copy(self)
|
||||
if tensor:nDimension() == 1 then
|
||||
local format,scale,sz = Storage__printformat(tensor:storage())
|
||||
if format:sub(2,4) == 'nan' then format = '%f' end
|
||||
if scale then
|
||||
table.insert(strt, string.format('%g', scale) .. ' *\n')
|
||||
for i = 1,tensor:size(1) do
|
||||
table.insert(strt, string.format(format, tensor[i]/scale) .. '\n')
|
||||
end
|
||||
else
|
||||
for i = 1,tensor:size(1) do
|
||||
table.insert(strt, string.format(format, tensor[i]) .. '\n')
|
||||
end
|
||||
end
|
||||
table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ' .. tensor:size(1) .. ']\n')
|
||||
elseif tensor:nDimension() == 2 then
|
||||
table.insert(strt, Tensor__printMatrix(tensor))
|
||||
table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ' .. tensor:size(1) .. 'x' .. tensor:size(2) .. ']\n')
|
||||
else
|
||||
table.insert(strt, Tensor__printTensor(tensor))
|
||||
table.insert(strt, '[' .. torch.typename(self) .. ' of dimension ')
|
||||
for i=1,tensor:nDimension() do
|
||||
table.insert(strt, tensor:size(i))
|
||||
if i ~= tensor:nDimension() then
|
||||
table.insert(strt, 'x')
|
||||
end
|
||||
end
|
||||
table.insert(strt, ']\n')
|
||||
end
|
||||
end
|
||||
local str = table.concat(strt)
|
||||
return str
|
||||
end
|
||||
|
||||
function Tensor.type(self,type)
|
||||
local current = torch.typename(self)
|
||||
if not type then return current end
|
||||
if type ~= current then
|
||||
local new = torch.getmetatable(type).new()
|
||||
if self:nElement() > 0 then
|
||||
new:resize(self:size()):copy(self)
|
||||
end
|
||||
return new
|
||||
else
|
||||
return self
|
||||
end
|
||||
end
|
||||
|
||||
function Tensor.typeAs(self,tensor)
|
||||
return self:type(tensor:type())
|
||||
end
|
||||
|
||||
function Tensor.byte(self,type)
|
||||
return self:type('torch.ByteTensor')
|
||||
end
|
||||
|
||||
function Tensor.char(self,type)
|
||||
return self:type('torch.CharTensor')
|
||||
end
|
||||
|
||||
function Tensor.short(self,type)
|
||||
return self:type('torch.ShortTensor')
|
||||
end
|
||||
|
||||
function Tensor.int(self,type)
|
||||
return self:type('torch.IntTensor')
|
||||
end
|
||||
|
||||
function Tensor.long(self,type)
|
||||
return self:type('torch.LongTensor')
|
||||
end
|
||||
|
||||
function Tensor.float(self,type)
|
||||
return self:type('torch.FloatTensor')
|
||||
end
|
||||
|
||||
function Tensor.double(self,type)
|
||||
return self:type('torch.DoubleTensor')
|
||||
end
|
||||
|
||||
|
||||
for _,type in ipairs(types) do
|
||||
local metatable = torch.getmetatable('torch.' .. type .. 'Tensor')
|
||||
for funcname, func in pairs(Tensor) do
|
||||
rawset(metatable, funcname, func)
|
||||
end
|
||||
end
|
||||
127
TensorConvWrap.lua
Normal file
127
TensorConvWrap.lua
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
--
|
||||
-- require 'wrap'
|
||||
---
|
||||
|
||||
interface = wrap.CInterface.new()
|
||||
|
||||
|
||||
interface.dispatchregistry = {}
|
||||
function interface:wrap(name, ...)
|
||||
-- usual stuff
|
||||
--wrap.CInterface.wrap(self, name, ...)
|
||||
|
||||
-- dispatch function
|
||||
if not interface.dispatchregistry[name] then
|
||||
interface.dispatchregistry[name] = true
|
||||
table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)})
|
||||
|
||||
interface:print(string.gsub([[
|
||||
static int torch_NAME(lua_State *L)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
const void *id;
|
||||
|
||||
if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */
|
||||
{
|
||||
if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */
|
||||
{
|
||||
if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */
|
||||
lua_pop(L, 1);
|
||||
else if(!(id = torch_istensorid(L, torch_getdefaulttensorid())))
|
||||
luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor");
|
||||
}
|
||||
}
|
||||
|
||||
lua_pushstring(L, "NAME");
|
||||
lua_rawget(L, -2);
|
||||
if(lua_isfunction(L, -1))
|
||||
{
|
||||
lua_insert(L, 1);
|
||||
lua_pop(L, 2); /* the two tables we put on the stack above */
|
||||
lua_call(L, lua_gettop(L)-1, LUA_MULTRET);
|
||||
}
|
||||
else
|
||||
return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id));
|
||||
|
||||
return lua_gettop(L);
|
||||
}
|
||||
]], 'NAME', name))
|
||||
end
|
||||
end
|
||||
|
||||
function interface:dispatchregister(name)
|
||||
local txt = self.txt
|
||||
table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
|
||||
for _,reg in ipairs(self.dispatchregistry) do
|
||||
table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname))
|
||||
end
|
||||
table.insert(txt, '{NULL, NULL}')
|
||||
table.insert(txt, '};')
|
||||
table.insert(txt, '')
|
||||
self.dispatchregistry = {}
|
||||
end
|
||||
|
||||
interface:print('/* WARNING: autogenerated file */')
|
||||
interface:print('')
|
||||
|
||||
local reals = {ByteTensor='byte',
|
||||
CharTensor='char',
|
||||
ShortTensor='short',
|
||||
IntTensor='int',
|
||||
LongTensor='long',
|
||||
FloatTensor='float',
|
||||
DoubleTensor='double'}
|
||||
|
||||
for _,Tensor in ipairs({"FloatTensor", "DoubleTensor", "IntTensor", "LongTensor", "ByteTensor", "CharTensor","ShortTensor"}) do
|
||||
|
||||
local real = reals[Tensor]
|
||||
|
||||
function interface.luaname2wrapname(self, name)
|
||||
return string.format('torch_%s_%s', Tensor, name)
|
||||
end
|
||||
|
||||
local function cname(name)
|
||||
return string.format('TH%s_%s', Tensor, name)
|
||||
end
|
||||
|
||||
local function lastdim(argn)
|
||||
return function(arg)
|
||||
return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg())
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
for _,name in ipairs({"conv2","xcorr2","conv3","xcorr3"}) do
|
||||
interface:wrap(name,
|
||||
cname(name),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=Tensor}}
|
||||
)
|
||||
end
|
||||
|
||||
|
||||
--interface:register(string.format("torch_%sLapack__", Tensor))
|
||||
|
||||
-- interface:print(string.gsub([[
|
||||
-- static void torch_TensorLapack_init(lua_State *L)
|
||||
-- {
|
||||
-- torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor");
|
||||
-- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
||||
|
||||
-- luaT_pushmetaclass(L, torch_Tensor_id);
|
||||
-- lua_getfield(L,-1,"torch");
|
||||
-- luaL_register(L, NULL, torch_TensorLapack__);
|
||||
-- lua_pop(L, 2);
|
||||
-- }
|
||||
-- ]], 'Tensor', Tensor))
|
||||
end
|
||||
|
||||
interface:dispatchregister("torch_TensorConv__")
|
||||
|
||||
if arg[1] then
|
||||
interface:tofile(arg[1])
|
||||
else
|
||||
interface:tostdio()
|
||||
end
|
||||
132
TensorLapackWrap.lua
Normal file
132
TensorLapackWrap.lua
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
--
|
||||
-- require 'wrap'
|
||||
---
|
||||
|
||||
interface = wrap.CInterface.new()
|
||||
|
||||
|
||||
interface.dispatchregistry = {}
|
||||
function interface:wrap(name, ...)
|
||||
-- usual stuff
|
||||
wrap.CInterface.wrap(self, name, ...)
|
||||
|
||||
-- dispatch function
|
||||
if not interface.dispatchregistry[name] then
|
||||
interface.dispatchregistry[name] = true
|
||||
table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)})
|
||||
|
||||
interface:print(string.gsub([[
|
||||
static int torch_NAME(lua_State *L)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
const void *id;
|
||||
|
||||
if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */
|
||||
{
|
||||
if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */
|
||||
{
|
||||
if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */
|
||||
lua_pop(L, 1);
|
||||
else if(!(id = torch_istensorid(L, torch_getdefaulttensorid())))
|
||||
luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor");
|
||||
}
|
||||
}
|
||||
|
||||
lua_pushstring(L, "NAME");
|
||||
lua_rawget(L, -2);
|
||||
if(lua_isfunction(L, -1))
|
||||
{
|
||||
lua_insert(L, 1);
|
||||
lua_pop(L, 2); /* the two tables we put on the stack above */
|
||||
lua_call(L, lua_gettop(L)-1, LUA_MULTRET);
|
||||
}
|
||||
else
|
||||
return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id));
|
||||
|
||||
return lua_gettop(L);
|
||||
}
|
||||
]], 'NAME', name))
|
||||
end
|
||||
end
|
||||
|
||||
function interface:dispatchregister(name)
|
||||
local txt = self.txt
|
||||
table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
|
||||
for _,reg in ipairs(self.dispatchregistry) do
|
||||
table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname))
|
||||
end
|
||||
table.insert(txt, '{NULL, NULL}')
|
||||
table.insert(txt, '};')
|
||||
table.insert(txt, '')
|
||||
self.dispatchregistry = {}
|
||||
end
|
||||
|
||||
interface:print('/* WARNING: autogenerated file */')
|
||||
interface:print('')
|
||||
|
||||
local reals = {ByteTensor='byte',
|
||||
CharTensor='char',
|
||||
ShortTensor='short',
|
||||
IntTensor='int',
|
||||
LongTensor='long',
|
||||
FloatTensor='float',
|
||||
DoubleTensor='double'}
|
||||
|
||||
for _,Tensor in ipairs({"FloatTensor", "DoubleTensor"}) do
|
||||
|
||||
local real = reals[Tensor]
|
||||
|
||||
function interface.luaname2wrapname(self, name)
|
||||
return string.format('torch_%s_%s', Tensor, name)
|
||||
end
|
||||
|
||||
local function cname(name)
|
||||
return string.format('TH%s_%s', Tensor, name)
|
||||
end
|
||||
|
||||
local function lastdim(argn)
|
||||
return function(arg)
|
||||
return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg())
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
-- for _,name in ipairs({"gesv","gels","eig","svd"}) do
|
||||
-- interface:wrap(name,
|
||||
-- cname(name),
|
||||
-- {{name=Tensor, returned=true},
|
||||
-- {name=Tensor, returned=true},
|
||||
-- {name=Tensor},
|
||||
-- {name=Tensor}},
|
||||
-- cname(name),
|
||||
-- {{name=Tensor, default=true, returned=true, invisible=true},
|
||||
-- {name=Tensor, default=true, returned=true, invisible=true},
|
||||
-- {name=Tensor},
|
||||
-- {name=Tensor}}
|
||||
-- )
|
||||
-- end
|
||||
|
||||
|
||||
--interface:register(string.format("torch_%sLapack__", Tensor))
|
||||
|
||||
-- interface:print(string.gsub([[
|
||||
-- static void torch_TensorLapack_init(lua_State *L)
|
||||
-- {
|
||||
-- torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor");
|
||||
-- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
||||
|
||||
-- luaT_pushmetaclass(L, torch_Tensor_id);
|
||||
-- lua_getfield(L,-1,"torch");
|
||||
-- luaL_register(L, NULL, torch_TensorLapack__);
|
||||
-- lua_pop(L, 2);
|
||||
-- }
|
||||
-- ]], 'Tensor', Tensor))
|
||||
end
|
||||
|
||||
interface:dispatchregister("torch_TensorLapack__")
|
||||
|
||||
if arg[1] then
|
||||
interface:tofile(arg[1])
|
||||
else
|
||||
interface:tostdio()
|
||||
end
|
||||
54
TensorMath.c
Normal file
54
TensorMath.c
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
#include "TH.h"
|
||||
#include "luaT.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include "sys/time.h"
|
||||
|
||||
#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
|
||||
#define torch_string_(NAME) TH_CONCAT_STRING_3(torch., Real, NAME)
|
||||
|
||||
static const void* torch_ByteTensor_id;
|
||||
static const void* torch_CharTensor_id;
|
||||
static const void* torch_ShortTensor_id;
|
||||
static const void* torch_IntTensor_id;
|
||||
static const void* torch_LongTensor_id;
|
||||
static const void* torch_FloatTensor_id;
|
||||
static const void* torch_DoubleTensor_id;
|
||||
|
||||
static const void* torch_LongStorage_id;
|
||||
|
||||
|
||||
#include "TensorMathWrap.c"
|
||||
//#include "TensorLapackWrap.c"
|
||||
//#include "TensorConvWrap.c"
|
||||
|
||||
//#include "generic/TensorLapack.c"
|
||||
//#include "THGenerateFloatTypes.h"
|
||||
|
||||
//#include "generic/TensorConv.c"
|
||||
//#include "THGenerateAllTypes.h"
|
||||
|
||||
void torch_TensorMath_init(lua_State *L)
|
||||
{
|
||||
torch_ByteTensorMath_init(L);
|
||||
torch_CharTensorMath_init(L);
|
||||
torch_ShortTensorMath_init(L);
|
||||
torch_IntTensorMath_init(L);
|
||||
torch_LongTensorMath_init(L);
|
||||
torch_FloatTensorMath_init(L);
|
||||
torch_DoubleTensorMath_init(L);
|
||||
luaL_register(L, NULL, torch_TensorMath__);
|
||||
|
||||
/* torch_FloatLapack_init(L); */
|
||||
/* torch_DoubleLapack_init(L); */
|
||||
/* luaL_register(L, NULL, torch_TensorLapack__); */
|
||||
|
||||
/* torch_ByteConv_init(L); */
|
||||
/* torch_CharConv_init(L); */
|
||||
/* torch_ShortConv_init(L); */
|
||||
/* torch_IntConv_init(L); */
|
||||
/* torch_LongConv_init(L); */
|
||||
/* torch_FloatConv_init(L); */
|
||||
/* torch_DoubleConv_init(L); */
|
||||
/* luaL_register(L, NULL, torch_TensorConv__); */
|
||||
}
|
||||
110
TensorMath.lua
Normal file
110
TensorMath.lua
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
for _,tensortype in ipairs({'ByteTensor',
|
||||
'CharTensor',
|
||||
'ShortTensor',
|
||||
'IntTensor',
|
||||
'LongTensor',
|
||||
'FloatTensor',
|
||||
'DoubleTensor'}) do
|
||||
|
||||
for _,func in ipairs({'add',
|
||||
'mul',
|
||||
'div',
|
||||
'cmul',
|
||||
'cdiv',
|
||||
'addcmul',
|
||||
'addcdiv',
|
||||
'log',
|
||||
'log1p',
|
||||
'exp',
|
||||
'cos',
|
||||
'acos',
|
||||
'cosh',
|
||||
'sin',
|
||||
'asin',
|
||||
'sinh',
|
||||
'tan',
|
||||
'atan',
|
||||
'tanh',
|
||||
'pow',
|
||||
'sqrt',
|
||||
'ceil',
|
||||
'floor',
|
||||
'abs',
|
||||
'sign'
|
||||
}) do
|
||||
|
||||
local torchfunc = torch[tensortype].torch[func]
|
||||
torch[tensortype][func] = function(self, ...)
|
||||
return torchfunc(self, self, ...)
|
||||
end
|
||||
end
|
||||
|
||||
for _,func in ipairs({'addmv',
|
||||
'addmm',
|
||||
'addr'}) do
|
||||
|
||||
local torchfunc = torch[tensortype].torch[func]
|
||||
torch[tensortype][func] = function(self, next1, next2, ...)
|
||||
if type(next1) == 'number' and type(next2) == 'number' then
|
||||
return torchfunc(self, next1, self, next2, ...)
|
||||
elseif type(next1) == 'number' then
|
||||
return torchfunc(self, self, next1, next2, ...)
|
||||
else
|
||||
return torchfunc(self, self, next1, next2, ...)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
for _,func in ipairs({'zero',
|
||||
'fill',
|
||||
'dot',
|
||||
'minall',
|
||||
'maxall',
|
||||
'sumall',
|
||||
'numel',
|
||||
'max',
|
||||
'min',
|
||||
'sum',
|
||||
'prod',
|
||||
'cumsum',
|
||||
'cumprod',
|
||||
'trace',
|
||||
'cross',
|
||||
'zeros',
|
||||
'ones',
|
||||
'diag',
|
||||
'eye',
|
||||
'range',
|
||||
'randperm',
|
||||
'reshape',
|
||||
'sort',
|
||||
'tril',
|
||||
'triu',
|
||||
'_histc',
|
||||
'cat',
|
||||
'mean',
|
||||
'std',
|
||||
'var',
|
||||
'norm',
|
||||
'dist',
|
||||
'meanall',
|
||||
'varall',
|
||||
'stdall',
|
||||
'linspace',
|
||||
'logspace',
|
||||
'rand',
|
||||
'randn',
|
||||
'random',
|
||||
'uniform',
|
||||
'normal',
|
||||
'cauchy',
|
||||
'logNormal',
|
||||
'exponential',
|
||||
'geometric',
|
||||
'bernoulli',
|
||||
'squeeze'
|
||||
}) do
|
||||
|
||||
torch[tensortype][func] = torch[tensortype].torch[func]
|
||||
end
|
||||
end
|
||||
870
TensorMathWrap.lua
Normal file
870
TensorMathWrap.lua
Normal file
|
|
@ -0,0 +1,870 @@
|
|||
--
|
||||
-- require 'wrap'
|
||||
---
|
||||
|
||||
local interface = wrap.CInterface.new()
|
||||
|
||||
interface:print([[
|
||||
#include "TH.h"
|
||||
#include "luaT.h"
|
||||
#include "utils.h"
|
||||
|
||||
static const void* torch_ByteTensor_id;
|
||||
static const void* torch_CharTensor_id;
|
||||
static const void* torch_ShortTensor_id;
|
||||
static const void* torch_IntTensor_id;
|
||||
static const void* torch_LongTensor_id;
|
||||
static const void* torch_FloatTensor_id;
|
||||
static const void* torch_DoubleTensor_id;
|
||||
|
||||
static const void* torch_LongStorage_id;
|
||||
]])
|
||||
|
||||
-- special argument specific to torch package
|
||||
interface.argtypes.LongArg = {
|
||||
|
||||
vararg = true,
|
||||
|
||||
helpname = function(arg)
|
||||
return "(LongStorage | dim1 [dim2...])"
|
||||
end,
|
||||
|
||||
declare = function(arg)
|
||||
return string.format("THLongStorage *arg%d = NULL;", arg.i)
|
||||
end,
|
||||
|
||||
init = function(arg)
|
||||
if arg.default then
|
||||
error('LongArg cannot have a default value')
|
||||
end
|
||||
end,
|
||||
|
||||
check = function(arg, idx)
|
||||
return string.format("torch_islongargs(L, %d)", idx)
|
||||
end,
|
||||
|
||||
read = function(arg, idx)
|
||||
return string.format("arg%d = torch_checklongargs(L, %d);", arg.i, idx)
|
||||
end,
|
||||
|
||||
carg = function(arg, idx)
|
||||
return string.format('arg%d', arg.i)
|
||||
end,
|
||||
|
||||
creturn = function(arg, idx)
|
||||
return string.format('arg%d', arg.i)
|
||||
end,
|
||||
|
||||
precall = function(arg)
|
||||
local txt = {}
|
||||
if arg.returned then
|
||||
table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongStorage_id);', arg.i))
|
||||
end
|
||||
return table.concat(txt, '\n')
|
||||
end,
|
||||
|
||||
postcall = function(arg)
|
||||
local txt = {}
|
||||
if arg.creturned then
|
||||
-- this next line is actually debatable
|
||||
table.insert(txt, string.format('THLongStorage_retain(arg%d);', arg.i))
|
||||
table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongStorage_id);', arg.i))
|
||||
end
|
||||
if not arg.returned and not arg.creturned then
|
||||
table.insert(txt, string.format('THLongStorage_free(arg%d);', arg.i))
|
||||
end
|
||||
return table.concat(txt, '\n')
|
||||
end
|
||||
}
|
||||
|
||||
interface.argtypes.charoption = {
|
||||
|
||||
helpname = function(arg)
|
||||
if arg.values then
|
||||
return "(" .. table.concat(arg.values, '|') .. ")"
|
||||
end
|
||||
end,
|
||||
|
||||
declare = function(arg)
|
||||
local txt = {}
|
||||
table.insert(txt, string.format("const char *arg%d = NULL;", arg.i))
|
||||
if arg.default then
|
||||
table.insert(txt, string.format("char arg%d_default = '%s';", arg.i, arg.default))
|
||||
end
|
||||
return table.concat(txt, '\n')
|
||||
end,
|
||||
|
||||
init = function(arg)
|
||||
return string.format("arg%d = &arg%d_default;", arg.i, arg.i)
|
||||
end,
|
||||
|
||||
check = function(arg, idx)
|
||||
local txt = {}
|
||||
local txtv = {}
|
||||
table.insert(txt, string.format('(arg%d = lua_tostring(L, %d)) && (', arg.i, idx))
|
||||
for _,value in ipairs(arg.values) do
|
||||
table.insert(txtv, string.format("*arg%d == '%s'", arg.i, value))
|
||||
end
|
||||
table.insert(txt, table.concat(txtv, ' || '))
|
||||
table.insert(txt, ')')
|
||||
return table.concat(txt, '')
|
||||
end,
|
||||
|
||||
read = function(arg, idx)
|
||||
end,
|
||||
|
||||
carg = function(arg, idx)
|
||||
return string.format('arg%d', arg.i)
|
||||
end,
|
||||
|
||||
creturn = function(arg, idx)
|
||||
end,
|
||||
|
||||
precall = function(arg)
|
||||
end,
|
||||
|
||||
postcall = function(arg)
|
||||
end
|
||||
}
|
||||
|
||||
-- also specific to torch: we generate a 'dispatch' function
|
||||
-- first we create a helper function
|
||||
interface:print([[
|
||||
static const void* torch_istensorid(lua_State *L, const void *id)
|
||||
{
|
||||
if(!id)
|
||||
return NULL;
|
||||
|
||||
luaT_pushmetaclass(L, id);
|
||||
lua_pushstring(L, "torch");
|
||||
lua_rawget(L, -2);
|
||||
if(lua_istable(L, -1))
|
||||
return id;
|
||||
else
|
||||
{
|
||||
lua_pop(L, 2);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
]])
|
||||
|
||||
interface.dispatchregistry = {}
|
||||
function interface:wrap(name, ...)
|
||||
-- usual stuff
|
||||
wrap.CInterface.wrap(self, name, ...)
|
||||
|
||||
-- dispatch function
|
||||
if not interface.dispatchregistry[name] then
|
||||
interface.dispatchregistry[name] = true
|
||||
table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)})
|
||||
|
||||
interface:print(string.gsub([[
|
||||
static int torch_NAME(lua_State *L)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
const void *id;
|
||||
|
||||
if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */
|
||||
{
|
||||
if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */
|
||||
{
|
||||
if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */
|
||||
lua_pop(L, 1);
|
||||
else if(!(id = torch_istensorid(L, torch_getdefaulttensorid())))
|
||||
luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor");
|
||||
}
|
||||
}
|
||||
|
||||
lua_pushstring(L, "NAME");
|
||||
lua_rawget(L, -2);
|
||||
if(lua_isfunction(L, -1))
|
||||
{
|
||||
lua_insert(L, 1);
|
||||
lua_pop(L, 2); /* the two tables we put on the stack above */
|
||||
lua_call(L, lua_gettop(L)-1, LUA_MULTRET);
|
||||
}
|
||||
else
|
||||
return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id));
|
||||
|
||||
return lua_gettop(L);
|
||||
}
|
||||
]], 'NAME', name))
|
||||
end
|
||||
end
|
||||
|
||||
function interface:dispatchregister(name)
|
||||
local txt = self.txt
|
||||
table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
|
||||
for _,reg in ipairs(self.dispatchregistry) do
|
||||
table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname))
|
||||
end
|
||||
table.insert(txt, '{NULL, NULL}')
|
||||
table.insert(txt, '};')
|
||||
table.insert(txt, '')
|
||||
self.dispatchregistry = {}
|
||||
end
|
||||
|
||||
interface:print('/* WARNING: autogenerated file */')
|
||||
interface:print('')
|
||||
|
||||
local reals = {ByteTensor='unsigned char',
|
||||
CharTensor='char',
|
||||
ShortTensor='short',
|
||||
IntTensor='int',
|
||||
LongTensor='long',
|
||||
FloatTensor='float',
|
||||
DoubleTensor='double'}
|
||||
|
||||
for _,Tensor in ipairs({"ByteTensor", "CharTensor",
|
||||
"ShortTensor", "IntTensor", "LongTensor",
|
||||
"FloatTensor", "DoubleTensor"}) do
|
||||
|
||||
local real = reals[Tensor]
|
||||
|
||||
function interface.luaname2wrapname(self, name)
|
||||
return string.format('torch_%s_%s', Tensor, name)
|
||||
end
|
||||
|
||||
local function cname(name)
|
||||
return string.format('TH%s_%s', Tensor, name)
|
||||
end
|
||||
|
||||
local function lastdim(argn)
|
||||
return function(arg)
|
||||
return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg())
|
||||
end
|
||||
end
|
||||
|
||||
interface:wrap("zero",
|
||||
cname("zero"),
|
||||
{{name=Tensor, returned=true}})
|
||||
|
||||
interface:wrap("fill",
|
||||
cname("fill"),
|
||||
{{name=Tensor, returned=true},
|
||||
{name=real}})
|
||||
|
||||
interface:wrap("zeros",
|
||||
cname("zeros"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name="LongArg"}})
|
||||
|
||||
interface:wrap("ones",
|
||||
cname("ones"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name="LongArg"}})
|
||||
|
||||
interface:wrap("reshape",
|
||||
cname("reshape"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="LongArg"}})
|
||||
|
||||
interface:wrap("dot",
|
||||
cname("dot"),
|
||||
{{name=Tensor},
|
||||
{name=Tensor},
|
||||
{name=real, creturned=true}})
|
||||
|
||||
for _,name in ipairs({"minall", "maxall", "sumall"}) do
|
||||
interface:wrap(name,
|
||||
cname(name),
|
||||
{{name=Tensor},
|
||||
{name=real, creturned=true}})
|
||||
end
|
||||
|
||||
interface:wrap("add",
|
||||
cname("add"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=real}},
|
||||
cname("cadd"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=real, default=1},
|
||||
{name=Tensor}})
|
||||
|
||||
interface:wrap("mul",
|
||||
cname("mul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=real}})
|
||||
|
||||
interface:wrap("div",
|
||||
cname("div"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=real}})
|
||||
|
||||
interface:wrap("cmul",
|
||||
cname("cmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
|
||||
interface:wrap("cdiv",
|
||||
cname("cdiv"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
|
||||
interface:wrap("addcmul",
|
||||
cname("addcmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=real, default=1},
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
|
||||
interface:wrap("addcdiv",
|
||||
cname("addcdiv"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=real, default=1},
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
|
||||
for _,name in ipairs({"addmv", "addmm", "addr"}) do
|
||||
interface:wrap(name,
|
||||
cname(name),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=1},
|
||||
{name=Tensor},
|
||||
{name=real, default=1},
|
||||
{name=Tensor},
|
||||
{name=Tensor}})
|
||||
end
|
||||
|
||||
interface:wrap("numel",
|
||||
cname("numel"),
|
||||
{{name=Tensor},
|
||||
{name=real, creturned=true}})
|
||||
|
||||
for _,name in ipairs({"sum", "prod", "cumsum", "cumprod"}) do
|
||||
interface:wrap(name,
|
||||
cname(name),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="index", default=lastdim(2)}})
|
||||
end
|
||||
|
||||
interface:wrap("min",
|
||||
cname("min"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name="IndexTensor", default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="index", default=lastdim(3)}})
|
||||
|
||||
interface:wrap("max",
|
||||
cname("max"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name="IndexTensor", default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="index", default=lastdim(3)}})
|
||||
|
||||
interface:wrap("trace",
|
||||
cname("trace"),
|
||||
{{name=Tensor},
|
||||
{name=real, creturned=true}})
|
||||
|
||||
interface:wrap("cross",
|
||||
cname("cross"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=Tensor},
|
||||
{name="index", default=0}})
|
||||
|
||||
interface:wrap("diag",
|
||||
cname("diag"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="long", default=0}})
|
||||
|
||||
interface:wrap("eye",
|
||||
cname("eye"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name="long"},
|
||||
{name="long", default=0}})
|
||||
|
||||
interface:wrap("range",
|
||||
cname("range"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real},
|
||||
{name=real},
|
||||
{name=real, default=1}})
|
||||
|
||||
interface:wrap("randperm",
|
||||
cname("randperm"),
|
||||
{{name=Tensor, default=true, returned=true, userpostcall=function(arg)
|
||||
return string.format("TH%s_add(%s, %s, 1);", Tensor, arg:carg(), arg:carg())
|
||||
end},
|
||||
{name="long"}})
|
||||
|
||||
interface:wrap("sort",
|
||||
cname("sort"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name="IndexTensor", default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="index", default=lastdim(3)},
|
||||
{name="boolean", default=0}})
|
||||
|
||||
|
||||
interface:wrap("tril",
|
||||
cname("tril"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="int", default=0}})
|
||||
|
||||
interface:wrap("triu",
|
||||
cname("triu"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="int", default=0}})
|
||||
|
||||
interface:wrap("cat",
|
||||
cname("cat"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=Tensor},
|
||||
{name="index", default=lastdim(2)}})
|
||||
|
||||
if Tensor == 'ByteTensor' then -- we declare this only once
|
||||
interface:print(
|
||||
[[
|
||||
static int THRandom_random2__(long a, long b)
|
||||
{
|
||||
THArgCheck(b >= a, 2, "upper bound must be larger than lower bound");
|
||||
return((THRandom_random() % (b+1-a)) + a);
|
||||
}
|
||||
|
||||
static int THRandom_random1__(long b)
|
||||
{
|
||||
THArgCheck(b > 0, 1, "upper bound must be strictly positive");
|
||||
return(THRandom_random() % b + 1);
|
||||
}
|
||||
]])
|
||||
end
|
||||
|
||||
interface:print(string.gsub(
|
||||
[[
|
||||
static void THTensor_random2__(THTensor *self, long a, long b)
|
||||
{
|
||||
THArgCheck(b >= a, 2, "upper bound must be larger than lower bound");
|
||||
TH_TENSOR_APPLY(real, self, *self_data = ((THRandom_random() % (b+1-a)) + a);)
|
||||
}
|
||||
|
||||
static void THTensor_random1__(THTensor *self, long b)
|
||||
{
|
||||
THArgCheck(b > 0, 1, "upper bound must be strictly positive");
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (THRandom_random() % b + 1);)
|
||||
}
|
||||
]], 'Tensor', Tensor):gsub('real', real))
|
||||
|
||||
interface:wrap('random',
|
||||
'THRandom_random2__',
|
||||
{{name='long'},
|
||||
{name='long'},
|
||||
{name='long', creturned=true}},
|
||||
'THRandom_random1__',
|
||||
{{name='long'},
|
||||
{name='long', creturned=true}},
|
||||
'THRandom_random',
|
||||
{{name='long', creturned=true}},
|
||||
cname("random2__"),
|
||||
{{name=Tensor},
|
||||
{name='long'},
|
||||
{name='long'}},
|
||||
cname("random1__"),
|
||||
{{name=Tensor},
|
||||
{name='long'}},
|
||||
cname("random"),
|
||||
{{name=Tensor}})
|
||||
|
||||
for _,f in ipairs({{name='geometric'},
|
||||
{name='bernoulli', a=0.5}}) do
|
||||
|
||||
interface:wrap(f.name,
|
||||
string.format("THRandom_%s", f.name),
|
||||
{{name="double", default=f.a},
|
||||
{name="double", creturned=true}},
|
||||
cname(f.name),
|
||||
{{name=Tensor, returned=true},
|
||||
{name=real, default=f.a}})
|
||||
end
|
||||
|
||||
interface:wrap("squeeze",
|
||||
cname("squeeze"),
|
||||
{{name=Tensor, default=true, returned=true, postcall=function(arg)
|
||||
local txt = {}
|
||||
if arg.returned then
|
||||
table.insert(txt, string.format('if(arg%d->nDimension == 1 && arg%d->size[0] == 1)', arg.i, arg.i)) -- number
|
||||
table.insert(txt, string.format('lua_pushnumber(L, (lua_Number)(*TH%s_data(arg%d)));', Tensor, arg.i))
|
||||
end
|
||||
return table.concat(txt, '\n')
|
||||
end},
|
||||
{name=Tensor}},
|
||||
cname("squeeze1d"),
|
||||
{{name=Tensor, default=true, returned=true, postcall=function(arg)
|
||||
local txt = {}
|
||||
if arg.returned then
|
||||
table.insert(txt, string.format('if(arg%d->nDimension == 1 && arg%d->size[0] == 1)', arg.i, arg.i)) -- number
|
||||
table.insert(txt, string.format('lua_pushnumber(L, (lua_Number)(*TH%s_data(arg%d)));', Tensor, arg.i))
|
||||
end
|
||||
return table.concat(txt, '\n')
|
||||
end},
|
||||
{name=Tensor},
|
||||
{name="index"}})
|
||||
|
||||
interface:wrap("sign",
|
||||
cname("sign"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor}})
|
||||
|
||||
interface:wrap("conv2",
|
||||
cname("conv2Dmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=2},
|
||||
{name=Tensor, dim=2},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="C", invisible=true}},
|
||||
cname("conv2Dcmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=3},
|
||||
{name=Tensor, dim=3},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="C", invisible=true}},
|
||||
cname("conv2Dmv"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=3},
|
||||
{name=Tensor, dim=4},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="C", invisible=true}}
|
||||
)
|
||||
|
||||
interface:wrap("xcorr2",
|
||||
cname("conv2Dmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=2},
|
||||
{name=Tensor, dim=2},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="X", invisible=true}},
|
||||
cname("conv2Dcmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=3},
|
||||
{name=Tensor, dim=3},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="X", invisible=true}},
|
||||
cname("conv2Dmv"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=3},
|
||||
{name=Tensor, dim=4},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="X", invisible=true}}
|
||||
)
|
||||
|
||||
interface:wrap("conv3",
|
||||
cname("conv3Dmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=3},
|
||||
{name=Tensor, dim=3},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="C", invisible=true}},
|
||||
cname("conv3Dcmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=4},
|
||||
{name=Tensor, dim=4},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="C", invisible=true}},
|
||||
cname("conv3Dmv"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=4},
|
||||
{name=Tensor, dim=5},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="C", invisible=true}}
|
||||
)
|
||||
|
||||
interface:wrap("xcorr3",
|
||||
cname("conv3Dmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=3},
|
||||
{name=Tensor, dim=3},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="X", invisible=true}},
|
||||
cname("conv3Dcmul"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=4},
|
||||
{name=Tensor, dim=4},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="X", invisible=true}},
|
||||
cname("conv3Dmv"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real, default=0, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=Tensor, dim=4},
|
||||
{name=Tensor, dim=5},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name=real, default=1, invisible=true},
|
||||
{name='charoption', values={'V', 'F'}, default='V'},
|
||||
{name='charoption', default="X", invisible=true}}
|
||||
)
|
||||
|
||||
if Tensor == 'FloatTensor' or Tensor == 'DoubleTensor' then
|
||||
|
||||
interface:wrap("mean",
|
||||
cname("mean"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="index", default=lastdim(2)}})
|
||||
|
||||
interface:wrap("std",
|
||||
cname("std"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="index", default=lastdim(2)},
|
||||
{name="boolean", default=false}})
|
||||
|
||||
interface:wrap("var",
|
||||
cname("var"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name="index", default=lastdim(2)},
|
||||
{name="boolean", default=false}})
|
||||
|
||||
interface:wrap("norm",
|
||||
cname("norm"),
|
||||
{{name=Tensor},
|
||||
{name=real, default=2},
|
||||
{name=real, creturned=true}})
|
||||
|
||||
interface:wrap("dist",
|
||||
cname("dist"),
|
||||
{{name=Tensor},
|
||||
{name=Tensor},
|
||||
{name=real, default=2},
|
||||
{name=real, creturned=true}})
|
||||
|
||||
for _,name in ipairs({"meanall", "varall", "stdall"}) do
|
||||
interface:wrap(name,
|
||||
cname(name),
|
||||
{{name=Tensor},
|
||||
{name=real, creturned=true}})
|
||||
end
|
||||
|
||||
interface:wrap("linspace",
|
||||
cname("linspace"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real},
|
||||
{name=real},
|
||||
{name="long", default=100}})
|
||||
|
||||
interface:wrap("logspace",
|
||||
cname("logspace"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=real},
|
||||
{name=real},
|
||||
{name="long", default=100}})
|
||||
|
||||
for _,name in ipairs({"log", "log1p", "exp",
|
||||
"cos", "acos", "cosh",
|
||||
"sin", "asin", "sinh",
|
||||
"tan", "atan", "tanh",
|
||||
"sqrt",
|
||||
"ceil", "floor",
|
||||
"abs"}) do
|
||||
|
||||
interface:wrap(name,
|
||||
cname(name),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor}},
|
||||
name,
|
||||
{{name=real},
|
||||
{name=real, creturned=true}})
|
||||
|
||||
end
|
||||
|
||||
interface:wrap("pow",
|
||||
cname("pow"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name=Tensor},
|
||||
{name=real}},
|
||||
"pow",
|
||||
{{name=real},
|
||||
{name=real},
|
||||
{name=real, creturned=true}})
|
||||
|
||||
interface:wrap("rand",
|
||||
cname("rand"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name="LongArg"}})
|
||||
|
||||
interface:wrap("randn",
|
||||
cname("randn"),
|
||||
{{name=Tensor, default=true, returned=true},
|
||||
{name="LongArg"}})
|
||||
|
||||
for _,f in ipairs({{name='uniform', a=0, b=1},
|
||||
{name='normal', a=0, b=1},
|
||||
{name='cauchy', a=0, b=1},
|
||||
{name='logNormal', a=1, b=2}}) do
|
||||
|
||||
interface:wrap(f.name,
|
||||
string.format("THRandom_%s", f.name),
|
||||
{{name="double", default=f.a},
|
||||
{name="double", default=f.b},
|
||||
{name="double", creturned=true}},
|
||||
cname(f.name),
|
||||
{{name=Tensor, returned=true},
|
||||
{name=real, default=f.a},
|
||||
{name=real, default=f.b}})
|
||||
end
|
||||
|
||||
for _,f in ipairs({{name='exponential'}}) do
|
||||
|
||||
interface:wrap(f.name,
|
||||
string.format("THRandom_%s", f.name),
|
||||
{{name="double", default=f.a},
|
||||
{name="double", creturned=true}},
|
||||
cname(f.name),
|
||||
{{name=Tensor, returned=true},
|
||||
{name=real, default=f.a}})
|
||||
end
|
||||
|
||||
for _,name in ipairs({"gesv","gels"}) do
|
||||
interface:wrap(name,
|
||||
cname(name),
|
||||
{{name=Tensor, returned=true},
|
||||
{name=Tensor, returned=true},
|
||||
{name=Tensor},
|
||||
{name=Tensor}},
|
||||
cname(name),
|
||||
{{name=Tensor, default=true, returned=true, invisible=true},
|
||||
{name=Tensor, default=true, returned=true, invisible=true},
|
||||
{name=Tensor},
|
||||
{name=Tensor}}
|
||||
)
|
||||
end
|
||||
|
||||
interface:wrap("eig",
|
||||
cname("syev"),
|
||||
{{name=Tensor, returned=true},
|
||||
{name=Tensor, returned=true},
|
||||
{name=Tensor},
|
||||
{name='charoption', values={'N', 'V'}, default='N'},
|
||||
{name='charoption', values={'U', 'L'}, default='U'}},
|
||||
cname("syev"),
|
||||
{{name=Tensor, default=true, returned=true, invisible=true},
|
||||
{name=Tensor, default=true, returned=true, invisible=true},
|
||||
{name=Tensor},
|
||||
{name='charoption', values={'N', 'V'}, default='N'},
|
||||
{name='charoption', values={'U', 'L'}, default='U'}}
|
||||
)
|
||||
|
||||
interface:wrap("svd",
|
||||
cname("gesvd"),
|
||||
{{name=Tensor, returned=true},
|
||||
{name=Tensor, returned=true},
|
||||
{name=Tensor, returned=true},
|
||||
{name=Tensor},
|
||||
{name='charoption', values={'A', 'S'}, default='S'}},
|
||||
cname("gesvd"),
|
||||
{{name=Tensor, default=true, returned=true, invisible=true},
|
||||
{name=Tensor, default=true, returned=true, invisible=true},
|
||||
{name=Tensor, default=true, returned=true, invisible=true},
|
||||
{name=Tensor},
|
||||
{name='charoption', values={'A', 'S'}, default='S'}}
|
||||
)
|
||||
|
||||
end
|
||||
|
||||
interface:register(string.format("torch_%sMath__", Tensor))
|
||||
|
||||
interface:print(string.gsub([[
|
||||
static void torch_TensorMath_init(lua_State *L)
|
||||
{
|
||||
torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor");
|
||||
torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
||||
|
||||
/* register everything into the "torch" field of the tensor metaclass */
|
||||
luaT_pushmetaclass(L, torch_Tensor_id);
|
||||
lua_pushstring(L, "torch");
|
||||
lua_newtable(L);
|
||||
luaL_register(L, NULL, torch_TensorMath__);
|
||||
lua_rawset(L, -3);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
]], 'Tensor', Tensor))
|
||||
end
|
||||
|
||||
interface:dispatchregister("torch_TensorMath__")
|
||||
|
||||
interface:print([[
|
||||
void torch_TensorMath_init(lua_State *L)
|
||||
{
|
||||
torch_ByteTensorMath_init(L);
|
||||
torch_CharTensorMath_init(L);
|
||||
torch_ShortTensorMath_init(L);
|
||||
torch_IntTensorMath_init(L);
|
||||
torch_LongTensorMath_init(L);
|
||||
torch_FloatTensorMath_init(L);
|
||||
torch_DoubleTensorMath_init(L);
|
||||
luaL_register(L, NULL, torch_TensorMath__);
|
||||
}
|
||||
]])
|
||||
|
||||
if arg[1] then
|
||||
interface:tofile(arg[1])
|
||||
else
|
||||
interface:tostdio()
|
||||
end
|
||||
8
TensorOperator.c
Normal file
8
TensorOperator.c
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
#include "general.h"
|
||||
|
||||
#define torch_TensorOperator_(NAME) TH_CONCAT_4(torch_,Real,TensorOperator_,NAME)
|
||||
#define torch_Tensor_id TH_CONCAT_3(torch_,Real,Tensor_id)
|
||||
#define STRING_torchTensor TH_CONCAT_STRING_3(torch.,Real,Tensor)
|
||||
|
||||
#include "generic/TensorOperator.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
124
Tester.lua
Normal file
124
Tester.lua
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
local Tester = torch.class('torch.Tester')
|
||||
|
||||
function Tester:__init()
|
||||
self.errors = {}
|
||||
self.tests = {}
|
||||
self.testnames = {}
|
||||
self.curtestname = ''
|
||||
end
|
||||
|
||||
|
||||
function Tester:assert_sub (condition, message)
|
||||
if not condition then
|
||||
local ss = debug.traceback('tester',2)
|
||||
--print(ss)
|
||||
ss = ss:match('[^\n]+\n[^\n]+\n([^\n]+\n[^\n]+)\n')
|
||||
self.errors[#self.errors+1] = self.curtestname .. '\n' .. message .. '\n' .. ss .. '\n'
|
||||
end
|
||||
end
|
||||
function Tester:assert (condition, message)
|
||||
self:assert_sub(condition,string.format('%s\n%s condition=%s',message,' BOOL violation ', tostring(condition)))
|
||||
end
|
||||
function Tester:assertlt (val, condition, message)
|
||||
self:assert_sub(val<condition,string.format('%s\n%s val=%s, condition=%s',message,' LT(<) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
function Tester:assertgt (val, condition, message)
|
||||
self:assert_sub(val>condition,string.format('%s\n%s val=%s, condition=%s',message,' GT(>) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
function Tester:assertle (val, condition, message)
|
||||
self:assert_sub(val<=condition,string.format('%s\n%s val=%s, condition=%s',message,' LE(<=) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
function Tester:assertge (val, condition, message)
|
||||
self:assert_sub(val>=condition,string.format('%s\n%s val=%s, condition=%s',message,' GE(>=) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
function Tester:asserteq (val, condition, message)
|
||||
self:assert_sub(val==condition,string.format('%s\n%s val=%s, condition=%s',message,' EQ(==) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
function Tester:assertne (val, condition, message)
|
||||
self:assert_sub(val~=condition,string.format('%s\n%s val=%s, condition=%s',message,' NE(~=) violation ', tostring(val), tostring(condition)))
|
||||
end
|
||||
function Tester:assertTensorEq(ta, tb, condition, message)
|
||||
local diff = ta-tb
|
||||
local err = diff:abs():maxall()
|
||||
self:assert_sub(err<condition,string.format('%s\n%s val=%s, condition=%s',message,' TensorEQ(~=) violation ', tostring(err), tostring(condition)))
|
||||
end
|
||||
|
||||
function Tester:pcall(f)
|
||||
local nerr = #self.errors
|
||||
local res = f()
|
||||
-- local stat, result = pcall(f)
|
||||
-- if not stat then
|
||||
-- result = result .. debug.traceback()
|
||||
-- end
|
||||
-- return stat, result, stat and (nerr == #self.errors)
|
||||
return true, res, nerr == #self.errors
|
||||
end
|
||||
|
||||
function Tester:report()
|
||||
print('Completed ' .. #self.tests .. ' tests with ' .. #self.errors .. ' errors')
|
||||
print()
|
||||
print(string.rep('-',80))
|
||||
for i,v in ipairs(self.errors) do
|
||||
print(v)
|
||||
print(string.rep('-',80))
|
||||
end
|
||||
end
|
||||
|
||||
function Tester:run()
|
||||
print('Running ' .. #self.tests .. ' tests')
|
||||
local statstr = string.rep('_',#self.tests)
|
||||
local pstr = ''
|
||||
io.write(statstr .. '\r')
|
||||
for i,v in ipairs(self.tests) do
|
||||
self.curtestname = self.testnames[i]
|
||||
|
||||
--clear
|
||||
io.write('\r' .. string.rep(' ', pstr:len()))
|
||||
io.flush()
|
||||
--write
|
||||
pstr = statstr:sub(1,i-1) .. '|' .. statstr:sub(i+1) .. ' ==> ' .. self.curtestname
|
||||
io.write('\r' .. pstr)
|
||||
io.flush()
|
||||
|
||||
local stat, message, pass = self:pcall(v)
|
||||
|
||||
if pass then
|
||||
--io.write(string.format('\b_'))
|
||||
statstr = statstr:sub(1,i-1) .. '_' .. statstr:sub(i+1)
|
||||
else
|
||||
statstr = statstr:sub(1,i-1) .. '*' .. statstr:sub(i+1)
|
||||
--io.write(string.format('\b*'))
|
||||
end
|
||||
|
||||
if not stat then
|
||||
print()
|
||||
print('Function call failed: Test No ' .. i .. ' ' .. self.testnames[i])
|
||||
print(message)
|
||||
end
|
||||
collectgarbage()
|
||||
end
|
||||
--clear
|
||||
io.write('\r' .. string.rep(' ', pstr:len()))
|
||||
io.flush()
|
||||
-- write finish
|
||||
pstr = statstr .. ' ==> Done '
|
||||
io.write('\r' .. pstr)
|
||||
io.flush()
|
||||
print()
|
||||
print()
|
||||
self:report()
|
||||
end
|
||||
|
||||
function Tester:add(f,name)
|
||||
name = name or 'unknown'
|
||||
if type(f) == "table" then
|
||||
for i,v in pairs(f) do
|
||||
self:add(v,i)
|
||||
end
|
||||
elseif type(f) == "function" then
|
||||
self.tests[#self.tests+1] = f
|
||||
self.testnames[#self.tests] = name
|
||||
else
|
||||
error('Tester:add(f) expects a function or a table of functions')
|
||||
end
|
||||
end
|
||||
157
Timer.c
Normal file
157
Timer.c
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
#include "general.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <time.h>
|
||||
#else
|
||||
#include <sys/time.h>
|
||||
#include <sys/resource.h>
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
static time_t base_time = 0;
|
||||
#endif
|
||||
|
||||
static const void* torch_Timer_id = NULL;
|
||||
|
||||
typedef struct _Timer
|
||||
{
|
||||
int isRunning;
|
||||
|
||||
double totalrealtime;
|
||||
double totalusertime;
|
||||
double totalsystime;
|
||||
|
||||
double startrealtime;
|
||||
double startusertime;
|
||||
double startsystime;
|
||||
|
||||
} Timer;
|
||||
|
||||
static double torch_Timer_realtime()
|
||||
{
|
||||
struct timeval current;
|
||||
gettimeofday(¤t, NULL);
|
||||
return (current.tv_sec + current.tv_usec/1000000.0);
|
||||
}
|
||||
|
||||
static double torch_Timer_usertime()
|
||||
{
|
||||
struct rusage current;
|
||||
getrusage(RUSAGE_SELF, ¤t);
|
||||
return (current.ru_utime.tv_sec + current.ru_utime.tv_usec/1000000.0);
|
||||
}
|
||||
|
||||
static double torch_Timer_systime()
|
||||
{
|
||||
struct rusage current;
|
||||
getrusage(RUSAGE_SELF, ¤t);
|
||||
return (current.ru_stime.tv_sec + current.ru_stime.tv_usec/1000000.0);
|
||||
}
|
||||
|
||||
static int torch_Timer_new(lua_State *L)
|
||||
{
|
||||
Timer *timer = luaT_alloc(L, sizeof(Timer));
|
||||
#ifdef _MSC_VER
|
||||
while(!base_time)
|
||||
time(&base_time);
|
||||
#endif
|
||||
timer->isRunning = 1;
|
||||
timer->totalrealtime = 0;
|
||||
timer->totalusertime = 0;
|
||||
timer->totalsystime = 0;
|
||||
timer->startrealtime = torch_Timer_realtime();
|
||||
timer->startusertime = torch_Timer_usertime();
|
||||
timer->startsystime = torch_Timer_systime();
|
||||
luaT_pushudata(L, timer, torch_Timer_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Timer_reset(lua_State *L)
|
||||
{
|
||||
Timer *timer = luaT_checkudata(L, 1, torch_Timer_id);
|
||||
timer->totalrealtime = 0;
|
||||
timer->totalusertime = 0;
|
||||
timer->totalsystime = 0;
|
||||
timer->startrealtime = torch_Timer_realtime();
|
||||
timer->startusertime = torch_Timer_usertime();
|
||||
timer->startsystime = torch_Timer_systime();
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Timer_free(lua_State *L)
|
||||
{
|
||||
Timer *timer = luaT_checkudata(L, 1, torch_Timer_id);
|
||||
luaT_free(L, timer);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int torch_Timer_stop(lua_State *L)
|
||||
{
|
||||
Timer *timer = luaT_checkudata(L, 1, torch_Timer_id);
|
||||
if(timer->isRunning)
|
||||
{
|
||||
double realtime = torch_Timer_realtime() - timer->startrealtime;
|
||||
double usertime = torch_Timer_usertime() - timer->startusertime;
|
||||
double systime = torch_Timer_systime() - timer->startsystime;
|
||||
timer->totalrealtime += realtime;
|
||||
timer->totalusertime += usertime;
|
||||
timer->totalsystime += systime;
|
||||
timer->isRunning = 0;
|
||||
}
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Timer_resume(lua_State *L)
|
||||
{
|
||||
Timer *timer = luaT_checkudata(L, 1, torch_Timer_id);
|
||||
if(!timer->isRunning)
|
||||
{
|
||||
timer->isRunning = 1;
|
||||
timer->startrealtime = torch_Timer_realtime();
|
||||
timer->startusertime = torch_Timer_usertime();
|
||||
timer->startsystime = torch_Timer_systime();
|
||||
}
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Timer_time(lua_State *L)
|
||||
{
|
||||
Timer *timer = luaT_checkudata(L, 1, torch_Timer_id);
|
||||
double realtime = (timer->isRunning ? (timer->totalrealtime + torch_Timer_realtime() - timer->startrealtime) : timer->totalrealtime);
|
||||
double usertime = (timer->isRunning ? (timer->totalusertime + torch_Timer_usertime() - timer->startusertime) : timer->totalusertime);
|
||||
double systime = (timer->isRunning ? (timer->totalsystime + torch_Timer_systime() - timer->startsystime) : timer->totalsystime);
|
||||
lua_createtable(L, 0, 3);
|
||||
lua_pushnumber(L, realtime);
|
||||
lua_setfield(L, -2, "real");
|
||||
lua_pushnumber(L, usertime);
|
||||
lua_setfield(L, -2, "user");
|
||||
lua_pushnumber(L, systime);
|
||||
lua_setfield(L, -2, "sys");
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Timer___tostring__(lua_State *L)
|
||||
{
|
||||
Timer *timer = luaT_checkudata(L, 1, torch_Timer_id);
|
||||
lua_pushfstring(L, "torch.Timer [status: %s]", (timer->isRunning ? "running" : "stopped"));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_Timer__ [] = {
|
||||
{"reset", torch_Timer_reset},
|
||||
{"stop", torch_Timer_stop},
|
||||
{"resume", torch_Timer_resume},
|
||||
{"time", torch_Timer_time},
|
||||
{"__tostring__", torch_Timer___tostring__},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_Timer_init(lua_State *L)
|
||||
{
|
||||
torch_Timer_id = luaT_newmetatable(L, "torch.Timer", NULL, torch_Timer_new, torch_Timer_free, NULL);
|
||||
luaL_register(L, NULL, torch_Timer__);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
115
dok/cmdline.dok
Normal file
115
dok/cmdline.dok
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
====== CmdLine ======
|
||||
{{anchor:torch.CmdLine.dok}}
|
||||
|
||||
This class provides a parameter parsing framework which is very
|
||||
usefull when one needs to run several experiments that rely on
|
||||
different parameter settings that are passed in the command line.
|
||||
This class will also override the default print function to direct
|
||||
all the output to a log file as well as screen at the same time.
|
||||
|
||||
A sample ''lua'' file is given below that makes use of ''CmdLine''
|
||||
class.
|
||||
|
||||
<file lua>
|
||||
|
||||
cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text()
|
||||
cmd:text('Training a simple network')
|
||||
cmd:text()
|
||||
cmd:text('Options')
|
||||
cmd:option('-seed',123,'initial random seed')
|
||||
cmd:option('-booloption',false,'boolean option')
|
||||
cmd:option('-stroption','mystring','string option')
|
||||
cmd:text()
|
||||
|
||||
-- parse input params
|
||||
params = cmd:pard(arg)
|
||||
|
||||
params.rundir = cmd:string('experiment', params, {dir=true})
|
||||
|
||||
-- create log file
|
||||
cmd:log(params.rundir .. '/log', params)
|
||||
|
||||
</file>
|
||||
|
||||
When this file is run on the lua commandline as follows
|
||||
<file shell>
|
||||
# lua myscript.lua
|
||||
</file>
|
||||
|
||||
It will produce the following output:
|
||||
|
||||
<file>
|
||||
[program started on Tue Jan 10 15:33:49 2012]
|
||||
[command line arguments]
|
||||
booloption false
|
||||
seed 123
|
||||
rundir experiment
|
||||
stroption mystring
|
||||
[----------------------]
|
||||
booloption false
|
||||
seed 123
|
||||
rundir experiment
|
||||
stroption mystring
|
||||
</file>
|
||||
|
||||
The same output will also be written to file
|
||||
''experiment/log''. Whenever one of the options are passed on the
|
||||
command line and is different than the default value, the ''rundir''
|
||||
is name is produced to reflect the parameter setting.
|
||||
|
||||
<file shell>
|
||||
# lua myscript.lua -seed 456 -stroption mycustomstring
|
||||
</file>
|
||||
|
||||
This will produce the following output:
|
||||
|
||||
<file>
|
||||
[program started on Tue Jan 10 15:36:55 2012]
|
||||
[command line arguments]
|
||||
booloption false
|
||||
seed 456
|
||||
rundir experiment,seed=456,stroption=mycustomstring
|
||||
stroption mycustomstring
|
||||
[----------------------]
|
||||
booloption false
|
||||
seed 456
|
||||
rundir experiment,seed=456,stroption=mycustomstring
|
||||
stroption mycustomstring
|
||||
</file>
|
||||
|
||||
and the output will be logged in
|
||||
''experiment,seed=456,stroption=mycustomstring/log''
|
||||
|
||||
==== text(string) ====
|
||||
{{anchor:torch.CmdLine.text}}
|
||||
Logs a custom text message.
|
||||
|
||||
==== option(name, default, help) ====
|
||||
{{anchor:torch.CmdLine.option}}
|
||||
Stores an option argument. The name should always start with '-'.
|
||||
|
||||
==== [table] parse(arg) ====
|
||||
{{anchor:torch.CmdLine.parse}}
|
||||
Parses a given table, ''arg'' is by default the argument table that
|
||||
is created by ''lua'' using the command line arguments passed to the
|
||||
executable. Returns a table of option values.
|
||||
|
||||
==== [string] string(prefix, params, ignore) ====
|
||||
{{anchor:torch.CmdLine.string}}
|
||||
|
||||
Returns a string representation of the options by concatenating the
|
||||
non-default options. ''ignore'' is a table ''{dir=true}'', which will
|
||||
ensure that option named ''dir'' will be ignored while creating the
|
||||
string representation.
|
||||
|
||||
This function is usefull for creating unique experiment directories that
|
||||
depend on the parameter settings.
|
||||
|
||||
==== log(filename, parameter_table) ====
|
||||
{{anchor:torch.CmdLine.log}}
|
||||
|
||||
It set the log filename to ''filename'' and prints the values of
|
||||
parameters in the ''parameter_table''.
|
||||
|
||||
64
dok/diskfile.dok
Normal file
64
dok/diskfile.dok
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
====== DiskFile ======
|
||||
{{anchor:torch.DiskFile.dok}}
|
||||
|
||||
Parent classes: [[File|File]]
|
||||
|
||||
A ''DiskFile'' is a particular ''File'' which is able to perform basic read/write operations
|
||||
on a file stored on disk. It implements all methods described in [[File|File]], and
|
||||
some additional methods relative to //endian// encoding.
|
||||
|
||||
By default, a ''DiskFile'' is in [[File#torch.File.binary|ASCII]] mode. If changed to
|
||||
the [[File#torch.File.binary|binary]] mode, the default endian encoding is the native
|
||||
computer one.
|
||||
|
||||
The file might be open in read, write, or read-write mode, depending on the parameter
|
||||
''mode'' (which can take the value ''"r"'', ''"w"'' or ''"rw"'' respectively)
|
||||
given to the [[#torch.DiskFile|torch.DiskFile(fileName, mode)]].
|
||||
|
||||
===== torch.DiskFile(fileName, [mode], [quiet]) =====
|
||||
{{anchor:torch.DiskFile}}
|
||||
|
||||
//Constructor// which opens ''fileName'' on disk, using the given ''mode''. Valid ''mode'' are
|
||||
''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). Default is read mode.
|
||||
|
||||
If read-write mode, the file //will be created// if it does not exists. If it
|
||||
exists, it will be positionned at the beginning of the file after opening.
|
||||
|
||||
If (and only if) ''quiet'' is ''true'', no error will be raised in case of
|
||||
problem opening the file: instead ''nil'' will be returned.
|
||||
|
||||
The file is opened in [[File#torch.File.ascii|ASCII]] mode by default.
|
||||
|
||||
===== bigEndianEncoding() =====
|
||||
{{anchor:torch.DiskFile.bigEndianEncoding}}
|
||||
|
||||
In [[file#torch.File.binary|binary]] mode, force encoding in //big endian//.
|
||||
(//big end first//: decreasing numeric significance with increasing memory
|
||||
addresses)
|
||||
|
||||
===== [boolean] isBigEndianCPU() =====
|
||||
{{anchor:torch.DiskFile.isBigEndianCPU}}
|
||||
|
||||
Returns ''true'' if, and only if, the computer CPU operates in //big endian//.
|
||||
//Big end first//: decreasing numeric significance with increasing
|
||||
memory addresses.
|
||||
|
||||
===== [boolean] isLittleEndianCPU() =====
|
||||
{{anchor:torch.DiskFile.isLittleEndianCPU}}
|
||||
|
||||
Returns ''true'' if, and only if, the computer CPU operates in //little endian//.
|
||||
//Little end first//: increasing numeric significance with increasing
|
||||
memory addresses.
|
||||
|
||||
===== littleEndianEncoding() =====
|
||||
{{anchor:torch.DiskFile.littleEndianEncoding}}
|
||||
|
||||
In [[file#torch.File.binary|binary]] mode, force encoding in //little endian//.
|
||||
(//little end first//: increasing numeric significance with increasing memory
|
||||
addresses)
|
||||
|
||||
===== nativeEndianEncoding() =====
|
||||
{{anchor:torch.DiskFile.nativeEndianEncoding}}
|
||||
|
||||
In [[file#torch.File.binary|binary]] mode, force encoding in //native endian//.
|
||||
|
||||
333
dok/file.dok
Normal file
333
dok/file.dok
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
====== File ======
|
||||
{{anchor:torch.File.dok}}
|
||||
|
||||
This is an //abstract// class. It defines most methods implemented by its
|
||||
child classes, like [[DiskFile|DiskFile]],
|
||||
[[MemoryFile|MemoryFile]] and [[PipeFile|PipeFile]].
|
||||
|
||||
Methods defined here are intended for basic read/write functionalities.
|
||||
Read/write methods might write in [[#torch.File.ascii|ASCII]] mode or
|
||||
[[#torch.File.binary|binary]] mode.
|
||||
|
||||
In [[#torch.File.ascii|ASCII]] mode, numbers are converted in human readable
|
||||
format (characters). Booleans are converted into ''0'' (false) or ''1'' (true).
|
||||
In [[#torch.File.binary|binary]] mode, numbers and boolean are directly encoded
|
||||
as represented in a register of the computer. While not being human
|
||||
readable and less portable, the binary mode is obviously faster.
|
||||
|
||||
In [[#torch.File.ascii|ASCII]] mode, if the default option
|
||||
[[#torch.File.autoSpacing|autoSpacing()]] is chosen, a space will be generated
|
||||
after each written number or boolean. A carriage return will also be added
|
||||
after each call to a write method. With this option, the spaces are
|
||||
supposed to exist while reading. This option can be deactivated with
|
||||
[[#torch.File.noAutoSpacing|noAutoSpacing()]].
|
||||
|
||||
A ''Lua'' error might or might be not generated in case of read/write error
|
||||
or problem in the file. This depends on the choice made between
|
||||
[[#torch.File.quiet|quiet()]] and [[#torch.File.pedantic|pedantic()]] options. It
|
||||
is possible to query if an error occured in the last operation by calling
|
||||
[[#torch.File.hasError|hasError()]].
|
||||
|
||||
===== Read methods =====
|
||||
{{anchor:torch.File.read}}
|
||||
{{anchor:torch.File.readBool}}
|
||||
{{anchor:torch.File.readByte}}
|
||||
{{anchor:torch.File.readChar}}
|
||||
{{anchor:torch.File.readShort}}
|
||||
{{anchor:torch.File.readInt}}
|
||||
{{anchor:torch.File.readLong}}
|
||||
{{anchor:torch.File.readFloat}}
|
||||
{{anchor:torch.File.readDouble}}
|
||||
|
||||
They are three types of reading methods:
|
||||
- ''[number] readTYPE()''
|
||||
- ''[TYPEStorage] readTYPE(n)''
|
||||
- ''[number] readTYPE(TYPEStorage)''
|
||||
|
||||
where ''TYPE'' can be either ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'', ''Float'' or ''Double''.
|
||||
|
||||
A convenience method also exist for boolean types: ''[boolean] readBool()''. It reads
|
||||
a value on the file with ''readInt()'' and returns ''true'' if and only if this value is ''1''. It is not possible
|
||||
to read storages of booleans.
|
||||
|
||||
All these methods depends on the encoding choice: [[#torch.File.ascii|ASCII]]
|
||||
or [[#torch.File.binary|binary]] mode. In [[#torch.File.ascii|ASCII]] mode, the
|
||||
option [[#torch.File.autoSpacing|autoSpacing()]] and
|
||||
[[#torch.File.noAutoSpacing|noAutoSpacing()]] have also an effect on these
|
||||
methods.
|
||||
|
||||
If no parameter is given, one element is returned. This element is
|
||||
converted to a ''Lua'' number when reading.
|
||||
|
||||
If ''n'' is given, ''n'' values of the specified type are read
|
||||
and returned in a new [[Storage|Storage]] of that particular type.
|
||||
The storage size corresponds to the number of elements actually read.
|
||||
|
||||
If a ''Storage'' is given, the method will attempt to read a number of elements
|
||||
equals to the size of the given storage, and fill up the storage with these elements.
|
||||
The number of elements actually read is returned.
|
||||
|
||||
In case of read error, these methods will call the ''Lua'' error function using the default
|
||||
[[#torch.File.pedantic|pedantic]] option, or stay quiet with the [[#torch.File.quiet|quiet]]
|
||||
option. In the latter case, one can check if an error occurred with
|
||||
[[#torch.File.hasError|hasError()]].
|
||||
|
||||
===== Write methods =====
|
||||
{{anchor:torch.File.write}}
|
||||
{{anchor:torch.File.writeBool}}
|
||||
{{anchor:torch.File.writeByte}}
|
||||
{{anchor:torch.File.writeChar}}
|
||||
{{anchor:torch.File.writeShort}}
|
||||
{{anchor:torch.File.writeInt}}
|
||||
{{anchor:torch.File.writeLong}}
|
||||
{{anchor:torch.File.writeFloat}}
|
||||
{{anchor:torch.File.writeDouble}}
|
||||
|
||||
They are two types of reading methods:
|
||||
- ''[number] writeTYPE(number)''
|
||||
- ''[number] writeTYPE(TYPEStorage)''
|
||||
|
||||
where ''TYPE'' can be either ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'', ''Float'' or ''Double''.
|
||||
|
||||
A convenience method also exist for boolean types: ''writeBool(value)''. If ''value'' is ''nil'' or
|
||||
not ''true'' a it is equivalent to a ''writeInt(0)'' call, else to ''writeInt(1)''. It is not possible
|
||||
to write storages of booleans.
|
||||
|
||||
All these methods depends on the encoding choice: [[#torch.File.ascii|ASCII]]
|
||||
or [[#torch.File.ascii|binary]] mode. In [[#torch.File.ascii|ASCII]] mode, the
|
||||
option [[#torch.File.autoSpacing|autoSpacing()]] and
|
||||
[[#torch.File.noAutoSpacing|noAutoSpacing()]] have also an effect on these
|
||||
methods.
|
||||
|
||||
If one ''Lua'' number is given, this number is converted according to the
|
||||
name of the method when writing (e.g. ''writeInt(3.14)'' will write ''3'').
|
||||
|
||||
If a ''Storage'' is given, the method will attempt to write all the elements contained
|
||||
in the storage.
|
||||
|
||||
These methods return the number of elements actually written.
|
||||
|
||||
In case of read error, these methods will call the ''Lua'' error function using the default
|
||||
[[#torch.File.pedantic|pedantic]] option, or stay quiet with the [[#torch.File.quiet|quiet]]
|
||||
option. In the latter case, one can check if an error occurred with
|
||||
[[#torch.File.hasError|hasError()]].
|
||||
|
||||
===== Serialization methods =====
|
||||
{{anchor:torch.File.serialization}}
|
||||
|
||||
These methods allow the user to save any serializable objects on disk and
|
||||
reload it later in its original state. In other words, it can perform a
|
||||
//deep// copy of an object into a given ''File''.
|
||||
|
||||
Serializable objects are ''Torch'' objects having a ''read()'' and
|
||||
''write()'' method. ''Lua'' objects such as ''table'', ''number'' or
|
||||
''string'' or //pure Lua// functions are also serializable.
|
||||
|
||||
If the object to save contains several other objects (let say it is a tree
|
||||
of objects), then objects appearing several times in this tree will be
|
||||
//saved only once//. This saves disk space, speedup loading/saving and
|
||||
respect the dependencies between objects.
|
||||
|
||||
Interestingly, if the ''File'' is a [[MemoryFile|MemoryFile]], it allows
|
||||
the user to easily make a //clone// of any serializable object:
|
||||
<file lua>
|
||||
file = torch.MemoryFile() -- creates a file in memory
|
||||
file:writeObject(object) -- writes the object into file
|
||||
file:seek(1) -- comes back at the beginning of the file
|
||||
objectClone = file:readObject() -- gets a clone of object
|
||||
</file>
|
||||
|
||||
==== readObject() ====
|
||||
{{anchor:torch.File.readObject}}
|
||||
|
||||
Returns the next [[#torch.File.serialization|serializable]] object saved beforehand
|
||||
in the file with [[#torch.File.writeObject|writeObject()]].
|
||||
|
||||
Note that objects which were [[#torch.File.writeObject|written]] with the same
|
||||
reference have still the same reference after loading.
|
||||
|
||||
Example:
|
||||
<file lua>
|
||||
-- creates an array which contains twice the same tensor
|
||||
array = {}
|
||||
x = torch.Tensor(1)
|
||||
table.insert(array, x)
|
||||
table.insert(array, x)
|
||||
|
||||
-- array[1] and array[2] refer to the same address
|
||||
-- x[1] == array[1][1] == array[2][1] == 3.14
|
||||
array[1][1] = 3.14
|
||||
|
||||
-- write the array on disk
|
||||
file = torch.DiskFile('foo.asc', 'w')
|
||||
file:writeObject(array)
|
||||
file:close() -- make sure the data is written
|
||||
|
||||
-- reload the array
|
||||
file = torch.DiskFile('foo.asc', 'r')
|
||||
arrayNew = file:readObject()
|
||||
|
||||
-- arrayNew[1] and arrayNew[2] refer to the same address!
|
||||
-- arrayNew[1][1] == arrayNew[2][1] == 3.14
|
||||
-- so if we do now:
|
||||
arrayNew[1][1] = 2.72
|
||||
-- arrayNew[1][1] == arrayNew[2][1] == 2.72 !
|
||||
</file>
|
||||
|
||||
==== writeObject(object) ====
|
||||
{{anchor:torch.File.writeObject}}
|
||||
|
||||
Writes ''object'' into the file. This object can be read later using
|
||||
[[#torch.File.readObject|readObject()]]. Serializable objects are ''Torch''
|
||||
objects having a ''read()'' and ''write()'' method. ''Lua'' objects such as
|
||||
''table'', ''number'' or ''string'' or pure Lua functions are also serializable.
|
||||
|
||||
If the object has been already written in the file, only a //reference// to
|
||||
this already saved object will be written: this saves space an speed-up
|
||||
writing; it also allows to keep the dependencies between objects intact.
|
||||
|
||||
In returns, if one writes an object, modify its member, and write the
|
||||
object again in the same file, the modifications will not be recorded
|
||||
in the file, as only a reference to the original will be written. See
|
||||
[[#torch.File.readObject|readObject()]] for an example.
|
||||
|
||||
==== [string] readString(format) ====
|
||||
{{anchor:torch.File.readString}}
|
||||
|
||||
If ''format'' starts with ''"*l"'' then returns the next line in the ''File''. The end-of-line character is skipped.
|
||||
|
||||
If ''format'' starts with ''"*a"'' then returns all the remaining contents of the ''File''.
|
||||
|
||||
If no data is available, then an error is raised, except if ''File'' is in [[#torch.File.quiet|quiet()]] mode where
|
||||
it then returns ''nil''.
|
||||
|
||||
Because Torch is more precised on number typing, the ''Lua'' format ''"*n"'' is not supported:
|
||||
instead use one of the [[#torch.File.read|number read methods]].
|
||||
|
||||
==== [number] writeString(str) ====
|
||||
{{anchor:torch.File.writeString}}
|
||||
|
||||
Writes the string ''str'' in the ''File''. If the string cannot be written completely an error is raised, except
|
||||
if ''File'' is in [[#torch.File.quiet|quiet()]] mode where it returns the number of character actually written.
|
||||
|
||||
===== ascii() [default] =====
|
||||
{{anchor:torch.File.ascii}}
|
||||
|
||||
The data read or written will be in ''ASCII'' mode: all numbers are converted
|
||||
to characters (human readable format) and boolean are converted to ''0''
|
||||
(false) or ''1'' (true). The input-output format in this mode depends on the
|
||||
options [[#torch.File.autoSpacing|autoSpacing()]] and
|
||||
[[#torch.File.noAutoSpacing|noAutoSpacing()]].
|
||||
|
||||
===== autoSpacing() [default] =====
|
||||
{{anchor:torch.File.autoSpacing}}
|
||||
|
||||
In [[#torch.File.ascii|ASCII]] mode, write additional spaces around the elements
|
||||
written on disk: if writing a [[Storage|Storage]], a space will be
|
||||
generated between each //element// and a //return line// after the last
|
||||
element. If only writing one element, a //return line// will be generated
|
||||
after this element.
|
||||
|
||||
Those spaces are supposed to exist while reading in this mode.
|
||||
|
||||
This is the default behavior. You can de-activate this option with the
|
||||
[[#torch.File.noAutoSpacing|noAutoSpacing()]] method.
|
||||
|
||||
===== binary() =====
|
||||
{{anchor:torch.File.binary}}
|
||||
|
||||
The data read or written will be in binary mode: the representation in the
|
||||
''File'' is the same that the one in the computer memory/register (not human
|
||||
readable). This mode is faster than [[#torch.File.ascii|ASCII]] but less
|
||||
portable.
|
||||
|
||||
===== clearError() =====
|
||||
{{anchor:torch.File.clearError}}
|
||||
|
||||
Clear the error.flag returned by [[#torch.File.hasError|hasError()]].
|
||||
|
||||
===== close() =====
|
||||
{{anchor:torch.File.close}}
|
||||
|
||||
Close the file. Any subsequent operation will generate a ''Lua'' error.
|
||||
|
||||
===== noAutoSpacing() =====
|
||||
{{anchor:torch.File.noAutoSpacing}}
|
||||
|
||||
In [[#torch.File.ascii|ASCII]] mode, do not put extra spaces between element
|
||||
written on disk. This is the contrary of the option
|
||||
[[#torch.File.autoSpacing|autoSpacing()]].
|
||||
|
||||
===== synchronize() =====
|
||||
{{anchor:torch.File.synchronize}}
|
||||
|
||||
If the child class bufferize the data while writing, ensure that the data
|
||||
is actually written.
|
||||
|
||||
|
||||
===== pedantic() [default] =====
|
||||
{{anchor:torch.File.pedantic}}
|
||||
|
||||
If this mode is chosen (which is the default), a ''Lua'' error will be
|
||||
generated in case of error (which will cause the program to stop).
|
||||
|
||||
It is possible to use [[#torch.File.quiet|quiet()]] to avoid ''Lua'' error generation
|
||||
and set a flag instead.
|
||||
|
||||
===== [number] position() =====
|
||||
{{anchor:torch.File.position}}
|
||||
|
||||
Returns the current position (in bytes) in the file.
|
||||
The first position is ''1'' (following Lua standard indexing).
|
||||
|
||||
===== quiet() =====
|
||||
{{anchor:torch.File.quiet}}
|
||||
|
||||
If this mode is chosen instead of [[#torch.File.pedantic|pedantic()]], no ''Lua''
|
||||
error will be generated in case of read/write error. Instead, a flag will
|
||||
be raised, readable through [[#torch.File.hasError|hasError()]]. This flag can
|
||||
be cleared with [[#torch.File.clearError|clearError()]]
|
||||
|
||||
Checking if a file is quiet can be performed using [[#torch.File.isQuiet|isQuiet()]].
|
||||
|
||||
===== seek(position) =====
|
||||
{{anchor:torch.File.seek}}
|
||||
|
||||
Jump into the file at the given ''position'' (in byte). Might generate/raise
|
||||
an error in case of problem. The first position is ''1'' (following Lua standard indexing).
|
||||
|
||||
===== seekEnd() =====
|
||||
{{anchor:torch.File.seekEnd}}
|
||||
|
||||
Jump at the end of the file. Might generate/raise an error in case of
|
||||
problem.
|
||||
|
||||
===== File state query =====
|
||||
|
||||
These methods allow the user to query the state of the given ''File''.
|
||||
|
||||
==== [boolean] hasError() ====
|
||||
{{anchor:torch.File.hasError}}
|
||||
|
||||
Returns if an error occurred since the last [[#torch.File.clearError|clearError()]] call, or since
|
||||
the opening of the file if ''clearError()'' has never been called.
|
||||
|
||||
==== [boolean] isQuiet() ====
|
||||
{{anchor:torch.File.isQuiet}}
|
||||
|
||||
Returns a boolean which tells if the file is in [[#torch.File.quiet|quiet]] mode or not.
|
||||
|
||||
==== [boolean] isReadable() ====
|
||||
{{anchor:torch.File.isReadable}}
|
||||
|
||||
Tells if one can read the file or not.
|
||||
|
||||
==== [boolean] isWritable() ====
|
||||
{{anchor:torch.File.isWritable}}
|
||||
|
||||
Tells if one can write in the file or not.
|
||||
|
||||
==== [boolean] isAutoSpacing() ====
|
||||
{{anchor:torch.File.isAutoSpacing}}
|
||||
|
||||
Return ''true'' if [[#torch.File.autoSpacing|autoSpacing]] has been chosen.
|
||||
39
dok/index.dok
Normal file
39
dok/index.dok
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
====== Torch Package Reference Manual ======
|
||||
{{anchor:torch.reference.dok}}
|
||||
|
||||
The **''torch''** package contains basic classes used everywhere in ''Torch7''.
|
||||
|
||||
//Input-output management// is provided with [[File|File]] (abstract class), [[DiskFile|DiskFile]] (file on disk),
|
||||
[[MemoryFile|MemoryFile]] (file in ''RAM'') and [[PipeFile|PipeFile]] (file from a piped command). These
|
||||
classes also handle //serialization//.
|
||||
|
||||
[[Storage|Storage]] and [[Tensor|Tensor]] are the basic bricks for //powerful numeric operations//. Tensors support
|
||||
a wide variety of fundamental [[maths|math operations]].
|
||||
|
||||
[[Timer|Timer]] is provided for //measuring time//.
|
||||
|
||||
[[Tester|Tester]] is provided as a generic testing framework and it is also used by [[..:nn:index|nn]] package.
|
||||
|
||||
[[CmdLine|CmdLine]] is provided as a command line argument parsing utility.
|
||||
|
||||
Finally, ''Torch'' provides some [[Utility|utility functions]] for creating and handling ''Torch'' //classes//,
|
||||
as well as support for [[random|random number generation]].
|
||||
|
||||
===== Torch Packages =====
|
||||
{{anchor:torch.reference.dok}}
|
||||
|
||||
* File I/O Interface Library
|
||||
* [[File|File]] is an abstract interface for common file operations.
|
||||
* [[DiskFile|Disk File]] defines operations on files stored on disk.
|
||||
* [[MemoryFile|Memory File]] defines operations on stored in RAM.
|
||||
* [[PipeFile|Pipe File]] defines operations for using piped commands.
|
||||
* Tensor Library
|
||||
* [[Storage|Storage]] defines a simple storage interface that controls the underlying storage for any tensor object.
|
||||
* [[Tensor|Tensor]] defines the //all powerful// tensor object that defines multi-dimensional numerical arrays with type templating.
|
||||
* [[maths|Mathemetical operations]] are defined for the tensor object types.
|
||||
* Useful Utilities
|
||||
* [[Timer|Timer]] provides functionality for //measuring time//.
|
||||
* [[Tester|Tester]] is a generic tester framework.
|
||||
* [[CmdLine|CmdLine]] is a command line argument parsing utility.
|
||||
* [[Random|Random]] defines a random number generator package with various distributions.
|
||||
* Finally useful [[Utility|utility] functions are provided for easy handling of torch tensor types and class inheritance.
|
||||
804
dok/maths.dok
Normal file
804
dok/maths.dok
Normal file
|
|
@ -0,0 +1,804 @@
|
|||
====== Math Functions ======
|
||||
{{anchor:torch.maths.dok}}
|
||||
|
||||
Torch provides Matlab-like functions for manipulating
|
||||
[[index#Tensor|Tensor]] objects. Functions fall into several types of
|
||||
categories:
|
||||
* [[#torch.construction.dok|constructors]] like [[#torch.zeros|zeros]], [[#torch.ones|ones]]
|
||||
* extractors like [[#torch.diag|diag]] and [[#torch.triu|triu]],
|
||||
* [[#torch.elementwise.dok|element-wise]] operations like [[#torch.abs|abs]] and [[#torch.pow|pow]],
|
||||
* [[#torch.columnwise.dok|column or row-wise operations]] like [[#torch.sum|sum]] and [[#torch.max|max]],
|
||||
* [[#torch.matrixwide.dok|matrix-wide operations]] like [[#torch.trace|trace]] and [[#torch.norm|norm]].
|
||||
* [[#torch.conv.dok|Convolution and cross-correlation]] operations like [[#torch.conv2|conv2]].
|
||||
* [[#torch.linalg.dok|Basic linear algebra operations]] like [[#torch.eig|eigen value/vector calculation]], [[#torch.svd|singular value decomposition (svd)]] and [[#torch.gesv|linear system solution]].
|
||||
|
||||
By default, all operations allocate a new tensor to return the
|
||||
result. However, all functions also support passing the resulting(s)
|
||||
tensor(s) as the first argument(s), in which case the resulting tensor(s)
|
||||
will be resized accordingly and filled with result.
|
||||
|
||||
For example, ''torch.conv2'' function can be used in the following manner.
|
||||
<file lua>
|
||||
|
||||
x = torch.rand(100,100)
|
||||
k = torch.rand(10,10)
|
||||
res1 = torch.conv2(x,k)
|
||||
|
||||
res2 = torch.Tensor()
|
||||
torch.conv2(res2,x,k)
|
||||
|
||||
=res2:dist(res1)
|
||||
0
|
||||
|
||||
</file>
|
||||
|
||||
The advantage of second case is, same ''res2'' tensor can be used successively in a loop without any new allocation.
|
||||
|
||||
<file lua>
|
||||
-- no new memory allocations...
|
||||
for i=1,100 do
|
||||
torch.conv2(res2,x,k)
|
||||
end
|
||||
=res2:dist(res1)
|
||||
0
|
||||
</file>
|
||||
|
||||
====== Construction or extraction functions ======
|
||||
{{anchor:torch.construction.dok}}
|
||||
|
||||
===== torch.cat( [res,] x_1, x_2, [dimension] ) =====
|
||||
{{anchor:torch.cat}}
|
||||
|
||||
''x=torch.cat(x_1,x_2,[dimension])'' returns a tensor ''x'' which is the concatenation of tensors x_1 and x_2 along dimension ''dimension''.
|
||||
|
||||
If ''dimension'' is not specified it is 1.
|
||||
|
||||
The other dimensions of x_1 and x_2 have to be equal.
|
||||
|
||||
Examples:
|
||||
<file lua>
|
||||
> print(torch.cat(torch.ones(3),torch.zeros(2)))
|
||||
|
||||
1
|
||||
1
|
||||
1
|
||||
0
|
||||
0
|
||||
[torch.Tensor of dimension 5]
|
||||
|
||||
|
||||
> print(torch.cat(torch.ones(3,2),torch.zeros(2,2)))
|
||||
|
||||
1 1
|
||||
1 1
|
||||
1 1
|
||||
0 0
|
||||
0 0
|
||||
[torch.DoubleTensor of dimension 5x2]
|
||||
|
||||
|
||||
> print(torch.cat(torch.ones(2,2),torch.zeros(2,2)))
|
||||
1 1
|
||||
1 1
|
||||
0 0
|
||||
0 0
|
||||
[torch.DoubleTensor of dimension 4x2]
|
||||
|
||||
> print(torch.cat(torch.ones(2,2),torch.zeros(2,2),2))
|
||||
1 1 0 0
|
||||
1 1 0 0
|
||||
[torch.DoubleTensor of dimension 2x4]
|
||||
|
||||
|
||||
> print(torch.cat(torch.cat(torch.ones(2,2),torch.zeros(2,2)),torch.rand(3,2)))
|
||||
|
||||
1.0000 1.0000
|
||||
1.0000 1.0000
|
||||
0.0000 0.0000
|
||||
0.0000 0.0000
|
||||
0.3227 0.0493
|
||||
0.9161 0.1086
|
||||
0.2206 0.7449
|
||||
[torch.DoubleTensor of dimension 7x2]
|
||||
|
||||
</file>
|
||||
|
||||
|
||||
===== torch.diag( [res,] x) =====
|
||||
{{anchor:torch.diag}}
|
||||
|
||||
''y=torch.diag(x)'' when x is of dimension 1 returns a diagonal matrix with diagonal elements constructed from x.
|
||||
|
||||
''y=torch.diag(x)'' when x is of dimension 2 returns a tensor of dimension 1
|
||||
with elements constructed from the diagonal of x.
|
||||
|
||||
''y=torch.diag(x,k)'' returns the k-th diagonal of x,
|
||||
wher k=0 is the main diagonal, k>0 is above the main diagonal and k<0
|
||||
is below the main diagonal.
|
||||
|
||||
===== torch.eye( [res,] n) =====
|
||||
{{anchor:torch.eye}}
|
||||
|
||||
''y=torch.eye(n)'' returns the n-by-n identity matrix.
|
||||
|
||||
''y=torch.eye(m,n)'' returns an m-by-n identity matrix with ones on the diagonal and zeros elsewhere.
|
||||
|
||||
|
||||
===== torch.linspace( [res,] x1,x2) =====
|
||||
{{anchor:torch.linspace}}
|
||||
|
||||
''y=torch.linspace(x1,x2)'' returns a one-dimensional tensor of size 100 equally spaced points between x1 and x2.
|
||||
|
||||
''y=torch.linspace(x1,x2,n)'' returns a one-dimensional tensor of n equally spaced points between x1 and x2.
|
||||
|
||||
|
||||
===== torch.logspace( [res,] x1, x2) =====
|
||||
{{anchor:torch.logspace}}
|
||||
|
||||
''y=torch.logspace(x1,x2)'' returns a one-dimensional tensor of 50 logarithmically eqally spaced points between x1 and x2.
|
||||
|
||||
''y=torch.logspace(x1,x2,n)'' returns a one-dimensional tensor of n logarithmically equally spaced points between x1 and x2.
|
||||
|
||||
===== torch.ones( [res,] m) =====
|
||||
{{anchor:torch.ones}}
|
||||
|
||||
''y=torch.ones(n)'' returns a one-dimensional tensor of size n filled with ones.
|
||||
|
||||
''y=torch.ones(m,n)'' returns a mxn tensor filled with ones.
|
||||
|
||||
''y=torch.ones(m,n,k)'' returns a mxnxk tensor filled with ones.
|
||||
|
||||
''y=torch.ones(d1,...,d_n)'' returns an n-dimensional tensor with sizes d1, ..., d_n filled with ones.
|
||||
|
||||
===== torch.rand( [res,] m [, n, k, ...]) =====
|
||||
{{anchor:torch.rand}}
|
||||
|
||||
''y=torch.rand(n)'' returns a one-dimensional tensor of size n filled with random numbers from a uniform distribution on the interval (0,1).
|
||||
|
||||
''y=torch.rand(m,n)'' returns a mxn tensor of random numbers from a uniform distribution on the interval (0,1).
|
||||
|
||||
===== torch.randn( [res,] m [, n, k, ...]) =====
|
||||
{{anchor:torch.randn}}
|
||||
|
||||
''y=torch.randn(n)'' returns a one-dimensional tensor of size n filled with random numbers from a normal distribution with mean zero and variance one.
|
||||
|
||||
''y=torch.randn(m,n)'' returns a mxn tensor of random numbers from a normal distribution with mean zero and variance one.
|
||||
|
||||
===== torch.range([res,] n,m) =====
|
||||
{{anchor:torch.range}}
|
||||
|
||||
''y=torch.range(n,m)'' returns a tensor of size m-n+1x1 with integer
|
||||
values n to m.
|
||||
|
||||
<file lua>
|
||||
> print(torch.range(2,5))
|
||||
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
[torch.Tensor of dimension 4]
|
||||
</file>
|
||||
|
||||
''y=torch.range(n,m,incr)'' returns a tensor filled in range n to m with incr increments.
|
||||
<file lua>
|
||||
print(torch.range(2,5,1.2))
|
||||
2.0000
|
||||
3.2000
|
||||
4.4000
|
||||
[torch.DoubleTensor of dimension 3]
|
||||
</file>
|
||||
|
||||
===== torch.randperm([res,] n) =====
|
||||
{{anchor:torch.randperm}}
|
||||
|
||||
''y=torch.randperm(n)'' returns a randomly ordered nx1 tensor of the integers from 1 to n.
|
||||
|
||||
===== torch.reshape([res,] x,m,n) =====
|
||||
{{anchor:torch.reshape}}
|
||||
|
||||
''y=torch.reshape(x,m,n)'' returns a new mxn tensor y whose elements
|
||||
are taken rowwise from x, which must have m*n elements. The elements are copied into the new tensor.
|
||||
|
||||
===== torch.tril([res,] x) =====
|
||||
{{anchor:torch.tril}}
|
||||
|
||||
''y=torch.tril(x)'' returns the lower triangular part of x, the other elements of y are set to 0.
|
||||
|
||||
''torch.tril(x,k)'' returns the elements on and below the k-th diagonal of x as non-zero. k=0 is the main diagonal, k>0 is above the main diagonal and k<0
|
||||
is below the main diagonal.
|
||||
|
||||
===== torch.triu([res,] x) =====
|
||||
{{anchor:torch.triu}}
|
||||
|
||||
''y=torch.triu(x)'' returns the upper triangular part of x,
|
||||
the other elements of y are set to 0.
|
||||
|
||||
''torch.triu(x,k)'' returns the elements on and above the k-th diagonal of x as non-zero. k=0 is the main diagonal, k>0 is above the main diagonal and k<0
|
||||
is below the main diagonal.
|
||||
|
||||
===== torch.zeros([res,] x) =====
|
||||
{{anchor:torch.zeros}}
|
||||
|
||||
''y=torch.zeros(n)'' returns a one-dimensional tensor of size n filled with zeros.
|
||||
|
||||
''y=torch.zeros(m,n)'' returns a mxn tensor filled with zeros.
|
||||
|
||||
|
||||
====== Element-wise operations ======
|
||||
{{anchor:torch.elementwise.dok}}
|
||||
|
||||
===== torch.abs([res,] x) =====
|
||||
{{anchor:torch.abs}}
|
||||
|
||||
''y=torch.abs(x)'' returns the absolute values of the elements of x.
|
||||
|
||||
===== torch.acos([res,] x) =====
|
||||
{{anchor:torch.acos}}
|
||||
|
||||
''y=torch.acos(x)'' returns the arcosine of the elements of x.
|
||||
|
||||
===== torch.asin([res,] x) =====
|
||||
{{anchor:torch.asin}}
|
||||
|
||||
''y=torch.asin(x)'' returns the arcsine of the elements of x.
|
||||
|
||||
===== torch.atan([res,] x) =====
|
||||
{{anchor:torch.atan}}
|
||||
|
||||
''y=torch.atan(x)'' returns the arctangent of the elements of x.
|
||||
|
||||
===== torch.ceil([res,] x) =====
|
||||
{{anchor:torch.ceil}}
|
||||
|
||||
''y=torch.ceil(x)'' returns the values of the elements of x rounded up to the nearest integers.
|
||||
|
||||
===== torch.cos([res,] x) =====
|
||||
{{anchor:torch.cos}}
|
||||
|
||||
''y=torch.cos(x)'' returns the cosine of the elements of x.
|
||||
|
||||
===== torch.cosh([res,] x) =====
|
||||
{{anchor:torch.cosh}}
|
||||
|
||||
''y=torch.cosh(x)'' returns the hyberbolic cosine of the elements of x.
|
||||
|
||||
===== torch.exp[res,] (x) =====
|
||||
{{anchor:torch.exp}}
|
||||
|
||||
''y=torch.exp(x)'' returns, for each element in x, e (the base of natural logarithms) raised to the power of the element in x.
|
||||
|
||||
===== torch.floor([res,] x) =====
|
||||
{{anchor:torch.floor}}
|
||||
|
||||
''y=torch.floor(x)'' returns the values of the elements of x rounded down to the nearest integers.
|
||||
|
||||
===== torch.log[res,] (x) =====
|
||||
{{anchor:torch.log}}
|
||||
|
||||
''y=torch.log(x)'' returns the natural logarithm of the elements of x.
|
||||
|
||||
===== torch.pow([res,] x) =====
|
||||
{{anchor:torch.pow}}
|
||||
|
||||
''y=torch.pow(x,n)'' returns the elements of x to the power of n.
|
||||
|
||||
===== torch.sin([res,] x) =====
|
||||
{{anchor:torch.sin}}
|
||||
|
||||
''y=torch.sin(x)'' returns the sine of the elements of x.
|
||||
|
||||
===== torch.sinh([res,] x) =====
|
||||
{{anchor:torch.sinh}}
|
||||
|
||||
''y=torch.sinh(x)'' returns the hyperbolic sine of the elements of x.
|
||||
|
||||
===== torch.sqrt([res,] x) =====
|
||||
{{anchor:torch.sqrt}}
|
||||
|
||||
''y=torch.sqrt(x)'' returns the square root of the elements of x.
|
||||
|
||||
===== torch.tan([res,] x) =====
|
||||
{{anchor:torch.tan}}
|
||||
|
||||
''y=torch.abs(x)'' returns the tangent of the elements of x.
|
||||
|
||||
===== torch.tanh([res,] x) =====
|
||||
{{anchor:torch.tanh}}
|
||||
|
||||
''y=torch.tanh(x)'' returns the hyperbolic tangent of the elements of x.
|
||||
|
||||
====== Column or row-wise operations (dimension-wise operations) ======
|
||||
{{anchor:torch.columnwise.dok}}
|
||||
|
||||
===== torch.cross([res,] a,b) =====
|
||||
{{anchor:torch.cross}}
|
||||
|
||||
''y=torch.cross(a,b)'' returns the cross product of the tensors a and b.
|
||||
a and b must be 3 element vectors.
|
||||
|
||||
''y=cross(a,b)'' returns the cross product of a and b along the first dimension of length 3.
|
||||
|
||||
''y=cross(a,b,n)'', where a and b returns the cross
|
||||
product of vectors in dimension n of a and b.
|
||||
a and b must have the same size,
|
||||
and both a:size(n) and b:size(n) must be 3.
|
||||
|
||||
|
||||
===== torch.cumprod([res,] x) =====
|
||||
{{anchor:torch.cumprod}}
|
||||
|
||||
''y=torch.cumprod(x)'' returns the cumulative product of the elements of x, performing the operation over the last dimension.
|
||||
|
||||
''y=torch.cumprod(x,n)'' returns the cumulative product of the elements of x, performing the operation over dimension n.
|
||||
|
||||
===== torch.cumsum([res,] x) =====
|
||||
{{anchor:torch.cumsum}}
|
||||
|
||||
''y=torch.cumsum(x)'' returns the cumulative product of the elements of x, performing the operation over the first dimension.
|
||||
|
||||
''y=torch.cumsum(x,n)'' returns the cumulative product of the elements of x, performing the operation over dimension n.
|
||||
|
||||
===== torch.max([resval, resind, ] x) =====
|
||||
{{anchor:torch.max}}
|
||||
|
||||
''y,i=torch.max(x)'' returns a tensor y of the largest element in
|
||||
each row of x, and a tensor i of their corresponding indices in x.
|
||||
|
||||
''y,i=torch.max(x,1)'' performs the max operation for each row and
|
||||
''y,i=torch.max(x,n)'' performs the max operation over the dimension n.
|
||||
|
||||
|
||||
===== torch.mean([res,] x) =====
|
||||
{{anchor:torch.mean}}
|
||||
|
||||
''y=torch.mean(x)'' returns a tensor y of the mean of the elements in
|
||||
each row of x.
|
||||
|
||||
''y=torch.mean(x,2)'' performs the mean operation for each row and
|
||||
''y=torch.mean(x,n)'' performs the mean operation over the dimension n.
|
||||
|
||||
===== torch.min([resval, resind, ] x) =====
|
||||
{{anchor:torch.min}}
|
||||
|
||||
''y,i=torch.min(x)'' returns a tensor y of the smallest element in
|
||||
each row of x, and a tensor i of their corresponding indices in x.
|
||||
|
||||
''y,i=torch.min(x,2)'' performs the min operation for each row and
|
||||
''y,i=torch.min(x,n)'' performs the min operation over the dimension n.
|
||||
|
||||
|
||||
===== torch.prod([res,] x) =====
|
||||
{{anchor:torch.prod}}
|
||||
|
||||
''y=torch.prod(x)'' returns a tensor y of the product of the elements in
|
||||
each row of x.
|
||||
|
||||
''y=torch.prod(x,2)'' performs the prod operation for each row and
|
||||
''y=torch.prod(x,n)'' performs the prod operation over the dimension n.
|
||||
|
||||
===== torch.sort([resval, resind, ] x) =====
|
||||
{{anchor:torch.sort}}
|
||||
|
||||
''y,i=torch.sort(x)'' returns a tensor y of the sorted
|
||||
rows of x, and a tensor i of the corresponding indices from x.
|
||||
|
||||
''y,i=torch.sort(x,2)'' performs the sort operation for each row and
|
||||
''y,i=torch.sort(x,n)'' performs the sort operation over the dimension n.
|
||||
|
||||
===== torch.std([res,] x) =====
|
||||
{{anchor:torch.std}}
|
||||
|
||||
''y=torch.std(x)'' returns a tensor y of the standard deviation of the elements in
|
||||
each row of x.
|
||||
|
||||
''torch.std(x)'' normalizes by (n-1) where n is the number of elements. This
|
||||
makes torch.sum(torch.pow(torch.std(x),2))
|
||||
the best unbiased estimate of the variance if x
|
||||
is a sample from a normal distribution.
|
||||
|
||||
''y=torch.std(x,true)'' performs the std operation normalizing by n instead of n-1.
|
||||
|
||||
''y=torch.std(x,false)'' performs the std operation normalizing by n-1.
|
||||
|
||||
''y=torch.std(x,flag,n)'' performs the std operation over the dimension n.
|
||||
|
||||
|
||||
===== torch.sum([res,] x) =====
|
||||
{{anchor:torch.sum}}
|
||||
|
||||
''y=torch.sum(x)'' returns a tensor y of the sum of the elements in
|
||||
each row of x.
|
||||
|
||||
''y=torch.sum(x,2)'' performs the sum operation for each row and
|
||||
''y=torch.sum(x,n)'' performs the sum operation over the dimension n.
|
||||
|
||||
===== torch.var([res,] x) =====
|
||||
{{anchor:torch.var}}
|
||||
|
||||
''y=torch.var(x)'' returns a tensor y of the standard deviation of the elements in
|
||||
each row of x.
|
||||
|
||||
''torch.var(x)'' normalizes by (n-1) where n is the number of elements. This
|
||||
makes torch.sum(torch.var(x))
|
||||
the best unbiased estimate of the variance if x
|
||||
is a sample from a normal distribution.
|
||||
|
||||
''y=torch.var(x,true)'' performs the var operation normalizing by n instead of n-1.
|
||||
|
||||
''y=torch.var(x,false)'' performs the var operation normalizing by n-1.
|
||||
|
||||
''y=torch.var(x,flag,n)'' performs the var operation over the dimension n.
|
||||
|
||||
====== Matrix-wide operations (tensor-wide operations) ======
|
||||
{{anchor:torch.matrixwide.dok}}
|
||||
|
||||
===== torch.norm(x) =====
|
||||
{{anchor:torch.norm}}
|
||||
|
||||
''y=torch.norm(x)'' returns the 2-norm of the tensor x.
|
||||
|
||||
''y=torch.norm(x,p)'' returns the p-norm of the tensor x.
|
||||
|
||||
|
||||
===== torch.dist(x,y) =====
|
||||
{{anchor:torch.dist}}
|
||||
|
||||
''y=torch.dist(x,y)'' returns the 2-norm of (x-y).
|
||||
|
||||
''y=torch.dist(x,y,p)'' returns the p-norm of (x-y).
|
||||
|
||||
===== torch.numel(x) =====
|
||||
{{anchor:torch.numel}}
|
||||
|
||||
''y=torch.numel(x)'' returns the count of the number of elements in the matrix x.
|
||||
|
||||
===== torch.trace(x) =====
|
||||
{{anchor:torch.trace}}
|
||||
|
||||
''y=torch.trace(x)'' returns the trace (sum of the diagonal elements)
|
||||
of a matrix x. This is equal to the sum of the eigenvalues of x.
|
||||
The returned value ''y'' is a number, not a tensor.
|
||||
|
||||
====== Convolution Operations ======
|
||||
{{anchor:torch.conv.dok}}
|
||||
|
||||
These function implement convolution or cross-correlation of an input
|
||||
image (or set of input images) with a kernel (or set of kernels). The
|
||||
convolution function in Torch can handle different types of
|
||||
input/kernel dimensions and produces corresponding outputs. The
|
||||
general form of operations always remain the same.
|
||||
|
||||
===== torch.conv2([res,] x, k, ['f' or 'v']) =====
|
||||
{{anchor:torch.conv2}}
|
||||
|
||||
This function computes 2 dimensional convolutions between '' x '' and '' k ''. These operations are similar to BLAS operations when number of dimensions of input and kernel are reduced by 2.
|
||||
|
||||
* '' x '' and '' k '' are 2D : convolution of a single image with a single kernel (2D output). This operation is similar to multiplication of two scalars.
|
||||
* '' x '' and '' k '' are 3D : convolution of each input slice with corresponding kernel (3D output).
|
||||
* '' x (p x m x n) '' 3D, '' k (q x p x ki x kj)'' 4D : convolution of all input slices with the corresponding slice of kernel. Output is 3D '' (q x m x n) ''. This operation is similar to matrix vector product of matrix '' k '' and vector '' x ''.
|
||||
|
||||
The last argument controls if the convolution is a full ('f') or valid ('v') convolution. The default is 'valid' convolution.
|
||||
|
||||
<file lua>
|
||||
x=torch.rand(100,100)
|
||||
k=torch.rand(10,10)
|
||||
c = torch.conv2(x,k)
|
||||
=c:size()
|
||||
|
||||
91
|
||||
91
|
||||
[torch.LongStorage of size 2]
|
||||
|
||||
c = torch.conv2(x,k,'f')
|
||||
=c:size()
|
||||
|
||||
109
|
||||
109
|
||||
[torch.LongStorage of size 2]
|
||||
|
||||
</file>
|
||||
|
||||
===== torch.xcorr2([res,] x, k, ['f' or 'v']) =====
|
||||
{{anchor:torch.xcorr2}}
|
||||
|
||||
This function operates with same options and input/output
|
||||
configurations as [[#torch.conv2|torch.conv2]], but performs
|
||||
cross-correlation of the input with the kernel '' k ''.
|
||||
|
||||
===== torch.conv3([res,] x, k, ['f' or 'v']) =====
|
||||
{{anchor:torch.conv3}}
|
||||
|
||||
This function computes 3 dimensional convolutions between '' x '' and '' k ''. These operations are similar to BLAS operations when number of dimensions of input and kernel are reduced by 3.
|
||||
|
||||
* '' x '' and '' k '' are 3D : convolution of a single image with a single kernel (3D output). This operation is similar to multiplication of two scalars.
|
||||
* '' x '' and '' k '' are 4D : convolution of each input slice with corresponding kernel (4D output).
|
||||
* '' x (p x m x n x o) '' 4D, '' k (q x p x ki x kj x kk)'' 5D : convolution of all input slices with the corresponding slice of kernel. Output is 4D '' (q x m x n x o) ''. This operation is similar to matrix vector product of matrix '' k '' and vector '' x ''.
|
||||
|
||||
The last argument controls if the convolution is a full ('f') or valid ('v') convolution. The default is 'valid' convolution.
|
||||
|
||||
<file lua>
|
||||
x=torch.rand(100,100,100)
|
||||
k=torch.rand(10,10,10)
|
||||
c = torch.conv3(x,k)
|
||||
=c:size()
|
||||
|
||||
91
|
||||
91
|
||||
91
|
||||
[torch.LongStorage of size 3]
|
||||
|
||||
c = torch.conv3(x,k,'f')
|
||||
=c:size()
|
||||
|
||||
109
|
||||
109
|
||||
109
|
||||
[torch.LongStorage of size 3]
|
||||
|
||||
</file>
|
||||
|
||||
===== torch.xcorr3([res,] x, k, ['f' or 'v']) =====
|
||||
{{anchor:torch.xcorr3}}
|
||||
|
||||
This function operates with same options and input/output
|
||||
configurations as [[#torch.conv3|torch.conv3]], but performs
|
||||
cross-correlation of the input with the kernel '' k ''.
|
||||
|
||||
====== Eigenvalues, SVD, Linear System Solution ======
|
||||
{{anchor:torch.linalg.dok}}
|
||||
|
||||
Functions in this section are implemented with an interface to LAPACK
|
||||
libraries. If LAPACK libraries are not found during compilation step,
|
||||
then these functions will not be available.
|
||||
|
||||
===== torch.gesv([resb, resa,] b,a [, true]) =====
|
||||
{{anchor:torch.gesv}}
|
||||
|
||||
Solution of '' AX=B '' and ''A'' has to be square and non-singular. ''
|
||||
A '' is '' m x m '', '' X '' is '' m x k '', '' B '' is '' m x k ''.
|
||||
|
||||
If ''resb'' and ''resa'' are given, then they will be used for
|
||||
temporary storage and returning the result.
|
||||
|
||||
* ''resa'' will contain L and U factors for ''LU'' factorization of ''A''.
|
||||
* ''resb'' will contain the solution.
|
||||
|
||||
If ''gesv'' is called with 3 parameters with last parameters ''true'',
|
||||
then ''b'' and ''a'' will destroyed and their output values will be
|
||||
same as ''resa'' and ''resb''.
|
||||
|
||||
<file lua>
|
||||
a=torch.Tensor({{6.80, -2.11, 5.66, 5.97, 8.23},
|
||||
{-6.05, -3.30, 5.36, -4.44, 1.08},
|
||||
{-0.45, 2.58, -2.70, 0.27, 9.04},
|
||||
{8.32, 2.71, 4.35, -7.17, 2.14},
|
||||
{-9.67, -5.14, -7.26, 6.08, -6.87}}):t()
|
||||
|
||||
b=torch.Tensor({{4.02, 6.19, -8.22, -7.57, -3.03},
|
||||
{-1.56, 4.00, -8.67, 1.75, 2.86},
|
||||
{9.81, -4.09, -4.57, -8.61, 8.99}}):t()
|
||||
|
||||
=b
|
||||
4.0200 -1.5600 9.8100
|
||||
6.1900 4.0000 -4.0900
|
||||
-8.2200 -8.6700 -4.5700
|
||||
-7.5700 1.7500 -8.6100
|
||||
-3.0300 2.8600 8.9900
|
||||
[torch.DoubleTensor of dimension 5x3]
|
||||
|
||||
=a
|
||||
6.8000 -6.0500 -0.4500 8.3200 -9.6700
|
||||
-2.1100 -3.3000 2.5800 2.7100 -5.1400
|
||||
5.6600 5.3600 -2.7000 4.3500 -7.2600
|
||||
5.9700 -4.4400 0.2700 -7.1700 6.0800
|
||||
8.2300 1.0800 9.0400 2.1400 -6.8700
|
||||
[torch.DoubleTensor of dimension 5x5]
|
||||
|
||||
|
||||
x=torch.gesv(b,a)
|
||||
=x
|
||||
-0.8007 -0.3896 0.9555
|
||||
-0.6952 -0.5544 0.2207
|
||||
0.5939 0.8422 1.9006
|
||||
1.3217 -0.1038 5.3577
|
||||
0.5658 0.1057 4.0406
|
||||
[torch.DoubleTensor of dimension 5x3]
|
||||
|
||||
=b:dist(a*x)
|
||||
1.1682163181673e-14
|
||||
|
||||
</file>
|
||||
|
||||
===== torch.gels([resb, resa,] b,a) =====
|
||||
{{anchor:torch.gels}}
|
||||
|
||||
Solution of least squares and least norm problems for a full rank '' A '' that is '' m x n''.
|
||||
* If '' n %%<=%% m '', then solve '' ||AX-B||_F ''.
|
||||
* If '' n > m '' , then solve '' min ||X||_F s.t. AX=B ''.
|
||||
|
||||
On return, first '' n '' rows of '' X '' matrix contains the solution
|
||||
and the rest contains residual information. Square root of sum squares
|
||||
of elements of each column of '' X '' starting at row '' n + 1 '' is
|
||||
the residual for corresponding column.
|
||||
|
||||
<file lua>
|
||||
|
||||
a=torch.Tensor({{ 1.44, -9.96, -7.55, 8.34, 7.08, -5.45},
|
||||
{-7.84, -0.28, 3.24, 8.09, 2.52, -5.70},
|
||||
{-4.39, -3.24, 6.27, 5.28, 0.74, -1.19},
|
||||
{4.53, 3.83, -6.64, 2.06, -2.47, 4.70}}):t()
|
||||
|
||||
b=torch.Tensor({{8.58, 8.26, 8.48, -5.28, 5.72, 8.93},
|
||||
{9.35, -4.43, -0.70, -0.26, -7.36, -2.52}}):t()
|
||||
|
||||
=a
|
||||
1.4400 -7.8400 -4.3900 4.5300
|
||||
-9.9600 -0.2800 -3.2400 3.8300
|
||||
-7.5500 3.2400 6.2700 -6.6400
|
||||
8.3400 8.0900 5.2800 2.0600
|
||||
7.0800 2.5200 0.7400 -2.4700
|
||||
-5.4500 -5.7000 -1.1900 4.7000
|
||||
[torch.DoubleTensor of dimension 6x4]
|
||||
|
||||
=b
|
||||
8.5800 9.3500
|
||||
8.2600 -4.4300
|
||||
8.4800 -0.7000
|
||||
-5.2800 -0.2600
|
||||
5.7200 -7.3600
|
||||
8.9300 -2.5200
|
||||
[torch.DoubleTensor of dimension 6x2]
|
||||
|
||||
x = torch.gels(a,b)
|
||||
=x
|
||||
-0.4506 0.2497
|
||||
-0.8492 -0.9020
|
||||
0.7066 0.6323
|
||||
0.1289 0.1351
|
||||
13.1193 -7.4922
|
||||
-4.8214 -7.1361
|
||||
[torch.DoubleTensor of dimension 6x2]
|
||||
|
||||
=b:dist(a*x:narrow(1,1,4))
|
||||
17.390200628863
|
||||
|
||||
=math.sqrt(x:narrow(1,5,2):pow(2):sumall())
|
||||
17.390200628863
|
||||
|
||||
</file>
|
||||
|
||||
===== torch.eig([rese, resv,] a, [, 'n' or 'v']) =====
|
||||
{{anchor:torch.eig}}
|
||||
|
||||
Eigen values and eigen vectors of a symmetric real matrix '' A '' of
|
||||
size '' m x m ''. This function calculates all eigenvalues (and
|
||||
vectors) of '' A '' such that '' A = V' diag(e) V ''. Since the input
|
||||
matrix '' A '' is supposed to be symmetric, only upper triangular
|
||||
portion is used.
|
||||
|
||||
Last argument defines computation of eigenvectors or eigenvalues
|
||||
only. If '' n '', only eignevalues are computed. If '' v '', both
|
||||
eigenvalues and eigenvectors are computed.
|
||||
|
||||
<file lua>
|
||||
|
||||
a=torch.Tensor({{ 1.96, 0.00, 0.00, 0.00, 0.00},
|
||||
{-6.49, 3.80, 0.00, 0.00, 0.00},
|
||||
{-0.47, -6.39, 4.17, 0.00, 0.00},
|
||||
{-7.20, 1.50, -1.51, 5.70, 0.00},
|
||||
{-0.65, -6.34, 2.67, 1.80, -7.10}}):t()
|
||||
|
||||
=a
|
||||
1.9600 -6.4900 -0.4700 -7.2000 -0.6500
|
||||
0.0000 3.8000 -6.3900 1.5000 -6.3400
|
||||
0.0000 0.0000 4.1700 -1.5100 2.6700
|
||||
0.0000 0.0000 0.0000 5.7000 1.8000
|
||||
0.0000 0.0000 0.0000 0.0000 -7.1000
|
||||
[torch.DoubleTensor of dimension 5x5]
|
||||
|
||||
e = torch.eig(a)
|
||||
=e
|
||||
-11.0656
|
||||
-6.2287
|
||||
0.8640
|
||||
8.8655
|
||||
16.0948
|
||||
[torch.DoubleTensor of dimension 5]
|
||||
|
||||
e,v = torch.eig(a,'v')
|
||||
=e
|
||||
-11.0656
|
||||
-6.2287
|
||||
0.8640
|
||||
8.8655
|
||||
16.0948
|
||||
[torch.DoubleTensor of dimension 5]
|
||||
|
||||
=v
|
||||
-0.2981 -0.6075 0.4026 -0.3745 0.4896
|
||||
-0.5078 -0.2880 -0.4066 -0.3572 -0.6053
|
||||
-0.0816 -0.3843 -0.6600 0.5008 0.3991
|
||||
-0.0036 -0.4467 0.4553 0.6204 -0.4564
|
||||
-0.8041 0.4480 0.1725 0.3108 0.1622
|
||||
[torch.DoubleTensor of dimension 5x5]
|
||||
|
||||
=v*torch.diag(e)*v:t()
|
||||
1.9600 -6.4900 -0.4700 -7.2000 -0.6500
|
||||
-6.4900 3.8000 -6.3900 1.5000 -6.3400
|
||||
-0.4700 -6.3900 4.1700 -1.5100 2.6700
|
||||
-7.2000 1.5000 -1.5100 5.7000 1.8000
|
||||
-0.6500 -6.3400 2.6700 1.8000 -7.1000
|
||||
[torch.DoubleTensor of dimension 5x5]
|
||||
|
||||
=a:dist(torch.triu(v*torch.diag(e)*v:t()))
|
||||
1.0219480822443e-14
|
||||
|
||||
</file>
|
||||
|
||||
===== torch.svd([resu, ress, resv] a, [, 's' or 'a']) =====
|
||||
{{anchor:torch.svd}}
|
||||
|
||||
Singular value decomposition of a real matrix '' A '' of size '' n x m
|
||||
'' such that '' A = USV**T ''. The call to ''svd'' returns ''U,S,VT''.
|
||||
|
||||
The last argument, if it is string, represents the number of singular
|
||||
values to be computed. 's' stands for 'some' and 'a' stands for 'all'.
|
||||
|
||||
|
||||
<file lua>
|
||||
|
||||
a=torch.Tensor({{8.79, 6.11, -9.15, 9.57, -3.49, 9.84},
|
||||
{9.93, 6.91, -7.93, 1.64, 4.02, 0.15},
|
||||
{9.83, 5.04, 4.86, 8.83, 9.80, -8.99},
|
||||
{5.45, -0.27, 4.85, 0.74, 10.00, -6.02},
|
||||
{3.16, 7.98, 3.01, 5.80, 4.27, -5.31}}):t()
|
||||
=a
|
||||
8.7900 9.9300 9.8300 5.4500 3.1600
|
||||
6.1100 6.9100 5.0400 -0.2700 7.9800
|
||||
-9.1500 -7.9300 4.8600 4.8500 3.0100
|
||||
9.5700 1.6400 8.8300 0.7400 5.8000
|
||||
-3.4900 4.0200 9.8000 10.0000 4.2700
|
||||
9.8400 0.1500 -8.9900 -6.0200 -5.3100
|
||||
|
||||
u,s,v = torch.svd(a)
|
||||
|
||||
=u
|
||||
-0.5911 0.2632 0.3554 0.3143 0.2299
|
||||
-0.3976 0.2438 -0.2224 -0.7535 -0.3636
|
||||
-0.0335 -0.6003 -0.4508 0.2334 -0.3055
|
||||
-0.4297 0.2362 -0.6859 0.3319 0.1649
|
||||
-0.4697 -0.3509 0.3874 0.1587 -0.5183
|
||||
0.2934 0.5763 -0.0209 0.3791 -0.6526
|
||||
[torch.DoubleTensor of dimension 6x5]
|
||||
|
||||
=s
|
||||
27.4687
|
||||
22.6432
|
||||
8.5584
|
||||
5.9857
|
||||
2.0149
|
||||
[torch.DoubleTensor of dimension 5]
|
||||
|
||||
=v
|
||||
-0.2514 -0.3968 -0.6922 -0.3662 -0.4076
|
||||
0.8148 0.3587 -0.2489 -0.3686 -0.0980
|
||||
-0.2606 0.7008 -0.2208 0.3859 -0.4933
|
||||
0.3967 -0.4507 0.2513 0.4342 -0.6227
|
||||
-0.2180 0.1402 0.5891 -0.6265 -0.4396
|
||||
[torch.DoubleTensor of dimension 5x5]
|
||||
|
||||
=u*torch.diag(s)*v
|
||||
8.7900 9.9300 9.8300 5.4500 3.1600
|
||||
6.1100 6.9100 5.0400 -0.2700 7.9800
|
||||
-9.1500 -7.9300 4.8600 4.8500 3.0100
|
||||
9.5700 1.6400 8.8300 0.7400 5.8000
|
||||
-3.4900 4.0200 9.8000 10.0000 4.2700
|
||||
9.8400 0.1500 -8.9900 -6.0200 -5.3100
|
||||
[torch.DoubleTensor of dimension 6x5]
|
||||
|
||||
=a:dist(u*torch.diag(s)*v)
|
||||
2.8923773593204e-14
|
||||
|
||||
</file>
|
||||
|
||||
36
dok/memoryfile.dok
Normal file
36
dok/memoryfile.dok
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
====== MemoryFile ======
|
||||
{{anchor:torch.MemoryFile.dok}}
|
||||
|
||||
Parent classes: [[File|File]]
|
||||
|
||||
A ''MemoryFile'' is a particular ''File'' which is able to perform basic
|
||||
read/write operations on a buffer in ''RAM''. It implements all methods
|
||||
described in [[File|File]].
|
||||
|
||||
The data of the this ''File'' is contained into a ''NULL'' terminated
|
||||
[[Storage|CharStorage]].
|
||||
|
||||
===== torch.MemoryFile([mode]) =====
|
||||
{{anchor:torch.MemoryFile}}
|
||||
|
||||
//Constructor// which returns a new ''MemoryFile'' object using ''mode''. Valid
|
||||
''mode'' are ''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). Default is ''"rw"''.
|
||||
|
||||
|
||||
===== torch.MemoryFile(storage, mode) =====
|
||||
{{anchor:torch.MemoryFile}}
|
||||
|
||||
//Constructor// which returns a new ''MemoryFile'' object, using the given
|
||||
[[Storage|storage]] (which must be a ''CharStorage'') and ''mode''. Valid
|
||||
''mode'' are ''"r"'' (read), ''"w"'' (write) or ''"rw"'' (read-write). The last character
|
||||
in this storage //must// be ''NULL'' or an error will be generated. This allow
|
||||
to read existing memory. If used for writing, not that the ''storage'' might
|
||||
be resized by this class if needed.
|
||||
|
||||
===== [CharStorage] storage() =====
|
||||
{{anchor:torch.MemoryFile.storage}}
|
||||
|
||||
Returns the [[Storage|storage]] which contains all the data of the
|
||||
''File'' (note: this is //not// a copy, but a //reference// on this storage). The
|
||||
size of the storage is the size of the data in the ''File'', plus one, the
|
||||
last character being ''NULL''.
|
||||
21
dok/pipefile.dok
Normal file
21
dok/pipefile.dok
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
====== PipeFile ======
|
||||
{{anchor:torch.PipeFile.dok}}
|
||||
|
||||
Parent classes: [[DiskFile|DiskFile]]
|
||||
|
||||
A ''PipeFile'' is a particular ''File'' which is able to perform basic read/write operations
|
||||
on a command pipe. It implements all methods described in [[DiskFile|DiskFile]] and [[File|File]].
|
||||
|
||||
The file might be open in read or write mode, depending on the parameter
|
||||
''mode'' (which can take the value ''"r"'' or ''"w"'')
|
||||
given to the [[#torch.PipeFile|torch.PipeFile(fileName, mode)]]. Read-write mode is not allowed.
|
||||
|
||||
===== torch.PipeFile(command, [mode], [quiet]) =====
|
||||
{{anchor:torch.PipeFile}}
|
||||
|
||||
//Constructor// which execute ''command'' by opening a pipe in read or write
|
||||
''mode''. Valid ''mode'' are ''"r"'' (read) or ''"w"'' (write). Default is read
|
||||
mode.
|
||||
|
||||
If (and only if) ''quiet'' is ''true'', no error will be raised in case of
|
||||
problem opening the file: instead ''nil'' will be returned.
|
||||
105
dok/random.dok
Normal file
105
dok/random.dok
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
====== Random Numbers ======
|
||||
{{anchor:torch.random.dok}}
|
||||
|
||||
Torch provides accurate mathematical random generation, based on
|
||||
[[http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html|Mersenne Twister]]
|
||||
random number generator.
|
||||
|
||||
====== Seed Handling ======
|
||||
{{anchor::torch.seed.dok}}
|
||||
|
||||
If no seed is provided to the random generator (using
|
||||
[[#torch.seed|seed()]] or [[#torch.manualSeed|manualSeed()]]), a
|
||||
random seed will be set according to [[#torch.seed|seed()]] the first
|
||||
time a random number is generated.
|
||||
|
||||
Initial seed can be obtained using [[#torch.initialSeed|initialSeed()]].
|
||||
|
||||
Setting a particular seed allows the user to (re)-generate a particular serie of
|
||||
random numbers. Example:
|
||||
<file>
|
||||
> torch.manualSeed(123)
|
||||
> = torch.uniform()
|
||||
0.69646918727085
|
||||
> return torch.uniform()
|
||||
0.71295532141812
|
||||
> return torch.uniform()
|
||||
0.28613933874294
|
||||
> torch.manualSeed(123)
|
||||
> return torch.uniform()
|
||||
0.69646918727085
|
||||
> return torch.uniform()
|
||||
0.71295532141812
|
||||
> return torch.uniform()
|
||||
0.28613933874294
|
||||
> torch.manualSeed(torch.initialSeed())
|
||||
> return torch.uniform()
|
||||
0.69646918727085
|
||||
> return torch.uniform()
|
||||
0.71295532141812
|
||||
> return torch.uniform()
|
||||
0.28613933874294
|
||||
</file>
|
||||
|
||||
===== [number] seed() =====
|
||||
{{anchor:torch.seed}}
|
||||
|
||||
Set the seed of the random number generator according to the time of the
|
||||
computer. Granularity is seconds. Returns the seed obtained.
|
||||
|
||||
===== manualSeed(number) =====
|
||||
{{anchor:torch.manualSeed}}
|
||||
|
||||
Set the seed of the random number generator to the given ''number''.
|
||||
|
||||
===== initialSeed() =====
|
||||
{{anchor:torch.initialSeed}}
|
||||
|
||||
Returns the initial seed used to initialize the random generator.
|
||||
|
||||
====== [number] random() ======
|
||||
{{anchor:torch.random}}
|
||||
|
||||
Returns a 32 bit integer random number.
|
||||
|
||||
====== [number] uniform([a],[b]) ======
|
||||
{{anchor:torch.uniform}}
|
||||
|
||||
Returns a random real number according to uniform distribution on [a,b[. By default ''a'' is 0 and ''b'' is 1.
|
||||
|
||||
====== [number] normal([mean],[stdv]) ======
|
||||
{{anchor:torch.normal}}
|
||||
|
||||
Returns a random real number according to a normal distribution with the given ''mean'' and standard deviation ''stdv''.
|
||||
''stdv'' must be positive.
|
||||
|
||||
====== [number] exponential(lambda) ======
|
||||
{{anchor:torch.exponential}}
|
||||
|
||||
Returns a random real number according to the exponential distribution
|
||||
''p(x) = lambda * exp(-lambda * x)''
|
||||
|
||||
====== [number] cauchy(median, sigma) ======
|
||||
{{anchor:torch.cauchy}}
|
||||
|
||||
Returns a random real number according to the Cauchy distribution
|
||||
''p(x) = sigma/(pi*(sigma^2 + (x-median)^2))''
|
||||
|
||||
====== [number] logNormal(mean, stdv) ======
|
||||
{{anchor:torch.logNormal}}
|
||||
|
||||
Returns a random real number according to the log-normal distribution, with
|
||||
the given ''mean'' and standard deviation ''stdv''.
|
||||
''stdv'' must be positive.
|
||||
|
||||
====== [number] geometric(p) ======
|
||||
{{anchor:torch.geometric}}
|
||||
|
||||
Returns a random integer number according to a geometric distribution
|
||||
''p(i) = (1-p) * p^(i-1)''. ''p'' must satisfy ''0 < p < 1''.
|
||||
|
||||
====== [number] bernouilli([p]) ======
|
||||
{{anchor:torch.bernoulli}}
|
||||
|
||||
Returns ''1'' with probability ''p'' and ''0'' with probability ''1-p''. ''p'' must satisfy ''0 < p < 1''.
|
||||
By default ''p'' is equal to ''0.5''.
|
||||
222
dok/storage.dok
Normal file
222
dok/storage.dok
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
====== Storage ======
|
||||
{{anchor:torch.Storage.dok}}
|
||||
{{anchor:torch.ByteStorage.dok}}
|
||||
{{anchor:torch.CharStorage.dok}}
|
||||
{{anchor:torch.ShortStorage.dok}}
|
||||
{{anchor:torch.IntStorage.dok}}
|
||||
{{anchor:torch.LongStorage.dok}}
|
||||
{{anchor:torch.FloatStorage.dok}}
|
||||
{{anchor:torch.DoubleStorage.dok}}
|
||||
|
||||
//Storages// are basically a way for ''Lua'' to access memory of a ''C'' pointer
|
||||
or array. //Storages// can also [[#__torch.StorageMap|map the contents of a file to memory]].
|
||||
A ''Storage'' is an array of //basic// ''C'' types. For arrays of ''Torch'' objects,
|
||||
use the ''Lua'' tables.
|
||||
|
||||
Several ''Storage'' classes for all the basic ''C'' types exist and have the
|
||||
following self-explanatory names: ''ByteStorage'', ''CharStorage'', ''ShortStorage'',
|
||||
''IntStorage'', ''LongStorage'', ''FloatStorage'', ''DoubleStorage''.
|
||||
|
||||
Note that ''ByteStorage'' and ''CharStorage'' represent both arrays of bytes. ''ByteStorage'' represents an array of
|
||||
//unsigned// chars, while ''CharStorage'' represents an array of //signed// chars.
|
||||
|
||||
Conversions between two ''Storage'' type might be done using ''copy'':
|
||||
<file lua>
|
||||
x = torch.IntStorage(10):fill(1)
|
||||
y = torch.DoubleStorage(10):copy(x)
|
||||
</file>
|
||||
|
||||
[[#torch.Storage|Classical storages]] are [[File#torch.File.serialization|serializable]].
|
||||
[[#__torch.StorageMap|Storages mapping a file]] are also [[#FileSerialization|serializable]],
|
||||
but //will be saved as a normal storage//.
|
||||
|
||||
An alias ''torch.Storage()'' is made over your preferred Storage type,
|
||||
controlled by the
|
||||
[[utility#torch.setdefaulttensortype|torch.setdefaulttensortype]]
|
||||
function. By default, this "points" on ''torch.DoubleStorage''.
|
||||
|
||||
===== torch.TYPEStorage([size]) =====
|
||||
{{anchor:torch.Storage}}
|
||||
|
||||
Returns a new ''Storage'' of type ''TYPE''. Valid ''TYPE'' are ''Byte'', ''Char'', ''Short'',
|
||||
''Int'', ''Long'', ''Float'', and ''Double''. If ''size'' is given, resize the
|
||||
''Storage'' accordingly, else create an empty ''Storage''.
|
||||
|
||||
Example:
|
||||
<file lua>
|
||||
-- Creates a Storage of 10 double:
|
||||
x = torch.DoubleStorage(10)
|
||||
</file>
|
||||
|
||||
The data in the ''Storage'' is //uninitialized//.
|
||||
|
||||
===== torch.TYPEStorage(table) =====
|
||||
{{anchor:torch.Storage}}
|
||||
|
||||
The argument is assumed to be a Lua array of numbers. The constructor returns a new storage of the specified 'TYPE',
|
||||
of the size of the table, containing all the table elements converted
|
||||
|
||||
Example:
|
||||
<file lua>
|
||||
> = torch.IntStorage({1,2,3,4})
|
||||
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
[torch.IntStorage of size 4]
|
||||
</file>
|
||||
|
||||
===== torch.TYPEStorage(filename [, shared]) =====
|
||||
{{anchor:torch.Storage}}
|
||||
{{anchor:__torch.StorageMap}}
|
||||
|
||||
Returns a new kind of ''Storage'' which maps the contents of the given
|
||||
''filename'' to memory. Valid ''TYPE'' are ''Byte'', ''Char'', ''Short'', ''Int'', ''Long'',
|
||||
''Float'', and ''Double''. If the optional boolean argument ''shared'' is ''true'',
|
||||
the mapped memory is shared amongst all processes on the computer.
|
||||
|
||||
When ''shared'' is ''true'', the file must be accessible in read-write mode. Any
|
||||
changes on the storage will be written in the file. The changes might be written
|
||||
only after destruction of the storage.
|
||||
|
||||
When ''shared'' is ''false'' (or not provided), the file must be at least
|
||||
readable. Any changes on the storage will not affect the file. Note:
|
||||
changes made on the file after creation of the storage have an unspecified
|
||||
effect on the storage contents.
|
||||
|
||||
The [[#torch.Storage.size|size]] of the returned ''Storage'' will be
|
||||
<file lua>
|
||||
(size of file in byte)/(size of TYPE).
|
||||
</file>
|
||||
|
||||
Example:
|
||||
<file lua>
|
||||
$ echo "Hello World" > hello.txt
|
||||
$ lua
|
||||
Lua 5.1.3 Copyright (C) 1994-2008 Lua.org, PUC-Rio
|
||||
> require 'torch'
|
||||
> x = torch.CharStorage('hello.txt')
|
||||
> = x
|
||||
72
|
||||
101
|
||||
108
|
||||
108
|
||||
111
|
||||
32
|
||||
87
|
||||
111
|
||||
114
|
||||
108
|
||||
100
|
||||
10
|
||||
[torch.CharStorage of size 12]
|
||||
|
||||
> = x:string()
|
||||
Hello World
|
||||
|
||||
> = x:fill(42):string()
|
||||
************
|
||||
>
|
||||
$ cat hello.txt
|
||||
Hello World
|
||||
$ lua
|
||||
Lua 5.1.3 Copyright (C) 1994-2008 Lua.org, PUC-Rio
|
||||
> require 'torch'
|
||||
> x = torch.CharStorage('hello.txt', true)
|
||||
> = x:string()
|
||||
Hello World
|
||||
|
||||
> x:fill(42)
|
||||
>
|
||||
$ cat hello.txt
|
||||
************
|
||||
</file>
|
||||
|
||||
===== [number] #self =====
|
||||
{{anchor:__torch.StorageSharp}}
|
||||
|
||||
Returns the number of elements in the storage. Equivalent to [[#torch.Storage.size|size()]].
|
||||
|
||||
===== [number] self[index] =====
|
||||
{{anchor:torch.Storage.__index__}}
|
||||
|
||||
Returns or set the element at position ''index'' in the storage. Valid range
|
||||
of ''index'' is 1 to [[#torch.Storage.size|size()]].
|
||||
|
||||
Example:
|
||||
<file lua>
|
||||
x = torch.DoubleStorage(10)
|
||||
print(x[5])
|
||||
</file>
|
||||
|
||||
===== [self] copy(storage) =====
|
||||
{{anchor:torch.Storage.copy}}
|
||||
|
||||
Copy another ''storage''. The types of the two storages might be different: in that case
|
||||
a conversion of types occur (which might result, of course, in loss of precision or rounding).
|
||||
This method returns self, allowing things like:
|
||||
<file lua>
|
||||
x = torch.IntStorage(10):fill(1)
|
||||
y = torch.DoubleStorage(10):copy(x) -- y won't be nil!
|
||||
</file>
|
||||
|
||||
===== [self] fill(value) =====
|
||||
{{anchor:torch.Storage.fill}}
|
||||
|
||||
Fill the ''Storage'' with the given value. This method returns self, allowing things like:
|
||||
<file lua>
|
||||
x = torch.IntStorage(10):fill(0) -- x won't be nil!
|
||||
</file>
|
||||
|
||||
===== [self] resize(size) =====
|
||||
{{anchor:torch.Storage.resize}}
|
||||
|
||||
Resize the storage to the provide ''size''. //The new contents are undertermined//.
|
||||
|
||||
This function returns self, allowing things like:
|
||||
<file lua>
|
||||
x = torch.DoubleStorage(10):fill(1)
|
||||
y = torch.DoubleStorage():resize(x:size()):copy(x) -- y won't be nil!
|
||||
</file>
|
||||
|
||||
===== [number] size() =====
|
||||
{{anchor:torch.Storage.size}}
|
||||
|
||||
Returns the number of elements in the storage. Equivalent to [[#__torch.StorageSharp|#]].
|
||||
|
||||
===== [self] string(str) =====
|
||||
{{anchor:torch.Storage.string}}
|
||||
|
||||
This function is available only on ''ByteStorage'' and ''CharStorage''.
|
||||
|
||||
This method resizes the storage to the length of the provided
|
||||
string ''str'', and copy the contents of ''str'' into the storage. The ''NULL'' terminating character is not copied,
|
||||
but ''str'' might contain ''NULL'' characters. The method returns the ''Storage''.
|
||||
<file lua>
|
||||
> x = torch.CharStorage():string("blah blah")
|
||||
> print(x)
|
||||
98
|
||||
108
|
||||
97
|
||||
104
|
||||
32
|
||||
98
|
||||
108
|
||||
97
|
||||
104
|
||||
[torch.CharStorage of size 9]
|
||||
</file>
|
||||
|
||||
===== [string] string() =====
|
||||
{{anchor:torch.Storage.string}}
|
||||
|
||||
This function is available only on ''ByteStorage'' and ''CharStorage''.
|
||||
|
||||
The contents of the storage viewed as a string are returned. The string might contain
|
||||
''NULL'' characters.
|
||||
<file lua>
|
||||
> x = torch.CharStorage():string("blah blah")
|
||||
> print(x:string())
|
||||
blah blah
|
||||
</file>
|
||||
1794
dok/tensor.dok
Normal file
1794
dok/tensor.dok
Normal file
File diff suppressed because it is too large
Load Diff
130
dok/tester.dok
Normal file
130
dok/tester.dok
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
====== Tester ======
|
||||
{{anchor:torch.Tester.dok}}
|
||||
|
||||
This class provides a generic unit testing framework. It is already
|
||||
being used in [[..:nn:index|nn]] package to verify the correctness of classes.
|
||||
|
||||
The framework is generally used as follows.
|
||||
|
||||
<file lua>
|
||||
mytest = {}
|
||||
|
||||
tester = torch.Tester()
|
||||
|
||||
function mytest.TestA()
|
||||
local a = 10
|
||||
local b = 10
|
||||
tester:asserteq(a,b,'a == b')
|
||||
tester:assertne(a,b,'a ~= b')
|
||||
end
|
||||
|
||||
function mytest.TestB()
|
||||
local a = 10
|
||||
local b = 9
|
||||
tester:assertlt(a,b,'a < b')
|
||||
tester:assertgt(a,b,'a > b')
|
||||
end
|
||||
|
||||
tester:add(mytest)
|
||||
tester:run()
|
||||
|
||||
</file>
|
||||
|
||||
Running this code will report 2 errors in 2 test functions. Generally it is
|
||||
better to put single test cases in each test function unless several very related
|
||||
test cases exit. The error report includes the message and line number of the error.
|
||||
|
||||
<file>
|
||||
|
||||
Running 2 tests
|
||||
** ==> Done
|
||||
|
||||
Completed 2 tests with 2 errors
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
TestB
|
||||
a < b
|
||||
LT(<) violation val=10, condition=9
|
||||
...y/usr.t7/local.master/share/lua/5.1/torch/Tester.lua:23: in function 'assertlt'
|
||||
[string "function mytest.TestB()..."]:4: in function 'f'
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
TestA
|
||||
a ~= b
|
||||
NE(~=) violation val=10, condition=10
|
||||
...y/usr.t7/local.master/share/lua/5.1/torch/Tester.lua:38: in function 'assertne'
|
||||
[string "function mytest.TestA()..."]:5: in function 'f'
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
</file>
|
||||
|
||||
|
||||
==== torch.Tester() ====
|
||||
{{anchor:torch.Tester}}
|
||||
|
||||
Returns a new instance of ''torch.Tester'' class.
|
||||
|
||||
==== add(f, 'name') ====
|
||||
{{anchor:torch.Tester.add}}
|
||||
|
||||
Adds a new test function with name ''name''. The test function is stored in ''f''.
|
||||
The function is supposed to run without any arguments and not return any values.
|
||||
|
||||
==== add(ftable) ====
|
||||
{{anchor:torch.Tester.add}}
|
||||
|
||||
Recursively adds all function entries of the table ''ftable'' as tests. This table
|
||||
can only have functions or nested tables of functions.
|
||||
|
||||
==== assert(condition [, message]) ====
|
||||
{{anchor:torch.Tester.assert}}
|
||||
|
||||
Saves an error if condition is not true with the optional message.
|
||||
|
||||
==== assertlt(val, condition [, message]) ====
|
||||
{{anchor:torch.Tester.assertlt}}
|
||||
|
||||
Saves an error if ''val < condition'' is not true with the optional message.
|
||||
|
||||
==== assertgt(val, condition [, message]) ====
|
||||
{{anchor:torch.Tester.assertgt}}
|
||||
|
||||
Saves an error if ''val > condition'' is not true with the optional message.
|
||||
|
||||
==== assertle(val, condition [, message]) ====
|
||||
{{anchor:torch.Tester.assertle}}
|
||||
|
||||
Saves an error if ''val <= condition'' is not true with the optional message.
|
||||
|
||||
==== assertge(val, condition [, message]) ====
|
||||
{{anchor:torch.Tester.assertge}}
|
||||
|
||||
Saves an error if ''val >= condition'' is not true with the optional message.
|
||||
|
||||
==== asserteq(val, condition [, message]) ====
|
||||
{{anchor:torch.Tester.asserteq}}
|
||||
|
||||
Saves an error if ''val == condition'' is not true with the optional message.
|
||||
|
||||
==== assertne(val, condition [, message]) ====
|
||||
{{anchor:torch.Tester.assertne}}
|
||||
|
||||
Saves an error if ''val ~= condition'' is not true with the optional message.
|
||||
|
||||
==== assertTensorEq(ta, tb, condition [, message]) ====
|
||||
{{anchor:torch.Tester.assertTensorEq}}
|
||||
|
||||
Saves an error if ''max(abs(ta-tb)) < condition'' is not true with the optional message.
|
||||
|
||||
==== run() ====
|
||||
{{anchor:torch.Tester.run}}
|
||||
|
||||
Runs all the test functions that are stored using [[#torch.Tester.add|add()]] function.
|
||||
While running it reports progress and at the end gives a summary of all errors.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
43
dok/timer.dok
Normal file
43
dok/timer.dok
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
====== Timer ======
|
||||
{{anchor:torch.Timer.dok}}
|
||||
|
||||
This class is able to measure time (in seconds) elapsed in a particular period. Example:
|
||||
<file lua>
|
||||
timer = torch.Timer() -- the Timer starts to count now
|
||||
x = 0
|
||||
for i=1,1000000 do
|
||||
x = x + math.sin(x)
|
||||
end
|
||||
print('Time elapsed for 1,000,000 sin: ' .. timer:time().real .. ' seconds')
|
||||
</file>
|
||||
|
||||
===== torch.Timer() =====
|
||||
{{anchor:torch.Timer}}
|
||||
|
||||
Returns a new ''Timer''. The timer starts to count the time now.
|
||||
|
||||
===== [self] reset() =====
|
||||
{{anchor:torch.Timer.reset}}
|
||||
|
||||
Reset the timer accumulated time to ''0''. If the timer was running, the timer
|
||||
restarts to count the time now. If the timer was stopped, it stays stopped.
|
||||
|
||||
===== [self] resume() =====
|
||||
{{anchor:torch.Timer.resume}}
|
||||
|
||||
Resume a stopped timer. The timer restarts to count the time, and addition
|
||||
the accumulated time with the time already counted before being stopped.
|
||||
|
||||
===== [self] stop() =====
|
||||
{{anchor:torch.Timer.stop}}
|
||||
|
||||
Stop the timer. The accumulated time counted until now is stored.
|
||||
|
||||
===== [table] time() =====
|
||||
{{anchor:torch.Timer.time}}
|
||||
|
||||
Returns a table reporting the accumulated time elapsed until now. Following the UNIX shell ''time'' command,
|
||||
there are three fields in the table:
|
||||
* ''real'': the wall-clock elapsed time.
|
||||
* ''user'': the elapsed CPU time. Note that the CPU time of a threaded program sums time spent in all threads.
|
||||
* ''sys'': the time spent in system usage.
|
||||
234
dok/utility.dok
Normal file
234
dok/utility.dok
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
====== Torch utility functions ======
|
||||
{{anchor:torch.utility.dok}}
|
||||
|
||||
This functions are used in all Torch package for creating and handling classes.
|
||||
The most interesting function is probably [[#torch.class|torch.class()]] which allows
|
||||
the user to create easily new classes. [[#torch.typename|torch.typename()]] might
|
||||
also be interesting to check what is the class of a given Torch object.
|
||||
|
||||
The other functions are more for advanced users.
|
||||
|
||||
===== [metatable] torch.class(name, [parentName]) =====
|
||||
{{anchor:torch.class}}
|
||||
|
||||
Creates a new ''Torch'' class called ''name''. If ''parentName'' is provided, the class will inherit
|
||||
''parentName'' methods. A class is a table which has a particular metatable.
|
||||
|
||||
If ''name'' is of the form ''package.className'' then the class ''className'' will be added to the specified ''package''.
|
||||
In that case, ''package'' has to be a valid (and already loaded) package. If ''name'' does not contain any ''"."'',
|
||||
then the class will be defined in the global environment.
|
||||
|
||||
One [or two] (meta)tables are returned. These tables contain all the method
|
||||
provided by the class [and its parent class if it has been provided]. After
|
||||
a call to ''torch.class()'' you have to fill-up properly the metatable.
|
||||
|
||||
After the class definition is complete, constructing a new class //name// will be achieved by a call to ''//name//()''.
|
||||
This call will first call the method <file lua>__init()</file> if it exists, passing all arguments of ''//name//()''.
|
||||
|
||||
<file lua>
|
||||
require "torch"
|
||||
|
||||
-- for naming convenience
|
||||
do
|
||||
--- creates a class "Foo"
|
||||
local Foo = torch.class('Foo')
|
||||
|
||||
--- the initializer
|
||||
function Foo:__init()
|
||||
self.contents = "this is some text"
|
||||
end
|
||||
|
||||
--- a method
|
||||
function Foo:print()
|
||||
print(self.contents)
|
||||
end
|
||||
|
||||
--- another one
|
||||
function Foo:bip()
|
||||
print('bip')
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
--- now create an instance of Foo
|
||||
foo = Foo()
|
||||
|
||||
--- try it out
|
||||
foo:print()
|
||||
|
||||
--- create a class torch.Bar which
|
||||
--- inherits from Foo
|
||||
do
|
||||
local Bar, parent = torch.class('torch.Bar', 'Foo')
|
||||
|
||||
--- the initializer
|
||||
function Bar:__init(stuff)
|
||||
--- call the parent initializer on ourself
|
||||
parent.__init(self)
|
||||
|
||||
--- do some stuff
|
||||
self.stuff = stuff
|
||||
end
|
||||
|
||||
--- a new method
|
||||
function Bar:boing()
|
||||
print('boing!')
|
||||
end
|
||||
|
||||
--- override parent's method
|
||||
function Bar:print()
|
||||
print(self.contents)
|
||||
print(self.stuff)
|
||||
end
|
||||
end
|
||||
|
||||
--- create a new instance and use it
|
||||
bar = torch.Bar("ha ha!")
|
||||
bar:print() -- overrided method
|
||||
bar:boing() -- child method
|
||||
bar:bip() -- parent's method
|
||||
|
||||
</file>
|
||||
|
||||
For advanced users, it is worth mentionning that ''torch.class()'' actually
|
||||
calls [[#torch.newmetatable|torch.newmetatable()]]. with a particular
|
||||
constructor. The constructor creates a Lua table and set the right
|
||||
metatable on it, and then calls <file lua>__init()</file> if it exists in the
|
||||
metatable. It also sets a [[#torch.factory|factory]] field <file lua>__factory</file> such that it
|
||||
is possible to create an empty object of this class.
|
||||
|
||||
===== [string] torch.typename(object) =====
|
||||
{{anchor:torch.typename}}
|
||||
|
||||
Checks if ''object'' has a metatable. If it does, and if it corresponds to a
|
||||
''Torch'' class, then returns a string containing the name of the
|
||||
class. Returns ''nil'' in any other cases.
|
||||
|
||||
A Torch class is a class created with [[#torch.class|torch.class()]] or
|
||||
[[#torch.newmetatable|torch.newmetatable()]].
|
||||
|
||||
===== [userdata] torch.typename2id(string) =====
|
||||
{{anchor:torch.typename2id}}
|
||||
|
||||
Given a Torch class name specified by ''string'', returns a unique
|
||||
corresponding id (defined by a ''lightuserdata'' pointing on the internal
|
||||
structure of the class). This might be useful to do a //fast// check of the
|
||||
class of an object (if used with [[#torch.id|torch.id()]]), avoiding string
|
||||
comparisons.
|
||||
|
||||
Returns ''nil'' if ''string'' does not specify a Torch object.
|
||||
|
||||
===== [userdata] torch.id(object) =====
|
||||
{{anchor:torch.id}}
|
||||
|
||||
Returns a unique id corresponding to the //class// of the given Torch object.
|
||||
The id is defined by a ''lightuserdata'' pointing on the internal structure
|
||||
of the class.
|
||||
|
||||
Returns ''nil'' if ''object'' is not a Torch object.
|
||||
|
||||
This is different from the //object// id returned by [[#torch.pointer|torch.pointer()]].
|
||||
|
||||
===== [table] torch.newmetatable(name, parentName, constructor) =====
|
||||
{{anchor:torch.newmetatable}}
|
||||
|
||||
Register a new metatable as a Torch type with the given string ''name''. The new metatable is returned.
|
||||
|
||||
If the string ''parentName'' is not ''nil'' and is a valid Torch type (previously created
|
||||
by ''torch.newmetatable()'') then set the corresponding metatable as a metatable to the returned new
|
||||
metatable.
|
||||
|
||||
If the given ''constructor'' function is not ''nil'', then assign to the variable ''name'' the given constructor.
|
||||
The given ''name'' might be of the form ''package.className'', in which case the ''className'' will be local to the
|
||||
specified ''package''. In that case, ''package'' must be a valid and already loaded package.
|
||||
|
||||
===== [function] torch.factory(name) =====
|
||||
{{anchor:torch.factory}}
|
||||
|
||||
Returns the factory function of the Torch class ''name''. If the class name is invalid or if the class
|
||||
has no factory, then returns ''nil''.
|
||||
|
||||
A Torch class is a class created with [[#torch.class|torch.class()]] or
|
||||
[[#torch.newmetatable|torch.newmetatable()]].
|
||||
|
||||
A factory function is able to return a new (empty) object of its corresponding class. This is helpful for
|
||||
[[File#torch.File.serialization|object serialization]].
|
||||
|
||||
===== [table] torch.getmetatable(string) =====
|
||||
{{anchor:torch.getmetatable}}
|
||||
|
||||
Given a ''string'', returns a metatable corresponding to the Torch class described
|
||||
by ''string''. Returns ''nil'' if the class does not exist.
|
||||
|
||||
A Torch class is a class created with [[#torch.class|torch.class()]] or
|
||||
[[#torch.newmetatable|torch.newmetatable()]].
|
||||
|
||||
Example:
|
||||
<file lua>
|
||||
> for k,v in pairs(torch.getmetatable("torch.CharStorage")) do print(k,v) end
|
||||
__index__ function: 0x1a4ba80
|
||||
__typename torch.CharStorage
|
||||
write function: 0x1a49cc0
|
||||
__tostring__ function: 0x1a586e0
|
||||
__newindex__ function: 0x1a4ba40
|
||||
string function: 0x1a4d860
|
||||
__version 1
|
||||
copy function: 0x1a49c80
|
||||
read function: 0x1a4d840
|
||||
__len__ function: 0x1a37440
|
||||
fill function: 0x1a375c0
|
||||
resize function: 0x1a37580
|
||||
__index table: 0x1a4a080
|
||||
size function: 0x1a4ba20
|
||||
</file>
|
||||
|
||||
===== [boolean] torch.isequal(object1, object2) =====
|
||||
{{anchor:torch.isequal}}
|
||||
|
||||
If the two objects given as arguments are ''Lua'' tables (or Torch objects), then returns ''true'' if and only if the
|
||||
tables (or Torch objects) have the same address in memory. Returns ''false'' in any other cases.
|
||||
|
||||
A Torch class is a class created with [[#TorchClass|torch.class()]] or
|
||||
[[#torch.newmetatable|torch.newmetatable()]].
|
||||
|
||||
===== torch.setenv(function or userdata, table) =====
|
||||
{{anchor:torch.setenv}}
|
||||
|
||||
Assign ''table'' as the Lua environment of the given ''function'' or the given
|
||||
''userdata''. To know more about environments, please read the documentation
|
||||
of [[http://www.lua.org/manual/5.1/manual.html#lua_setfenv|lua_setfenv()]]
|
||||
and [[http://www.lua.org/manual/5.1/manual.html#lua_getfenv|lua_getfenv()]].
|
||||
|
||||
===== [table] torch.getenv(function or userdata) =====
|
||||
{{anchor:torch.getenv}}
|
||||
|
||||
Returns the Lua ''table'' environment of the given ''function'' or the given
|
||||
''userdata''. To know more about environments, please read the documentation
|
||||
of [[http://www.lua.org/manual/5.1/manual.html#lua_setfenv|lua_setfenv()]]
|
||||
and [[http://www.lua.org/manual/5.1/manual.html#lua_getfenv|lua_getfenv()]].
|
||||
|
||||
===== [number] torch.version(object) =====
|
||||
{{anchor:torch.version}}
|
||||
|
||||
Returns the field <file lua>__version</file> of a given object. This might
|
||||
be helpful to handle variations in a class over time.
|
||||
|
||||
===== [number] torch.pointer(object) =====
|
||||
{{anchor:torch.pointer}}
|
||||
|
||||
Returns a unique id (pointer) of the given ''object'', which can be a Torch
|
||||
object, a table, a thread or a function.
|
||||
|
||||
This is different from the //class// id returned by [[#torch.id|torch.id()]].
|
||||
|
||||
===== [object] torch.setmetatable(table, classname) =====
|
||||
{{anchor:torch.setmetatable}}
|
||||
|
||||
Set the metatable of the given ''table'' to the metatable of the Torch object named ''classname''.
|
||||
This function has to be used with a lot of care.
|
||||
|
||||
===== [table] torch.getconstructortable(string) =====
|
||||
{{anchor:torch.getconstructortable}}
|
||||
|
||||
BUGGY
|
||||
Return the constructor table of the Torch class specified by ''string'.
|
||||
18
general.h
Normal file
18
general.h
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#ifndef TORCH_GENERAL_INC
|
||||
#define TORCH_GENERAL_INC
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "luaT.h"
|
||||
#include "TH.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
||||
#define snprintf _snprintf
|
||||
#define popen _popen
|
||||
#define pclose _pclose
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
221
generic/Storage.c
Normal file
221
generic/Storage.c
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/Storage.c"
|
||||
#else
|
||||
|
||||
static int torch_Storage_(new)(lua_State *L)
|
||||
{
|
||||
THStorage *storage;
|
||||
if(lua_type(L, 1) == LUA_TSTRING)
|
||||
{
|
||||
const char *fileName = luaL_checkstring(L, 1);
|
||||
int isShared = luaT_optboolean(L, 2, 0);
|
||||
storage = THStorage_(newWithMapping)(fileName, isShared); }
|
||||
else if(lua_type(L, 1) == LUA_TTABLE)
|
||||
{
|
||||
long size = lua_objlen(L, 1);
|
||||
long i;
|
||||
storage = THStorage_(newWithSize)(size);
|
||||
for(i = 1; i <= size; i++)
|
||||
{
|
||||
lua_rawgeti(L, 1, i);
|
||||
if(!lua_isnumber(L, -1))
|
||||
{
|
||||
THStorage_(free)(storage);
|
||||
luaL_error(L, "element at index %d is not a number", i);
|
||||
}
|
||||
THStorage_(set)(storage, i-1, (real)lua_tonumber(L, -1));
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
long size = luaL_optlong(L, 1, 0);
|
||||
storage = THStorage_(newWithSize)(size);
|
||||
}
|
||||
luaT_pushudata(L, storage, torch_Storage_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Storage_(free)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
THStorage_(free)(storage);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int torch_Storage_(resize)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
long size = luaL_checklong(L, 2);
|
||||
/* int keepContent = luaT_optboolean(L, 3, 0); */
|
||||
THStorage_(resize)(storage, size);/*, keepContent); */
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Storage_(copy)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
void *src;
|
||||
if( (src = luaT_toudata(L, 2, torch_Storage_id)) )
|
||||
THStorage_(copy)(storage, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_ByteStorage_id)) )
|
||||
THStorage_(copyByte)(storage, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_CharStorage_id)) )
|
||||
THStorage_(copyChar)(storage, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_ShortStorage_id)) )
|
||||
THStorage_(copyShort)(storage, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_IntStorage_id)) )
|
||||
THStorage_(copyInt)(storage, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_LongStorage_id)) )
|
||||
THStorage_(copyLong)(storage, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_FloatStorage_id)) )
|
||||
THStorage_(copyFloat)(storage, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_DoubleStorage_id)) )
|
||||
THStorage_(copyDouble)(storage, src);
|
||||
else
|
||||
luaL_typerror(L, 2, "torch.*Storage");
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Storage_(fill)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
double value = luaL_checknumber(L, 2);
|
||||
THStorage_(fill)(storage, (real)value);
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Storage_(__len__)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
lua_pushnumber(L, storage->size);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Storage_(__newindex__)(lua_State *L)
|
||||
{
|
||||
if(lua_isnumber(L, 2))
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
long index = luaL_checklong(L, 2) - 1;
|
||||
double number = luaL_checknumber(L, 3);
|
||||
THStorage_(set)(storage, index, (real)number);
|
||||
lua_pushboolean(L, 1);
|
||||
}
|
||||
else
|
||||
lua_pushboolean(L, 0);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Storage_(__index__)(lua_State *L)
|
||||
{
|
||||
if(lua_isnumber(L, 2))
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
long index = luaL_checklong(L, 2) - 1;
|
||||
lua_pushnumber(L, THStorage_(get)(storage, index));
|
||||
lua_pushboolean(L, 1);
|
||||
return 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
lua_pushboolean(L, 0);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_BYTE)
|
||||
static int torch_Storage_(string)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
if(lua_isstring(L, -1))
|
||||
{
|
||||
size_t len = 0;
|
||||
const char *str = lua_tolstring(L, -1, &len);
|
||||
THStorage_(resize)(storage, len);
|
||||
memmove(storage->data, str, len);
|
||||
lua_settop(L, 1);
|
||||
}
|
||||
else
|
||||
lua_pushlstring(L, (char*)storage->data, storage->size);
|
||||
|
||||
return 1; /* either storage or string */
|
||||
}
|
||||
#endif
|
||||
|
||||
static int torch_Storage_(totable)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
long i;
|
||||
|
||||
lua_newtable(L);
|
||||
for(i = 0; i < storage->size; i++)
|
||||
{
|
||||
lua_pushnumber(L, (lua_Number)storage->data[i]);
|
||||
lua_rawseti(L, -2, i+1);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Storage_(factory)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = THStorage_(new)();
|
||||
luaT_pushudata(L, storage, torch_Storage_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Storage_(write)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
THFile *file = luaT_checkudata(L, 2, torch_File_id);
|
||||
|
||||
THFile_writeLongScalar(file, storage->size);
|
||||
THFile_writeRealRaw(file, storage->data, storage->size);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int torch_Storage_(read)(lua_State *L)
|
||||
{
|
||||
THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
|
||||
THFile *file = luaT_checkudata(L, 2, torch_File_id);
|
||||
long size = THFile_readLongScalar(file);
|
||||
|
||||
THStorage_(resize)(storage, size);
|
||||
THFile_readRealRaw(file, storage->data, storage->size);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_Storage_(_) [] = {
|
||||
{"size", torch_Storage_(__len__)},
|
||||
{"__len__", torch_Storage_(__len__)},
|
||||
{"__newindex__", torch_Storage_(__newindex__)},
|
||||
{"__index__", torch_Storage_(__index__)},
|
||||
{"resize", torch_Storage_(resize)},
|
||||
{"fill", torch_Storage_(fill)},
|
||||
{"copy", torch_Storage_(copy)},
|
||||
{"totable", torch_Storage_(totable)},
|
||||
{"write", torch_Storage_(write)},
|
||||
{"read", torch_Storage_(read)},
|
||||
#if defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_BYTE)
|
||||
{"string", torch_Storage_(string)},
|
||||
#endif
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_Storage_(init)(lua_State *L)
|
||||
{
|
||||
torch_File_id = luaT_checktypename2id(L, "torch.File");
|
||||
|
||||
torch_Storage_id = luaT_newmetatable(L, STRING_torchStorage, NULL,
|
||||
torch_Storage_(new), torch_Storage_(free), torch_Storage_(factory));
|
||||
luaL_register(L, NULL, torch_Storage_(_));
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
|
||||
#endif
|
||||
939
generic/Tensor.c
Normal file
939
generic/Tensor.c
Normal file
|
|
@ -0,0 +1,939 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/Tensor.c"
|
||||
#else
|
||||
|
||||
static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride,
|
||||
THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_);
|
||||
|
||||
static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_);
|
||||
|
||||
static int torch_Tensor_(size)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
if(lua_isnumber(L,2))
|
||||
{
|
||||
int dim = luaL_checkint(L, 2)-1;
|
||||
luaL_argcheck(L, dim >= 0 && dim < tensor->nDimension, 2, "out of range");
|
||||
lua_pushnumber(L, tensor->size[dim]);
|
||||
}
|
||||
else
|
||||
{
|
||||
THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension);
|
||||
memmove(storage->data, tensor->size, sizeof(long)*tensor->nDimension);
|
||||
luaT_pushudata(L, storage, torch_LongStorage_id);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(stride)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
if(lua_isnumber(L,2))
|
||||
{
|
||||
int dim = luaL_checkint(L, 2)-1;
|
||||
luaL_argcheck(L, dim >= 0 && dim < tensor->nDimension, 2, "out of range");
|
||||
lua_pushnumber(L, tensor->stride[dim]);
|
||||
}
|
||||
else
|
||||
{
|
||||
THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension);
|
||||
memmove(storage->data, tensor->stride, sizeof(long)*tensor->nDimension);
|
||||
luaT_pushudata(L, storage, torch_LongStorage_id);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(nDimension)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
lua_pushnumber(L, tensor->nDimension);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(storage)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
|
||||
if(tensor->storage)
|
||||
{
|
||||
THStorage_(retain)(tensor->storage);
|
||||
luaT_pushudata(L, tensor->storage, torch_Storage_id);
|
||||
}
|
||||
else
|
||||
lua_pushnil(L);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(storageOffset)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
lua_pushnumber(L, tensor->storageOffset+1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(new)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor;
|
||||
long storageOffset;
|
||||
THLongStorage *size, *stride;
|
||||
|
||||
if(lua_type(L, 1) == LUA_TTABLE)
|
||||
{
|
||||
long i, j;
|
||||
THLongStorage *counter;
|
||||
long si = 0;
|
||||
int dimension = 0;
|
||||
int is_finished = 0;
|
||||
|
||||
lua_settop(L, 1);
|
||||
size = THLongStorage_new();
|
||||
|
||||
while( (lua_type(L, -1) == LUA_TTABLE) && (lua_objlen(L, -1) > 0) )
|
||||
{
|
||||
THLongStorage_resize(size, dimension+1);
|
||||
size->data[dimension] = lua_objlen(L, -1);
|
||||
dimension++;
|
||||
lua_rawgeti(L, -1, 1);
|
||||
}
|
||||
lua_pop(L, 1);
|
||||
|
||||
counter = THLongStorage_newWithSize(size->size);
|
||||
THLongStorage_fill(counter, 0);
|
||||
|
||||
tensor = THTensor_(newWithSize)(size, NULL);
|
||||
|
||||
if(size->size == 0)
|
||||
is_finished = 1;
|
||||
|
||||
while(!is_finished)
|
||||
{
|
||||
if(!lua_istable(L, -1))
|
||||
{
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(counter);
|
||||
THTensor_(free)(tensor);
|
||||
luaL_error(L, "invalid tensor definition");
|
||||
}
|
||||
|
||||
if(lua_objlen(L, -1) != size->data[size->size-1])
|
||||
{
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(counter);
|
||||
THTensor_(free)(tensor);
|
||||
luaL_error(L, "invalid tensor sizes");
|
||||
}
|
||||
|
||||
for(i = 0; i < size->data[size->size-1]; i++)
|
||||
{
|
||||
lua_rawgeti(L, -1, i+1);
|
||||
if(!lua_isnumber(L, -1))
|
||||
{
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(counter);
|
||||
THTensor_(free)(tensor);
|
||||
luaL_error(L, "invalid element (not a number)");
|
||||
}
|
||||
THStorage_(set)(THTensor_(storage)(tensor), si++, (real)lua_tonumber(L, -1));
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
|
||||
if(size->size == 1)
|
||||
break;
|
||||
|
||||
for(i = size->size-2; i >= 0; i--)
|
||||
{
|
||||
if(++counter->data[i] == size->data[i])
|
||||
{
|
||||
if(i == 0)
|
||||
{
|
||||
is_finished = 1;
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
counter->data[i] = 0;
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
lua_pop(L, 1);
|
||||
for(j = i; j < size->size-1; j++)
|
||||
{
|
||||
if(!lua_istable(L, -1))
|
||||
{
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(counter);
|
||||
THTensor_(free)(tensor);
|
||||
luaL_error(L, "invalid tensor definition");
|
||||
}
|
||||
if(lua_objlen(L, -1) != size->data[j])
|
||||
{
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(counter);
|
||||
THTensor_(free)(tensor);
|
||||
luaL_error(L, "invalid tensor sizes");
|
||||
}
|
||||
lua_rawgeti(L, -1, counter->data[j]+1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(counter);
|
||||
}
|
||||
else
|
||||
{
|
||||
THStorage *storage;
|
||||
|
||||
torch_Tensor_(c_readTensorStorageSizeStride)(L, 1, 1, 1, 1, 1,
|
||||
&storage, &storageOffset, &size, &stride);
|
||||
|
||||
tensor = THTensor_(newWithStorage)(storage, storageOffset, size, stride);
|
||||
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(stride);
|
||||
}
|
||||
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(set)(lua_State *L)
|
||||
{
|
||||
THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THStorage *storage;
|
||||
long storageOffset;
|
||||
THLongStorage *size, *stride;
|
||||
|
||||
torch_Tensor_(c_readTensorStorageSizeStride)(L, 2, 1, 1, 1, 1,
|
||||
&storage, &storageOffset, &size, &stride);
|
||||
|
||||
THTensor_(setStorage)(self, storage, storageOffset, size, stride);
|
||||
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(stride);
|
||||
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(clone)(lua_State *L)
|
||||
{
|
||||
THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
self = THTensor_(newClone)(self);
|
||||
luaT_pushudata(L, self, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(contiguous)(lua_State *L)
|
||||
{
|
||||
THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
self = THTensor_(newContiguous)(self);
|
||||
luaT_pushudata(L, self, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
/* Resize */
|
||||
static int torch_Tensor_(resizeAs)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id);
|
||||
THTensor_(resizeAs)(tensor, src);
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(resize)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THLongStorage *size, *stride;
|
||||
|
||||
torch_Tensor_(c_readSizeStride)(L, 2, 0, &size, &stride);
|
||||
|
||||
THTensor_(resize)(tensor, size, stride);
|
||||
|
||||
THLongStorage_free(size);
|
||||
THLongStorage_free(stride);
|
||||
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(narrow)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
int dimension = luaL_checkint(L, 2)-1;
|
||||
long firstIndex = luaL_checklong(L, 3)-1;
|
||||
long size = luaL_checklong(L, 4);
|
||||
|
||||
/* THArgCheck( (dimension >= 0) && (dimension < tensor->nDimension), 2, "out of range");
|
||||
THArgCheck( (firstIndex >= 0) && (firstIndex < tensor->size[dimension]), 3, "out of range");
|
||||
THArgCheck( (size > 0) && (firstIndex+size <= tensor->size[dimension]), 4, "out of range");
|
||||
*/
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, dimension, firstIndex, size);
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(sub)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
long d0s = -1, d0e = -1, d1s = -1, d1e = -1, d2s = -1, d2e = -1, d3s = -1, d3e = -1;
|
||||
|
||||
d0s = luaL_checklong(L, 2)-1;
|
||||
d0e = luaL_checklong(L, 3)-1;
|
||||
if(d0s < 0)
|
||||
d0s += tensor->size[0]+1;
|
||||
if(d0e < 0)
|
||||
d0e += tensor->size[0]+1;
|
||||
luaL_argcheck(L, tensor->nDimension > 0, 2, "invalid dimension");
|
||||
luaL_argcheck(L, d0s >= 0 && d0s < tensor->size[0], 2, "out of range");
|
||||
luaL_argcheck(L, d0e >= 0 && d0e < tensor->size[0], 3, "out of range");
|
||||
luaL_argcheck(L, d0e >= d0s, 3, "end smaller than beginning");
|
||||
|
||||
if(!lua_isnone(L, 4))
|
||||
{
|
||||
d1s = luaL_checklong(L, 4)-1;
|
||||
d1e = luaL_checklong(L, 5)-1;
|
||||
if(d1s < 0)
|
||||
d1s += tensor->size[1]+1;
|
||||
if(d1e < 0)
|
||||
d1e += tensor->size[1]+1;
|
||||
luaL_argcheck(L, tensor->nDimension > 1, 4, "invalid dimension");
|
||||
luaL_argcheck(L, d1s >= 0 && d1s < tensor->size[1], 4, "out of range");
|
||||
luaL_argcheck(L, d1e >= 0 && d1e < tensor->size[1], 5, "out of range");
|
||||
luaL_argcheck(L, d1e >= d1s, 5, "end smaller than beginning");
|
||||
|
||||
if(!lua_isnone(L, 6))
|
||||
{
|
||||
d2s = luaL_checklong(L, 6)-1;
|
||||
d2e = luaL_checklong(L, 7)-1;
|
||||
if(d2s < 0)
|
||||
d2s += tensor->size[2]+1;
|
||||
if(d2e < 0)
|
||||
d2e += tensor->size[2]+1;
|
||||
luaL_argcheck(L, tensor->nDimension > 2, 6, "invalid dimension");
|
||||
luaL_argcheck(L, d2s >= 0 && d2s < tensor->size[2], 6, "out of range");
|
||||
luaL_argcheck(L, d2e >= 0 && d2e < tensor->size[2], 7, "out of range");
|
||||
luaL_argcheck(L, d2e >= d2s, 7, "end smaller than beginning");
|
||||
|
||||
if(!lua_isnone(L, 8))
|
||||
{
|
||||
d3s = luaL_checklong(L, 8)-1;
|
||||
d3e = luaL_checklong(L, 9)-1;
|
||||
if(d3s < 0)
|
||||
d3s += tensor->size[3]+1;
|
||||
if(d3e < 0)
|
||||
d3e += tensor->size[3]+1;
|
||||
luaL_argcheck(L, tensor->nDimension > 3, 8, "invalid dimension");
|
||||
luaL_argcheck(L, d3s >= 0 && d3s < tensor->size[3], 8, "out of range");
|
||||
luaL_argcheck(L, d3e >= 0 && d3e < tensor->size[3], 9, "out of range");
|
||||
luaL_argcheck(L, d3e >= d3s, 9, "end smaller than beginning");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, d0s, d0e-d0s+1);
|
||||
if(d1s >= 0)
|
||||
THTensor_(narrow)(tensor, NULL, 1, d1s, d1e-d1s+1);
|
||||
if(d2s >= 0)
|
||||
THTensor_(narrow)(tensor, NULL, 2, d2s, d2e-d2s+1);
|
||||
if(d3s >= 0)
|
||||
THTensor_(narrow)(tensor, NULL, 3, d3s, d3e-d3s+1);
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(select)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
int dimension = luaL_checkint(L, 2)-1;
|
||||
long sliceIndex = luaL_checklong(L, 3)-1;
|
||||
|
||||
/* THArgCheck(src->nDimension > 1, 1, "cannot select on a vector");
|
||||
THArgCheck((dimension >= 0) && (dimension < src->nDimension), 2, "out of range");
|
||||
THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 3, "out of range");
|
||||
*/
|
||||
|
||||
if(tensor->nDimension > 1)
|
||||
{
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(select)(tensor, NULL, dimension, sliceIndex);
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 1, 1, "empty Tensor");
|
||||
lua_pushnumber(L, THTensor_(get1d)(tensor, sliceIndex));
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
static int torch_Tensor_(transpose)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
int dimension1 = luaL_checkint(L, 2)-1;
|
||||
int dimension2 = luaL_checkint(L, 3)-1;
|
||||
|
||||
/*
|
||||
THArgCheck( (dimension1 >= 0) && (dimension1 < src->nDimension), 2, "out of range");
|
||||
THArgCheck( (dimension2 >= 0) && (dimension2 < src->nDimension), 3, "out of range");
|
||||
*/
|
||||
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(transpose)(tensor, NULL, dimension1, dimension2);
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(t)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
|
||||
luaL_argcheck(L, tensor->nDimension == 2, 1, "Tensor must have 2 dimensions");
|
||||
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(transpose)(tensor, NULL, 0, 1);
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(unfold)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
int dimension = luaL_checkint(L, 2)-1;
|
||||
long size = luaL_checklong(L, 3);
|
||||
long step = luaL_checklong(L, 4);
|
||||
|
||||
/*
|
||||
THArgCheck( (src->nDimension > 0), 1, "cannot unfold an empty tensor");
|
||||
THArgCheck(dimension < src->nDimension, 2, "out of range");
|
||||
THArgCheck(size <= src->size[dimension], 3, "out of range");
|
||||
*/
|
||||
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(unfold)(tensor, NULL, dimension, size, step);
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
/* is contiguous? [a bit like in TnXIterator] */
|
||||
static int torch_Tensor_(isContiguous)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
lua_pushboolean(L, THTensor_(isContiguous)(tensor));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(nElement)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
lua_pushnumber(L, THTensor_(nElement)(tensor));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(copy)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
void *src;
|
||||
if( (src = luaT_toudata(L, 2, torch_Tensor_id)) )
|
||||
THTensor_(copy)(tensor, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_ByteTensor_id)) )
|
||||
THTensor_(copyByte)(tensor, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_CharTensor_id)) )
|
||||
THTensor_(copyChar)(tensor, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_ShortTensor_id)) )
|
||||
THTensor_(copyShort)(tensor, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_IntTensor_id)) )
|
||||
THTensor_(copyInt)(tensor, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_LongTensor_id)) )
|
||||
THTensor_(copyLong)(tensor, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_FloatTensor_id)) )
|
||||
THTensor_(copyFloat)(tensor, src);
|
||||
else if( (src = luaT_toudata(L, 2, torch_DoubleTensor_id)) )
|
||||
THTensor_(copyDouble)(tensor, src);
|
||||
else
|
||||
luaL_typerror(L, 2, "torch.*Tensor");
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(__newindex__)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THLongStorage *idx = NULL;
|
||||
|
||||
if(lua_isnumber(L, 2))
|
||||
{
|
||||
long index = luaL_checklong(L,2)-1;
|
||||
void *src;
|
||||
if (lua_isnumber(L,3)) {
|
||||
real value = (real)luaL_checknumber(L,3);
|
||||
luaL_argcheck(L, tensor->nDimension == 1, 1, "must be a one dimensional tensor");
|
||||
luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range");
|
||||
THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copy)(tensor, src);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyByte)(tensor, src);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyChar)(tensor, src);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyShort)(tensor, src);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyInt)(tensor, src);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyLong)(tensor, src);
|
||||
} else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) {
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(tensor, NULL, 0, index, 1);
|
||||
THTensor_(copyFloat)(tensor, src);
|
||||
} else {
|
||||
luaL_typerror(L, 3, "torch.*Tensor");
|
||||
}
|
||||
lua_pushboolean(L, 1);
|
||||
}
|
||||
else if((idx = luaT_toudata(L, 2, torch_LongStorage_id)))
|
||||
{
|
||||
long index = THTensor_(storageOffset)(tensor);
|
||||
real value = (real)luaL_checknumber(L,3);
|
||||
int dim;
|
||||
|
||||
luaL_argcheck(L, idx->size == tensor->nDimension, 2, "invalid size");
|
||||
|
||||
for(dim = 0; dim < idx->size; dim++)
|
||||
{
|
||||
long z = idx->data[dim]-1;
|
||||
luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
|
||||
index += z*tensor->stride[dim];
|
||||
}
|
||||
|
||||
THStorage_(set)(tensor->storage, index, value);
|
||||
lua_pushboolean(L, 1);
|
||||
}
|
||||
else if(lua_istable(L, 2))
|
||||
{
|
||||
long index = THTensor_(storageOffset)(tensor);
|
||||
real value = (real)luaL_checknumber(L,3);
|
||||
int dim;
|
||||
|
||||
luaL_argcheck(L, lua_objlen(L,2) == tensor->nDimension, 2, "invalid size");
|
||||
|
||||
for(dim = 0; dim < tensor->nDimension; dim++)
|
||||
{
|
||||
long z;
|
||||
|
||||
lua_rawgeti(L, 2, dim+1);
|
||||
if(!lua_isnumber(L, -1))
|
||||
luaL_error(L, "number expected for each dimension");
|
||||
|
||||
z = lua_tonumber(L, -1)-1;
|
||||
lua_pop(L, 1);
|
||||
|
||||
luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
|
||||
index += z*tensor->stride[dim];
|
||||
}
|
||||
THStorage_(set)(tensor->storage, index, value);
|
||||
lua_pushboolean(L, 1);
|
||||
}
|
||||
else
|
||||
lua_pushboolean(L, 0);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(__index__)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THLongStorage *idx = NULL;
|
||||
|
||||
if(lua_isnumber(L, 2))
|
||||
{
|
||||
long index = luaL_checklong(L,2)-1;
|
||||
|
||||
luaL_argcheck(L, tensor->nDimension > 0, 1, "empty tensor");
|
||||
luaL_argcheck(L, index >= 0 && index < tensor->size[0], 2, "out of range");
|
||||
|
||||
if(tensor->nDimension == 1)
|
||||
{
|
||||
lua_pushnumber(L, THStorage_(get)(tensor->storage, tensor->storageOffset+index*tensor->stride[0]));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensor = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(select)(tensor, NULL, 0, index);
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
}
|
||||
lua_pushboolean(L, 1);
|
||||
return 2;
|
||||
}
|
||||
else if((idx = luaT_toudata(L, 2, torch_LongStorage_id)))
|
||||
{
|
||||
long index = THTensor_(storageOffset)(tensor);
|
||||
int dim;
|
||||
|
||||
luaL_argcheck(L, idx->size == tensor->nDimension, 2, "invalid size");
|
||||
|
||||
for(dim = 0; dim < idx->size; dim++)
|
||||
{
|
||||
long z = idx->data[dim]-1;
|
||||
luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
|
||||
index += z*tensor->stride[dim];
|
||||
}
|
||||
lua_pushnumber(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index));
|
||||
lua_pushboolean(L, 1);
|
||||
return 2;
|
||||
}
|
||||
else if(lua_istable(L, 2))
|
||||
{
|
||||
long index = THTensor_(storageOffset)(tensor);
|
||||
int dim;
|
||||
|
||||
luaL_argcheck(L, lua_objlen(L,2) == tensor->nDimension, 2, "invalid size");
|
||||
|
||||
for(dim = 0; dim < tensor->nDimension; dim++)
|
||||
{
|
||||
long z;
|
||||
|
||||
lua_rawgeti(L, 2, dim+1);
|
||||
if(!lua_isnumber(L, -1))
|
||||
luaL_error(L, "number expected for each dimension");
|
||||
|
||||
z = lua_tonumber(L, -1)-1;
|
||||
lua_pop(L, 1);
|
||||
|
||||
luaL_argcheck(L, (z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
|
||||
index += z*tensor->stride[dim];
|
||||
}
|
||||
lua_pushnumber(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index));
|
||||
lua_pushboolean(L, 1);
|
||||
return 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
lua_pushboolean(L, 0);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
static int torch_Tensor_(free)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THTensor_(free)(tensor);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* helpful functions */
|
||||
static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_)
|
||||
{
|
||||
THLongStorage *size = NULL;
|
||||
THLongStorage *stride = NULL;
|
||||
|
||||
if( (size = luaT_toudata(L, index, torch_LongStorage_id)) )
|
||||
{
|
||||
if(!lua_isnoneornil(L, index+1))
|
||||
{
|
||||
if( (stride = luaT_toudata(L, index+1, torch_LongStorage_id)) )
|
||||
luaL_argcheck(L, stride->size == size->size, index+1, "provided stride and size are inconsistent");
|
||||
else
|
||||
luaL_argcheck(L, 0, index+1, "torch.LongStorage expected");
|
||||
}
|
||||
THLongStorage_retain(size);
|
||||
if(stride)
|
||||
THLongStorage_retain(stride);
|
||||
}
|
||||
else
|
||||
{
|
||||
int i;
|
||||
|
||||
size = THLongStorage_newWithSize(8);
|
||||
stride = THLongStorage_newWithSize(8);
|
||||
THLongStorage_fill(size, -1);
|
||||
THLongStorage_fill(stride, -1);
|
||||
|
||||
if(allowStride)
|
||||
{
|
||||
for(i = 0; i < 8; i++)
|
||||
{
|
||||
if(lua_isnone(L, index+2*i))
|
||||
break;
|
||||
size->data[i] = luaL_checklong(L, index+2*i);
|
||||
|
||||
if(lua_isnone(L, index+2*i+1))
|
||||
break;
|
||||
stride->data[i] = luaL_checklong(L, index+2*i+1);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(i = 0; i < 8; i++)
|
||||
{
|
||||
if(lua_isnone(L, index+i))
|
||||
break;
|
||||
size->data[i] = luaL_checklong(L, index+i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*size_ = size;
|
||||
*stride_ = stride;
|
||||
}
|
||||
|
||||
static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride,
|
||||
THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_)
|
||||
{
|
||||
static char errMsg[64];
|
||||
THTensor *src = NULL;
|
||||
THStorage *storage = NULL;
|
||||
|
||||
int arg1Type = lua_type(L, index);
|
||||
|
||||
if( allowNone && (arg1Type == LUA_TNONE) )
|
||||
{
|
||||
*storage_ = NULL;
|
||||
*storageOffset_ = 0;
|
||||
*size_ = NULL;
|
||||
*stride_ = NULL;
|
||||
return;
|
||||
}
|
||||
else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor_id)) )
|
||||
{
|
||||
*storage_ = src->storage;
|
||||
*storageOffset_ = src->storageOffset;
|
||||
*size_ = THTensor_(newSizeOf)(src);
|
||||
*stride_ = THTensor_(newStrideOf)(src);
|
||||
return;
|
||||
}
|
||||
else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage_id)) )
|
||||
{
|
||||
*storage_ = storage;
|
||||
if(lua_isnone(L, index+1))
|
||||
{
|
||||
*storageOffset_ = 0;
|
||||
*size_ = THLongStorage_newWithSize1(storage->size);
|
||||
*stride_ = THLongStorage_newWithSize1(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
*storageOffset_ = luaL_checklong(L, index+1)-1;
|
||||
torch_Tensor_(c_readSizeStride)(L, index+2, allowStride, size_, stride_);
|
||||
}
|
||||
return;
|
||||
}
|
||||
else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, torch_LongStorage_id)) )
|
||||
{
|
||||
*storage_ = NULL;
|
||||
*storageOffset_ = 0;
|
||||
torch_Tensor_(c_readSizeStride)(L, index, 0, size_, stride_);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
*storage_ = NULL;
|
||||
*storageOffset_ = 0;
|
||||
|
||||
sprintf(errMsg, "expecting number%s%s", (allowTensor ? " or Tensor" : ""), (allowStorage ? " or Storage" : ""));
|
||||
luaL_argcheck(L, 0, index, errMsg);
|
||||
}
|
||||
|
||||
static int torch_Tensor_(apply)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
luaL_checktype(L, 2, LUA_TFUNCTION);
|
||||
lua_settop(L, 2);
|
||||
|
||||
TH_TENSOR_APPLY(real, tensor,
|
||||
lua_pushvalue(L, 2);
|
||||
lua_pushnumber(L, *tensor_data);
|
||||
lua_call(L, 1, 1);
|
||||
if(lua_isnumber(L, 3))
|
||||
{
|
||||
*tensor_data = (real)lua_tonumber(L, 3);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
else if(lua_isnil(L, 3))
|
||||
lua_pop(L, 1);
|
||||
else
|
||||
luaL_error(L, "given function should return a number or nil"););
|
||||
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(map)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id);
|
||||
luaL_checktype(L, 3, LUA_TFUNCTION);
|
||||
lua_settop(L, 3);
|
||||
|
||||
TH_TENSOR_APPLY2(real, tensor, real, src,
|
||||
lua_pushvalue(L, 3);
|
||||
lua_pushnumber(L, *tensor_data);
|
||||
lua_pushnumber(L, *src_data);
|
||||
lua_call(L, 2, 1);
|
||||
if(lua_isnumber(L, 4))
|
||||
{
|
||||
*tensor_data = (real)lua_tonumber(L, 4);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
else if(lua_isnil(L, 4))
|
||||
lua_pop(L, 1);
|
||||
else
|
||||
luaL_error(L, "given function should return a number or nil"););
|
||||
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(map2)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor_id);
|
||||
THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor_id);
|
||||
luaL_checktype(L, 4, LUA_TFUNCTION);
|
||||
lua_settop(L, 4);
|
||||
|
||||
TH_TENSOR_APPLY3(real, tensor, real, src1, real, src2,
|
||||
lua_pushvalue(L, 4);
|
||||
lua_pushnumber(L, *tensor_data);
|
||||
lua_pushnumber(L, *src1_data);
|
||||
lua_pushnumber(L, *src2_data);
|
||||
lua_call(L, 3, 1);
|
||||
if(lua_isnumber(L, 5))
|
||||
{
|
||||
*tensor_data = (real)lua_tonumber(L, 5);
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
else if(lua_isnil(L, 5))
|
||||
lua_pop(L, 1);
|
||||
else
|
||||
luaL_error(L, "given function should return a number or nothing"););
|
||||
|
||||
lua_settop(L, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(factory)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = THTensor_(new)();
|
||||
luaT_pushudata(L, tensor, torch_Tensor_id);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(write)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THFile *file = luaT_checkudata(L, 2, torch_File_id);
|
||||
|
||||
THFile_writeIntScalar(file, tensor->nDimension);
|
||||
THFile_writeLongRaw(file, tensor->size, tensor->nDimension);
|
||||
THFile_writeLongRaw(file, tensor->stride, tensor->nDimension);
|
||||
THFile_writeLongScalar(file, tensor->storageOffset+1); /* to respect Lua convention */
|
||||
|
||||
lua_getfield(L, 2, "writeObject"); /* the method */
|
||||
lua_pushvalue(L, 2); /* the file */
|
||||
/* the storage */
|
||||
if(tensor->storage)
|
||||
{
|
||||
THStorage_(retain)(tensor->storage);
|
||||
luaT_pushudata(L, tensor->storage, torch_Storage_id);
|
||||
}
|
||||
else
|
||||
lua_pushnil(L);
|
||||
|
||||
lua_call(L, 2, 0); /* call the method */
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int torch_Tensor_(read)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THFile *file = luaT_checkudata(L, 2, torch_File_id);
|
||||
|
||||
tensor->nDimension = THFile_readIntScalar(file);
|
||||
tensor->size = THAlloc(sizeof(long)*tensor->nDimension);
|
||||
tensor->stride = THAlloc(sizeof(long)*tensor->nDimension);
|
||||
THFile_readLongRaw(file, tensor->size, tensor->nDimension);
|
||||
THFile_readLongRaw(file, tensor->stride, tensor->nDimension);
|
||||
tensor->storageOffset = THFile_readLongScalar(file);
|
||||
tensor->storageOffset--; /* to respect Lua convention */
|
||||
|
||||
lua_getfield(L, 2, "readObject"); /* the method */
|
||||
lua_pushvalue(L, 2); /* the file */
|
||||
lua_call(L, 1, 1); /* call the method */
|
||||
|
||||
tensor->storage = luaT_toudata(L, -1, torch_Storage_id);
|
||||
if(tensor->storage)
|
||||
THStorage_(retain)(tensor->storage);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_Tensor_(_) [] = {
|
||||
{"contiguous", torch_Tensor_(contiguous)},
|
||||
{"size", torch_Tensor_(size)},
|
||||
{"__len__", torch_Tensor_(size)},
|
||||
{"stride", torch_Tensor_(stride)},
|
||||
{"dim", torch_Tensor_(nDimension)},
|
||||
{"nDimension", torch_Tensor_(nDimension)},
|
||||
{"set", torch_Tensor_(set)},
|
||||
{"storage", torch_Tensor_(storage)},
|
||||
{"storageOffset", torch_Tensor_(storageOffset)},
|
||||
{"clone", torch_Tensor_(clone)},
|
||||
{"contiguous", torch_Tensor_(contiguous)},
|
||||
{"resizeAs", torch_Tensor_(resizeAs)},
|
||||
{"resize", torch_Tensor_(resize)},
|
||||
{"narrow", torch_Tensor_(narrow)},
|
||||
{"sub", torch_Tensor_(sub)},
|
||||
{"select", torch_Tensor_(select)},
|
||||
{"transpose", torch_Tensor_(transpose)},
|
||||
{"t", torch_Tensor_(t)},
|
||||
{"unfold", torch_Tensor_(unfold)},
|
||||
{"isContiguous", torch_Tensor_(isContiguous)},
|
||||
{"nElement", torch_Tensor_(nElement)},
|
||||
{"copy", torch_Tensor_(copy)},
|
||||
{"apply", torch_Tensor_(apply)},
|
||||
{"map", torch_Tensor_(map)},
|
||||
{"map2", torch_Tensor_(map2)},
|
||||
{"read", torch_Tensor_(read)},
|
||||
{"write", torch_Tensor_(write)},
|
||||
{"__index__", torch_Tensor_(__index__)},
|
||||
{"__newindex__", torch_Tensor_(__newindex__)},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_Tensor_(init)(lua_State *L)
|
||||
{
|
||||
torch_File_id = luaT_checktypename2id(L, "torch.File");
|
||||
torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
||||
torch_Storage_id = luaT_checktypename2id(L, STRING_torchStorage);
|
||||
|
||||
torch_Tensor_id = luaT_newmetatable(L, STRING_torchTensor, NULL,
|
||||
torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory));
|
||||
luaL_register(L, NULL, torch_Tensor_(_));
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
|
||||
#endif
|
||||
175
generic/TensorConv.c
Normal file
175
generic/TensorConv.c
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/TensorConv.c"
|
||||
#else
|
||||
|
||||
static int torch_(convxcorr2)(lua_State *L, const char* ktype)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
THTensor *r_ = NULL;
|
||||
THTensor *im = NULL;
|
||||
THTensor *ker = NULL;
|
||||
char type[2];
|
||||
int rgiven = 0;
|
||||
|
||||
type[0] = 'v';
|
||||
type[1] = ktype[0];
|
||||
|
||||
if (narg == 2
|
||||
&& (ker = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (im = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
}
|
||||
else if (narg == 3
|
||||
&& (lua_type(L,3) == LUA_TSTRING)
|
||||
&& (ker = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (im = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
type[0] = *(luaL_checkstring(L,3));
|
||||
luaL_argcheck(L, (type[0] == 'v' || type[0] == 'V' || type[0] == 'f' || type[0] == 'F'),
|
||||
3, "[Tensor, ] Tensor, Tensor [, x or c]");
|
||||
if (type[0] == 'V') type[0] = 'v';
|
||||
if (type[0] == 'F') type[0] = 'f';
|
||||
}
|
||||
else if (narg == 4
|
||||
&& (type[0] = *(luaL_checkstring(L,4)))
|
||||
&& (ker = luaT_toudata(L,3,torch_(Tensor_id)))
|
||||
&& (im = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (r_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
rgiven = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"[Tensor, ] Tensor, Tensor [, x or c]");
|
||||
}
|
||||
|
||||
if (!r_) r_ = THTensor_(new)();
|
||||
|
||||
if (im->nDimension == 2 && ker->nDimension == 2)
|
||||
{
|
||||
THTensor_(conv2Dmul)(r_,0.0,1.0,im,ker,1,1,type);
|
||||
}
|
||||
else if (im->nDimension == 3 && ker->nDimension == 3)
|
||||
{
|
||||
THTensor_(conv2Dcmul)(r_,0.0,1.0,im,ker,1,1,type);
|
||||
}
|
||||
else if (im->nDimension == 3 && ker->nDimension == 4)
|
||||
{
|
||||
THTensor_(conv2Dmv)(r_,0.0,1.0,im,ker,1,1,type);
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L," (2D,2D) or (3D,3D) or (3D,4D) ");
|
||||
}
|
||||
|
||||
pushreturn(rgiven, r_, torch_(Tensor_id));
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_(convxcorr3)(lua_State *L, char* ktype)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
THTensor *r_ = NULL;
|
||||
THTensor *im = NULL;
|
||||
THTensor *ker = NULL;
|
||||
char type[2];
|
||||
int rgiven = 0;
|
||||
|
||||
type[0] = 'v';
|
||||
type[1] = ktype[0];
|
||||
|
||||
if (narg == 2
|
||||
&& (ker = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (im = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
}
|
||||
else if (narg == 3
|
||||
&& (lua_type(L,3) == LUA_TSTRING)
|
||||
&& (ker = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (im = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
type[0] = *(luaL_checkstring(L,3));
|
||||
luaL_argcheck(L, (type[0] == 'v' || type[0] == 'V' || type[0] == 'f' || type[0] == 'F'),
|
||||
3, "[Tensor, ] Tensor, Tensor [, x or c]");
|
||||
if (type[0] == 'V') type[0] = 'v';
|
||||
if (type[0] == 'F') type[0] = 'f';
|
||||
}
|
||||
else if (narg == 4
|
||||
&& (type[0] = *(luaL_checkstring(L,4)))
|
||||
&& (ker = luaT_toudata(L,3,torch_(Tensor_id)))
|
||||
&& (im = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (r_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
rgiven = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"[Tensor, ] Tensor, Tensor [, x or c]");
|
||||
}
|
||||
|
||||
if (!r_) r_ = THTensor_(new)();
|
||||
|
||||
if (im->nDimension == 3 && ker->nDimension == 3)
|
||||
{
|
||||
THTensor_(conv3Dmul)(r_,0.0,1.0,im,ker,1,1,1,type);
|
||||
}
|
||||
else if (im->nDimension == 4 && ker->nDimension == 4)
|
||||
{
|
||||
THTensor_(conv3Dcmul)(r_,0.0,1.0,im,ker,1,1,1,type);
|
||||
}
|
||||
else if (im->nDimension == 4 && ker->nDimension == 5)
|
||||
{
|
||||
THTensor_(conv3Dmv)(r_,0.0,1.0,im,ker,1,1,1,type);
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L," (3D,3D) or (4D,4D) or (4D,5D) ");
|
||||
}
|
||||
|
||||
pushreturn(rgiven, r_, torch_(Tensor_id));
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_(conv2)(lua_State *L)
|
||||
{
|
||||
return torch_(convxcorr2)(L,"convolution");
|
||||
}
|
||||
static int torch_(xcorr2)(lua_State *L)
|
||||
{
|
||||
return torch_(convxcorr2)(L,"xcorrelation");
|
||||
}
|
||||
|
||||
|
||||
static int torch_(conv3)(lua_State *L)
|
||||
{
|
||||
return torch_(convxcorr3)(L,"convolution");
|
||||
}
|
||||
static int torch_(xcorr3)(lua_State *L)
|
||||
{
|
||||
return torch_(convxcorr3)(L,"xcorrelation");
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_(Conv__) [] = {
|
||||
{"conv2", torch_(conv2)},
|
||||
{"xcorr2", torch_(xcorr2)},
|
||||
{"conv3", torch_(conv3)},
|
||||
{"xcorr3", torch_(xcorr3)},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_(Conv_init)(lua_State *L)
|
||||
{
|
||||
torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor));
|
||||
|
||||
/* register everything into the "torch" field of the tensor metaclass */
|
||||
luaT_pushmetaclass(L, torch_(Tensor_id));
|
||||
lua_pushstring(L, "torch");
|
||||
lua_rawget(L, -2);
|
||||
luaL_register(L, NULL, torch_(Conv__));
|
||||
lua_pop(L, 2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
274
generic/TensorLapack.c
Normal file
274
generic/TensorLapack.c
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/TensorLapack.c"
|
||||
#else
|
||||
|
||||
#define pushreturn(i,t,tid) \
|
||||
if (!i) \
|
||||
luaT_pushudata(L, t, tid); \
|
||||
else \
|
||||
lua_pushvalue(L,i)
|
||||
|
||||
static int torch_(gesv)(lua_State *L)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
THTensor *ra_ = NULL;
|
||||
THTensor *rb_ = NULL;
|
||||
THTensor *a_ = NULL;
|
||||
THTensor *b_ = NULL;
|
||||
int ragiven = 0;
|
||||
int rbgiven = 0;
|
||||
|
||||
if (narg == 2
|
||||
&& (a_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (b_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
}
|
||||
else if (narg == 3
|
||||
&& (a_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (b_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
if(lua_toboolean(L,3))
|
||||
{
|
||||
ra_ = a_;
|
||||
rb_ = b_;
|
||||
a_ = NULL;
|
||||
b_ = NULL;
|
||||
ragiven = 2;
|
||||
rbgiven = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]");
|
||||
}
|
||||
}
|
||||
else if (narg == 4
|
||||
&& (a_ = luaT_toudata(L,4,torch_(Tensor_id)))
|
||||
&& (b_ = luaT_toudata(L,3,torch_(Tensor_id)))
|
||||
&& (ra_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (rb_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
ragiven = 2;
|
||||
rbgiven = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]");
|
||||
}
|
||||
|
||||
if (!ra_) ra_ = THTensor_(new)();
|
||||
if (!rb_) rb_ = THTensor_(new)();
|
||||
|
||||
THTensor_(gesv)(rb_,ra_,b_,a_);
|
||||
|
||||
pushreturn(rbgiven,rb_,torch_(Tensor_id));
|
||||
pushreturn(ragiven,ra_,torch_(Tensor_id));
|
||||
|
||||
return 2;
|
||||
}
|
||||
|
||||
static int torch_(gels)(lua_State *L)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
THTensor *ra_ = NULL;
|
||||
THTensor *rb_ = NULL;
|
||||
THTensor *a_ = NULL;
|
||||
THTensor *b_ = NULL;
|
||||
int ragiven = 0;
|
||||
int rbgiven = 0;
|
||||
|
||||
if (narg == 2
|
||||
&& (a_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (b_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
}
|
||||
else if (narg == 3
|
||||
&& (a_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (b_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
if (lua_toboolean(L,3))
|
||||
{
|
||||
ra_ = a_;
|
||||
rb_ = b_;
|
||||
a_ = NULL;
|
||||
b_ = NULL;
|
||||
ragiven = 2;
|
||||
rbgiven = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]");
|
||||
}
|
||||
}
|
||||
else if (narg == 4
|
||||
&& (a_ = luaT_toudata(L,4,torch_(Tensor_id)))
|
||||
&& (b_ = luaT_toudata(L,3,torch_(Tensor_id)))
|
||||
&& (ra_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (rb_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
ragiven = 2;
|
||||
rbgiven = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"[Tensor, Tensor], Tensor, Tensor, [,true]");
|
||||
}
|
||||
|
||||
if (!ra_) ra_ = THTensor_(new)();
|
||||
if (!rb_) rb_ = THTensor_(new)();
|
||||
|
||||
THTensor_(gels)(rb_,ra_,b_,a_);
|
||||
|
||||
pushreturn(rbgiven,rb_,torch_(Tensor_id));
|
||||
pushreturn(ragiven,ra_,torch_(Tensor_id));
|
||||
|
||||
return 2;
|
||||
}
|
||||
|
||||
static int torch_(eig)(lua_State *L)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
THTensor *re_ = NULL;
|
||||
THTensor *rv_ = NULL;
|
||||
THTensor *a_ = NULL;
|
||||
char type = 'N';
|
||||
char uplo = 'U';
|
||||
int regiven = 0;
|
||||
int rvgiven = 0;
|
||||
|
||||
if (narg == 1
|
||||
&& (a_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
}
|
||||
else if (narg == 2
|
||||
&& (lua_type(L,2) == LUA_TSTRING)
|
||||
&& (a_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
type = *(luaL_checkstring(L,2));
|
||||
luaL_argcheck(L, (type == 'v' || type == 'V' || type == 'n' || type == 'N'),
|
||||
2, "[Tensor, ] [Tensor, ] Tensor [, N or V]");
|
||||
if (type == 'v') type = 'V';
|
||||
if (type == 'n') type = 'N';
|
||||
}
|
||||
else if (narg == 2
|
||||
&& (a_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (re_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
regiven = 1;
|
||||
}
|
||||
else if (narg == 3
|
||||
&& (a_ = luaT_toudata(L,3,torch_(Tensor_id)))
|
||||
&& (rv_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (re_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
regiven = 1;
|
||||
rvgiven = 2;
|
||||
}
|
||||
else if (narg == 4
|
||||
&& (type = *(luaL_checkstring(L,4)))
|
||||
&& (a_ = luaT_toudata(L,3,torch_(Tensor_id)))
|
||||
&& (rv_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (re_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
regiven = 1;
|
||||
rvgiven = 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"[Tensor, ] [Tensor, ] Tensor [, N or V]");
|
||||
}
|
||||
if (!re_) re_ = THTensor_(new)();
|
||||
if (!rv_) rv_ = THTensor_(new)();
|
||||
|
||||
THTensor_(syev)(re_,rv_,a_,&type,&uplo);
|
||||
|
||||
pushreturn(regiven, re_, torch_(Tensor_id));
|
||||
pushreturn(rvgiven, rv_, torch_(Tensor_id));
|
||||
|
||||
return 2;
|
||||
}
|
||||
|
||||
static int torch_(svd)(lua_State *L)
|
||||
{
|
||||
int narg = lua_gettop(L);
|
||||
THTensor *ru_ = NULL;
|
||||
THTensor *rs_ = NULL;
|
||||
THTensor *rv_ = NULL;
|
||||
THTensor *a_ = NULL;
|
||||
char type = 'S';
|
||||
int rugiven = 0;
|
||||
int rsgiven = 0;
|
||||
int rvgiven = 0;
|
||||
|
||||
if (narg == 1
|
||||
&& (a_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
}
|
||||
else if (narg ==2
|
||||
&& (type = *(luaL_checkstring(L,2)))
|
||||
&& (a_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
luaL_argcheck(L, (type == 's' || type == 'S' || type == 'a' || type == 'A'),
|
||||
2, "[Tensor, ] [Tensor, ] [Tensor, ] Tensor [, A or S]");
|
||||
if (type == 's') type = 'S';
|
||||
if (type == 'a') type = 'A';
|
||||
}
|
||||
else if (narg == 4
|
||||
&& (a_ = luaT_toudata(L,4,torch_(Tensor_id)))
|
||||
&& (rv_ = luaT_toudata(L,3,torch_(Tensor_id)))
|
||||
&& (rs_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (ru_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
rugiven = 1;
|
||||
rsgiven = 2;
|
||||
rvgiven = 3;
|
||||
}
|
||||
else if (narg == 5
|
||||
&& (type = *(luaL_checkstring(L,5)))
|
||||
&& (a_ = luaT_toudata(L,4,torch_(Tensor_id)))
|
||||
&& (rv_ = luaT_toudata(L,3,torch_(Tensor_id)))
|
||||
&& (rs_ = luaT_toudata(L,2,torch_(Tensor_id)))
|
||||
&& (ru_ = luaT_toudata(L,1,torch_(Tensor_id))))
|
||||
{
|
||||
rugiven = 1;
|
||||
rsgiven = 2;
|
||||
rvgiven = 3;
|
||||
}
|
||||
else
|
||||
{
|
||||
luaL_error(L,"[Tensor, Tensor, Tensor], Tensor, [, 'A' or 'S' ]");
|
||||
}
|
||||
|
||||
if (!ru_) ru_ = THTensor_(new)();
|
||||
if (!rs_) rs_ = THTensor_(new)();
|
||||
if (!rv_) rv_ = THTensor_(new)();
|
||||
|
||||
THTensor_(gesvd)(ru_,rs_,rv_,a_,&type);
|
||||
|
||||
pushreturn(rugiven,ru_,torch_(Tensor_id));
|
||||
pushreturn(rsgiven,rs_,torch_(Tensor_id));
|
||||
pushreturn(rvgiven,rv_,torch_(Tensor_id));
|
||||
|
||||
return 3;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_(lapack__) [] = {
|
||||
{"gesv", torch_(gesv)},
|
||||
{"gels", torch_(gels)},
|
||||
{"eig", torch_(eig)},
|
||||
{"svd", torch_(svd)},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_(Lapack_init)(lua_State *L)
|
||||
{
|
||||
torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor));
|
||||
|
||||
/* register everything into the "torch" field of the tensor metaclass */
|
||||
luaT_pushmetaclass(L, torch_(Tensor_id));
|
||||
lua_pushstring(L, "torch");
|
||||
lua_rawget(L, -2);
|
||||
luaL_register(L, NULL, torch_(lapack__));
|
||||
lua_pop(L, 2);
|
||||
}
|
||||
|
||||
#endif
|
||||
177
generic/TensorOperator.c
Normal file
177
generic/TensorOperator.c
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/TensorOperator.c"
|
||||
#else
|
||||
|
||||
static const void* torch_Tensor_id;
|
||||
|
||||
static int torch_TensorOperator_(__add__)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id);
|
||||
THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id);
|
||||
THTensor *r;
|
||||
|
||||
if(!tensor1 && !tensor2)
|
||||
luaL_error(L, "expecting two Tensors or one Tensor and one number");
|
||||
else
|
||||
{
|
||||
r = THTensor_(new)();
|
||||
luaT_pushudata(L, r, torch_Tensor_id);
|
||||
|
||||
if(!tensor1 && tensor2)
|
||||
{
|
||||
THTensor_(resizeAs)(r, tensor2);
|
||||
THTensor_(copy)(r, tensor2);
|
||||
THTensor_(add)(r, r, luaL_checknumber(L, 1));
|
||||
}
|
||||
else if(tensor1 && !tensor2)
|
||||
{
|
||||
THTensor_(resizeAs)(r, tensor1);
|
||||
THTensor_(copy)(r, tensor1);
|
||||
THTensor_(add)(r, r, luaL_checknumber(L, 2));
|
||||
}
|
||||
else
|
||||
{
|
||||
THTensor_(resizeAs)(r, tensor1);
|
||||
THTensor_(copy)(r, tensor1);
|
||||
THTensor_(cadd)(r, r, 1, tensor2);
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_TensorOperator_(__sub__)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id);
|
||||
THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id);
|
||||
THTensor *r;
|
||||
|
||||
if(!tensor1 && !tensor2)
|
||||
luaL_error(L, "expecting two Tensors or one Tensor and one number");
|
||||
else
|
||||
{
|
||||
r = THTensor_(new)();
|
||||
luaT_pushudata(L, r, torch_Tensor_id);
|
||||
|
||||
if(!tensor1 && tensor2)
|
||||
{
|
||||
THTensor_(resizeAs)(r, tensor2);
|
||||
THTensor_(fill)(r, luaL_checknumber(L, 1));
|
||||
THTensor_(cadd)(r, r, -1, tensor2);
|
||||
}
|
||||
else if(tensor1 && !tensor2)
|
||||
{
|
||||
THTensor_(resizeAs)(r, tensor1);
|
||||
THTensor_(copy)(r, tensor1);
|
||||
THTensor_(add)(r, r, -luaL_checknumber(L, 2));
|
||||
}
|
||||
else
|
||||
{
|
||||
THTensor_(resizeAs)(r, tensor1);
|
||||
THTensor_(copy)(r, tensor1);
|
||||
THTensor_(cadd)(r, r, -1, tensor2);
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_TensorOperator_(__unm__)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THTensor *r;
|
||||
|
||||
r = THTensor_(new)();
|
||||
luaT_pushudata(L, r, torch_Tensor_id);
|
||||
THTensor_(resizeAs)(r, tensor);
|
||||
THTensor_(copy)(r, tensor);
|
||||
THTensor_(mul)(r, r, -1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_TensorOperator_(__mul__)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id);
|
||||
THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id);
|
||||
THTensor *r;
|
||||
|
||||
if(!tensor1 && !tensor2)
|
||||
luaL_error(L, "expecting two Tensors or one Tensor and one number");
|
||||
else
|
||||
{
|
||||
r = THTensor_(new)();
|
||||
luaT_pushudata(L, r, torch_Tensor_id);
|
||||
|
||||
if(!tensor1 && tensor2)
|
||||
{
|
||||
THTensor_(resizeAs)(r, tensor2);
|
||||
THTensor_(copy)(r, tensor2);
|
||||
THTensor_(mul)(r, r, luaL_checknumber(L, 1));
|
||||
}
|
||||
else if(tensor1 && !tensor2)
|
||||
{
|
||||
THTensor_(resizeAs)(r, tensor1);
|
||||
THTensor_(copy)(r, tensor1);
|
||||
THTensor_(mul)(r, r, luaL_checknumber(L, 2));
|
||||
}
|
||||
else
|
||||
{
|
||||
int dimt = tensor1->nDimension;
|
||||
int dims = tensor2->nDimension;
|
||||
|
||||
if(dimt == 1 && dims == 1)
|
||||
lua_pushnumber(L, THTensor_(dot)(tensor1, tensor2)); /* ok, we wasted r, but who cares */
|
||||
else if(dimt == 2 && dims == 1)
|
||||
{
|
||||
THTensor_(resize1d)(r, tensor1->size[0]);
|
||||
THTensor_(zero)(r);
|
||||
THTensor_(addmv)(r, 1, r, 1, tensor1, tensor2);
|
||||
}
|
||||
else if(dimt == 2 && dims == 2)
|
||||
{
|
||||
THTensor_(resize2d)(r, tensor1->size[0], tensor2->size[1]);
|
||||
THTensor_(zero)(r);
|
||||
THTensor_(addmm)(r, 1, r, 1, tensor1, tensor2);
|
||||
}
|
||||
else
|
||||
luaL_error(L, "multiplication between %dD and %dD tensors not yet supported", tensor1->nDimension, tensor2->nDimension);
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int torch_TensorOperator_(__div__)(lua_State *L)
|
||||
{
|
||||
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
|
||||
THTensor *r;
|
||||
|
||||
luaL_argcheck(L, lua_isnumber(L,2), 2, "number expected");
|
||||
|
||||
r = THTensor_(new)();
|
||||
luaT_pushudata(L, r, torch_Tensor_id);
|
||||
|
||||
THTensor_(resizeAs)(r, tensor);
|
||||
THTensor_(copy)(r, tensor);
|
||||
THTensor_(mul)(r, r, 1/lua_tonumber(L, 2));
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg torch_TensorOperator_(_) [] = {
|
||||
{"__add__", torch_TensorOperator_(__add__)},
|
||||
{"__sub__", torch_TensorOperator_(__sub__)},
|
||||
{"__unm__", torch_TensorOperator_(__unm__)},
|
||||
{"__mul__", torch_TensorOperator_(__mul__)},
|
||||
{"__div__", torch_TensorOperator_(__div__)},
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void torch_TensorOperator_(init)(lua_State *L)
|
||||
{
|
||||
torch_Tensor_id = luaT_checktypename2id(L, STRING_torchTensor);
|
||||
|
||||
luaT_pushmetaclass(L, torch_Tensor_id);
|
||||
luaL_register(L, NULL, torch_TensorOperator_(_));
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
|
||||
#endif
|
||||
44
generic/hist.c
Normal file
44
generic/hist.c
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/lab.c"
|
||||
#else
|
||||
|
||||
#include "interfaces.c"
|
||||
|
||||
static int lab_(histc)(lua_State *L)
|
||||
{
|
||||
THTensor *r = luaT_checkudata(L, 1, torch_(Tensor_id));
|
||||
THTensor *h = luaT_checkudata(L, 2, torch_(Tensor_id));
|
||||
int nbins = luaL_checknumber(L, 3);
|
||||
real *h_data = THTensor_(data)(h);
|
||||
|
||||
TH_TENSOR_APPLY(real, r, \
|
||||
if ((*r_data <= nbins) && (*r_data >= 1)) { \
|
||||
*(h_data + (int)(*r_data) - 1) += 1; \
|
||||
})
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg lab_(stuff__) [] = {
|
||||
{"_histc", lab_(histc)},
|
||||
#endif
|
||||
{NULL, NULL}
|
||||
};
|
||||
|
||||
void lab_(init)(lua_State *L)
|
||||
{
|
||||
torch_(Tensor_id) = luaT_checktypename2id(L, torch_string_(Tensor));
|
||||
torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
||||
|
||||
/* register everything into the "lab" field of the tensor metaclass */
|
||||
luaT_pushmetaclass(L, torch_(Tensor_id));
|
||||
lua_pushstring(L, "lab");
|
||||
lua_newtable(L);
|
||||
luaL_register(L, NULL, lab_(stuff__));
|
||||
lua_rawset(L, -3);
|
||||
lua_pop(L, 1);
|
||||
|
||||
/* luaT_registeratid(L, lab_(stuff__), torch_(Tensor_id)); */
|
||||
/* luaL_register(L, NULL, lab_(stuff__)); */
|
||||
}
|
||||
|
||||
#endif
|
||||
123
hist.lua
Normal file
123
hist.lua
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
--
|
||||
-- rudimentary histogram diplay on the command line.
|
||||
--
|
||||
-- Author: Marco Scoffier
|
||||
-- Date :
|
||||
-- Mod : Oct 21, 2011
|
||||
-- + made 80 columns default
|
||||
-- + save index of max bin in h.max not pointer to bin
|
||||
--
|
||||
function torch.histc__tostring(h, barHeight)
|
||||
barHeight = barHeight or 10
|
||||
local lastm = h[h.max].nb
|
||||
local incr = lastm/(barHeight+1)
|
||||
local m = lastm - incr
|
||||
local tl = torch.Tensor(#h):fill(0)
|
||||
local toph = '|'
|
||||
local topm = ':'
|
||||
local topl = '.'
|
||||
local bar = '|'
|
||||
local blank = ' '
|
||||
local yaxis = '--------:'
|
||||
local str = 'nsamples:'
|
||||
str = str ..
|
||||
string.format(' min:(bin:%d/#%d/cntr:%2.2f) max:(bin:%d/#%d/cntr:%2.2f)\n',
|
||||
h.min,h[h.min].nb,h[h.min].val,
|
||||
h.max,h[h.max].nb,h[h.max].val)
|
||||
|
||||
str = str .. yaxis
|
||||
for j = 1,#h do
|
||||
str = str .. '-'
|
||||
end
|
||||
str = str .. '\n'
|
||||
|
||||
for i = 1,barHeight do
|
||||
-- y axis
|
||||
if i%1==0 then
|
||||
str = str .. string.format('%1.2e:',m)
|
||||
end
|
||||
for j = 1,#h do
|
||||
if tl[j] == 1 then
|
||||
str = str .. bar
|
||||
elseif h[j].nb < m then
|
||||
str = str .. blank
|
||||
else
|
||||
-- in the bracket
|
||||
tl[j] = 1
|
||||
-- find 1/3rds
|
||||
local p = (lastm - h[j].nb) / incr
|
||||
if p > 0.66 then
|
||||
str = str .. toph
|
||||
elseif p > 0.33 then
|
||||
str = str .. topm
|
||||
else
|
||||
str = str .. topl
|
||||
end
|
||||
end
|
||||
end
|
||||
str = str .. '\n'
|
||||
lastm = m
|
||||
m = m - incr
|
||||
end
|
||||
-- x axis
|
||||
str = str .. yaxis
|
||||
for j = 1,#h do
|
||||
if ((j - 2) % 6 == 0)then
|
||||
str = str .. '^'
|
||||
else
|
||||
str = str .. '-'
|
||||
end
|
||||
end
|
||||
str = str .. '\ncenters '
|
||||
for j = 1,#h do
|
||||
if ((j - 2) % 6 == 0)then
|
||||
if h[j].val < 0 then
|
||||
str = str .. '-'
|
||||
else
|
||||
str = str .. '+'
|
||||
end
|
||||
str = str .. string.format('%1.2f ',math.abs(h[j].val))
|
||||
end
|
||||
end
|
||||
return str
|
||||
end
|
||||
|
||||
-- a simple function that computes the histogram of a tensor
|
||||
function torch.histc(...)
|
||||
-- get args
|
||||
local args = {...}
|
||||
local tensor = args[1] or error('usage: torch.histc (tensor [, nBins] [, min] [, max]')
|
||||
local bins = args[2] or 80 - 8
|
||||
local min = args[3] or tensor:min()
|
||||
local max = args[4] or tensor:max()
|
||||
local raw = args[5] or false
|
||||
|
||||
-- compute histogram
|
||||
local hist = torch.zeros(bins)
|
||||
local ten = torch.Tensor(tensor:nElement()):copy(tensor)
|
||||
ten:add(-min):div(max-min):mul(bins - 1e-6):floor():add(1)
|
||||
ten.torch._histc(ten, hist, bins)
|
||||
|
||||
-- return raw histogram (no extra info)
|
||||
if raw then return hist end
|
||||
|
||||
-- cleanup hist
|
||||
local cleanhist = {}
|
||||
cleanhist.raw = hist
|
||||
local _,mx = torch.max(cleanhist.raw)
|
||||
local _,mn = torch.min(cleanhist.raw)
|
||||
cleanhist.bins = bins
|
||||
cleanhist.binwidth = (max-min)/bins
|
||||
for i = 1,bins do
|
||||
cleanhist[i] = {}
|
||||
cleanhist[i].val = min + (i-0.5)*cleanhist.binwidth
|
||||
cleanhist[i].nb = hist[i]
|
||||
end
|
||||
cleanhist.max = mx[1]
|
||||
cleanhist.min = mn[1]
|
||||
|
||||
-- print function
|
||||
setmetatable(cleanhist, {__tostring=torch.histc__tostring})
|
||||
return cleanhist
|
||||
end
|
||||
|
||||
99
init.c
Normal file
99
init.c
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
#include "general.h"
|
||||
#include "utils.h"
|
||||
|
||||
extern void torch_utils_init(lua_State *L);
|
||||
extern void torch_random_init(lua_State *L);
|
||||
extern void torch_File_init(lua_State *L);
|
||||
extern void torch_File_init_storage_id(lua_State *L);
|
||||
extern void torch_DiskFile_init(lua_State *L);
|
||||
extern void torch_MemoryFile_init(lua_State *L);
|
||||
extern void torch_PipeFile_init(lua_State *L);
|
||||
extern void torch_Timer_init(lua_State *L);
|
||||
|
||||
extern void torch_ByteStorage_init(lua_State *L);
|
||||
extern void torch_CharStorage_init(lua_State *L);
|
||||
extern void torch_ShortStorage_init(lua_State *L);
|
||||
extern void torch_IntStorage_init(lua_State *L);
|
||||
extern void torch_LongStorage_init(lua_State *L);
|
||||
extern void torch_FloatStorage_init(lua_State *L);
|
||||
extern void torch_DoubleStorage_init(lua_State *L);
|
||||
|
||||
extern void torch_ByteTensor_init(lua_State *L);
|
||||
extern void torch_CharTensor_init(lua_State *L);
|
||||
extern void torch_ShortTensor_init(lua_State *L);
|
||||
extern void torch_IntTensor_init(lua_State *L);
|
||||
extern void torch_LongTensor_init(lua_State *L);
|
||||
extern void torch_FloatTensor_init(lua_State *L);
|
||||
extern void torch_DoubleTensor_init(lua_State *L);
|
||||
|
||||
extern void torch_ByteTensorOperator_init(lua_State *L);
|
||||
extern void torch_CharTensorOperator_init(lua_State *L);
|
||||
extern void torch_ShortTensorOperator_init(lua_State *L);
|
||||
extern void torch_IntTensorOperator_init(lua_State *L);
|
||||
extern void torch_LongTensorOperator_init(lua_State *L);
|
||||
extern void torch_FloatTensorOperator_init(lua_State *L);
|
||||
extern void torch_DoubleTensorOperator_init(lua_State *L);
|
||||
|
||||
extern void torch_TensorMath_init(lua_State *L);
|
||||
|
||||
static lua_State *globalL;
|
||||
static void luaTorchErrorHandlerFunction(const char *msg)
|
||||
{
|
||||
luaL_error(globalL, msg);
|
||||
}
|
||||
|
||||
static void luaTorchArgCheckHandlerFunction(int condition, int argNumber, const char *msg)
|
||||
{
|
||||
luaL_argcheck(globalL, condition, argNumber, msg);
|
||||
}
|
||||
|
||||
DLL_EXPORT int luaopen_libtorch(lua_State *L)
|
||||
{
|
||||
globalL = L;
|
||||
THSetErrorHandler(luaTorchErrorHandlerFunction);
|
||||
THSetArgCheckHandler(luaTorchArgCheckHandlerFunction);
|
||||
|
||||
lua_newtable(L);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, LUA_GLOBALSINDEX, "torch");
|
||||
|
||||
torch_File_init(L);
|
||||
|
||||
torch_ByteStorage_init(L);
|
||||
torch_CharStorage_init(L);
|
||||
torch_ShortStorage_init(L);
|
||||
torch_IntStorage_init(L);
|
||||
torch_LongStorage_init(L);
|
||||
torch_FloatStorage_init(L);
|
||||
torch_DoubleStorage_init(L);
|
||||
|
||||
torch_ByteTensor_init(L);
|
||||
torch_CharTensor_init(L);
|
||||
torch_ShortTensor_init(L);
|
||||
torch_IntTensor_init(L);
|
||||
torch_LongTensor_init(L);
|
||||
torch_FloatTensor_init(L);
|
||||
torch_DoubleTensor_init(L);
|
||||
|
||||
torch_File_init_storage_id(L);
|
||||
|
||||
torch_ByteTensorOperator_init(L);
|
||||
torch_CharTensorOperator_init(L);
|
||||
torch_ShortTensorOperator_init(L);
|
||||
torch_IntTensorOperator_init(L);
|
||||
torch_LongTensorOperator_init(L);
|
||||
torch_FloatTensorOperator_init(L);
|
||||
torch_DoubleTensorOperator_init(L);
|
||||
|
||||
torch_Timer_init(L);
|
||||
torch_DiskFile_init(L);
|
||||
torch_PipeFile_init(L);
|
||||
torch_MemoryFile_init(L);
|
||||
|
||||
torch_TensorMath_init(L);
|
||||
|
||||
torch_utils_init(L);
|
||||
torch_random_init(L);
|
||||
|
||||
return 1;
|
||||
}
|
||||
78
init.lua
Normal file
78
init.lua
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
|
||||
-- We are using paths.require to appease mkl
|
||||
require "paths"
|
||||
paths.require "libtorch"
|
||||
require "libtorch"
|
||||
|
||||
--- package stuff
|
||||
function torch.packageLuaPath(name)
|
||||
if not name then
|
||||
local ret = string.match(torch.packageLuaPath('torch'), '(.*)/')
|
||||
if not ret then --windows?
|
||||
ret = string.match(torch.packageLuaPath('torch'), '(.*)\\')
|
||||
end
|
||||
return ret
|
||||
end
|
||||
for path in string.gmatch(package.path, "(.-);") do
|
||||
path = string.gsub(path, "%?", name)
|
||||
local f = io.open(path)
|
||||
if f then
|
||||
f:close()
|
||||
local ret = string.match(path, "(.*)/")
|
||||
if not ret then --windows?
|
||||
ret = string.match(path, "(.*)\\")
|
||||
end
|
||||
return ret
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function torch.include(package, file)
|
||||
dofile(torch.packageLuaPath(package) .. '/' .. file)
|
||||
end
|
||||
|
||||
function torch.class(tname, parenttname)
|
||||
|
||||
local function constructor(...)
|
||||
local self = {}
|
||||
torch.setmetatable(self, tname)
|
||||
if self.__init then
|
||||
self:__init(...)
|
||||
end
|
||||
return self
|
||||
end
|
||||
|
||||
local function factory()
|
||||
local self = {}
|
||||
torch.setmetatable(self, tname)
|
||||
return self
|
||||
end
|
||||
|
||||
local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory)
|
||||
local mpt
|
||||
if parenttname then
|
||||
mpt = torch.getmetatable(parenttname)
|
||||
end
|
||||
return mt, mpt
|
||||
end
|
||||
|
||||
function torch.setdefaulttensortype(typename)
|
||||
assert(type(typename) == 'string', 'string expected')
|
||||
if torch.getconstructortable(typename) then
|
||||
torch.Tensor = torch.getconstructortable(typename)
|
||||
torch.Storage = torch.getconstructortable(torch.typename(torch.Tensor(1):storage()))
|
||||
torch.__setdefaulttensortype(typename)
|
||||
else
|
||||
error(string.format("<%s> is not a string describing a torch object", typename))
|
||||
end
|
||||
end
|
||||
|
||||
torch.setdefaulttensortype('torch.DoubleTensor')
|
||||
|
||||
torch.include('torch', 'Tensor.lua')
|
||||
torch.include('torch', 'File.lua')
|
||||
torch.include('torch', 'CmdLine.lua')
|
||||
torch.include('torch', 'Tester.lua')
|
||||
torch.include('torch', 'TensorMath.lua')
|
||||
torch.include('torch', 'test.lua')
|
||||
return torch
|
||||
2
lib/CMakeLists.txt
Normal file
2
lib/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
ADD_SUBDIRECTORY(TH)
|
||||
ADD_SUBDIRECTORY(luaT)
|
||||
117
lib/TH/CMakeLists.txt
Normal file
117
lib/TH/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
# -*- cmake -*-
|
||||
|
||||
SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
|
||||
|
||||
SET(hdr
|
||||
THGeneral.h THStorage.h THTensor.h THTensorApply.h
|
||||
THBlas.h THLapack.h THLogAdd.h THRandom.h THVector.h)
|
||||
SET(src
|
||||
THGeneral.c THStorage.c THTensor.c THBlas.c THLapack.c
|
||||
THLogAdd.c THRandom.c
|
||||
THFile.c THDiskFile.c THMemoryFile.c)
|
||||
|
||||
SET(src ${src} ${hdr})
|
||||
|
||||
IF(UNIX)
|
||||
INCLUDE(CheckFunctionExists)
|
||||
SET(CMAKE_EXTRA_INCLUDE_FILES "sys/mman.h")
|
||||
CHECK_FUNCTION_EXISTS(mmap HAVE_MMAP)
|
||||
IF(HAVE_MMAP)
|
||||
ADD_DEFINITIONS(-DHAVE_MMAP=1)
|
||||
ENDIF(HAVE_MMAP)
|
||||
ENDIF(UNIX)
|
||||
|
||||
ADD_LIBRARY(TH SHARED ${src})
|
||||
|
||||
FIND_PACKAGE(BLAS)
|
||||
FIND_PACKAGE(LAPACK)
|
||||
|
||||
IF (LAPACK_FOUND)
|
||||
SET(CMAKE_C_FLAGS "-D__LAPACK__ ${CMAKE_C_FLAGS}")
|
||||
ENDIF(LAPACK_FOUND)
|
||||
|
||||
FIND_PACKAGE(SSE)
|
||||
|
||||
IF (SSE2_FOUND)
|
||||
SET(CMAKE_C_FLAGS "-msse2 -D__SSE2__ ${CMAKE_C_FLAGS}")
|
||||
ENDIF (SSE2_FOUND)
|
||||
IF (SSE3_FOUND)
|
||||
SET(CMAKE_C_FLAGS "-msse3 -D__SSE3__ ${CMAKE_C_FLAGS}")
|
||||
ENDIF (SSE3_FOUND)
|
||||
IF (SSSE3_FOUND)
|
||||
SET(CMAKE_C_FLAGS "-mssse3 -D__SSSE3__ ${CMAKE_C_FLAGS}")
|
||||
ENDIF (SSSE3_FOUND)
|
||||
IF (SSE4.1_FOUND)
|
||||
SET(CMAKE_C_FLAGS "-msse4.1 -D__SSE4_1__ ${CMAKE_C_FLAGS}")
|
||||
ENDIF (SSE4.1_FOUND)
|
||||
|
||||
IF(BLAS_FOUND)
|
||||
ADD_DEFINITIONS(-DUSE_LAPACK)
|
||||
# INCLUDE_DIRECTORIES(${CBLAS_INCLUDE_DIR})
|
||||
TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES})
|
||||
ENDIF(BLAS_FOUND)
|
||||
|
||||
#CONFIGURE_FILE("THCBlas.h.in" "${CMAKE_CURRENT_BINARY_DIR}/THCBlas.h")
|
||||
#INCLUDE_DIRECTORIES("${CMAKE_CURRENT_BINARY_DIR}")
|
||||
#INSTALL(FILES "${CMAKE_CURRENT_BINARY_DIR}/THCBlas.h"
|
||||
# DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH")
|
||||
|
||||
INSTALL(TARGETS TH
|
||||
RUNTIME DESTINATION "${Torch_INSTALL_BIN_SUBDIR}"
|
||||
LIBRARY DESTINATION "${Torch_INSTALL_LIB_SUBDIR}"
|
||||
ARCHIVE DESTINATION "${Torch_INSTALL_LIB_SUBDIR}")
|
||||
|
||||
INSTALL(FILES
|
||||
TH.h
|
||||
THBlas.h
|
||||
THDiskFile.h
|
||||
THFile.h
|
||||
THFilePrivate.h
|
||||
THGeneral.h
|
||||
THGenerateAllTypes.h
|
||||
THGenerateFloatTypes.h
|
||||
THGenerateIntTypes.h
|
||||
THLapack.h
|
||||
THLogAdd.h
|
||||
THMemoryFile.h
|
||||
THRandom.h
|
||||
THStorage.h
|
||||
THTensor.h
|
||||
THTensorApply.h
|
||||
THTensorDimApply.h
|
||||
THTensorMacros.h
|
||||
THVector.h
|
||||
DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH")
|
||||
|
||||
INSTALL(FILES
|
||||
generic/THBlas.c
|
||||
generic/THBlas.h
|
||||
generic/THLapack.c
|
||||
generic/THLapack.h
|
||||
generic/THStorage.c
|
||||
generic/THStorage.h
|
||||
generic/THStorageCopy.c
|
||||
generic/THStorageCopy.h
|
||||
generic/THTensor.c
|
||||
generic/THTensor.h
|
||||
generic/THTensorConv.c
|
||||
generic/THTensorConv.h
|
||||
generic/THTensorCopy.c
|
||||
generic/THTensorCopy.h
|
||||
generic/THTensorLapack.c
|
||||
generic/THTensorLapack.h
|
||||
generic/THTensorMath.c
|
||||
generic/THTensorMath.h
|
||||
generic/THTensorRandom.c
|
||||
generic/THTensorRandom.h
|
||||
generic/THVector.c
|
||||
DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH/generic")
|
||||
|
||||
# Create THConfig.cmake
|
||||
GET_TARGET_PROPERTY(TH_OUTPUT_NAME TH LOCATION)
|
||||
GET_FILENAME_COMPONENT(TH_OUTPUT_NAME ${TH_OUTPUT_NAME} NAME)
|
||||
SET(TH_LIBRARIES "${Torch_INSTALL_LIB}/${TH_OUTPUT_NAME}")
|
||||
SET(TH_INCLUDE_DIR "${Torch_INSTALL_INCLUDE}/TH")
|
||||
CONFIGURE_FILE(THConfig.cmake.in "${Torch_BINARY_DIR}/cmake-external/THConfig.cmake")
|
||||
INSTALL(FILES "${Torch_BINARY_DIR}/cmake-external/THConfig.cmake"
|
||||
DESTINATION "${Torch_INSTALL_CMAKE_SUBDIR}")
|
||||
23
lib/TH/TH.h
Normal file
23
lib/TH/TH.h
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#ifndef TH_INC
|
||||
#define TH_INC
|
||||
|
||||
#include "THBlas.h"
|
||||
|
||||
#ifdef __LAPACK__
|
||||
#include "THLapack.h"
|
||||
#endif
|
||||
|
||||
#include "THVector.h"
|
||||
#include "THGeneral.h"
|
||||
#include "THLogAdd.h"
|
||||
#include "THRandom.h"
|
||||
#include "THStorage.h"
|
||||
#include "THTensor.h"
|
||||
#include "THTensorApply.h"
|
||||
#include "THTensorDimApply.h"
|
||||
|
||||
#include "THFile.h"
|
||||
#include "THDiskFile.h"
|
||||
#include "THMemoryFile.h"
|
||||
|
||||
#endif
|
||||
5
lib/TH/THBlas.c
Normal file
5
lib/TH/THBlas.c
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#include "THBlas.h"
|
||||
|
||||
/* #include "THCBlas.h" */
|
||||
#include "generic/THBlas.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
11
lib/TH/THBlas.h
Normal file
11
lib/TH/THBlas.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#ifndef TH_BLAS_INC
|
||||
#define TH_BLAS_INC
|
||||
|
||||
#include "THGeneral.h"
|
||||
|
||||
#define THBlas_(NAME) TH_CONCAT_4(TH,Real,Blas_,NAME)
|
||||
|
||||
#include "generic/THBlas.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#endif
|
||||
8
lib/TH/THCBlas.h.in
Normal file
8
lib/TH/THCBlas.h.in
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
/* -*- C -*- */
|
||||
|
||||
#cmakedefine USE_CBLAS @USE_CBLAS@
|
||||
|
||||
#if USE_CBLAS
|
||||
# include "@CBLAS_INCLUDE_FILE@"
|
||||
#endif
|
||||
|
||||
9
lib/TH/THConfig.cmake.in
Normal file
9
lib/TH/THConfig.cmake.in
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# Find the TH includes and library
|
||||
#
|
||||
# TH_INCLUDE_DIR -- where to find the includes
|
||||
# TH_LIBRARIES -- list of libraries to link against
|
||||
# TH_FOUND -- set to 1 if found
|
||||
|
||||
SET(TH_FOUND 1)
|
||||
SET(TH_INCLUDE_DIR "@TH_INCLUDE_DIR@")
|
||||
SET(TH_LIBRARIES "@TH_LIBRARIES@")
|
||||
592
lib/TH/THDiskFile.c
Normal file
592
lib/TH/THDiskFile.c
Normal file
|
|
@ -0,0 +1,592 @@
|
|||
#include "THGeneral.h"
|
||||
#include "THDiskFile.h"
|
||||
#include "THFilePrivate.h"
|
||||
|
||||
typedef struct THDiskFile__
|
||||
{
|
||||
THFile file;
|
||||
|
||||
FILE *handle;
|
||||
char *name;
|
||||
int isNativeEncoding;
|
||||
|
||||
} THDiskFile;
|
||||
|
||||
static int THDiskFile_isOpened(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)self;
|
||||
return (dfself->handle != NULL);
|
||||
}
|
||||
|
||||
const char *THDiskFile_name(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)self;
|
||||
return dfself->name;
|
||||
}
|
||||
|
||||
|
||||
#define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM) \
|
||||
static long THDiskFile_read##TYPEC(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
THDiskFile *dfself = (THDiskFile*)(self); \
|
||||
long nread = 0L; \
|
||||
\
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \
|
||||
THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); \
|
||||
\
|
||||
if(dfself->file.isBinary) \
|
||||
{ \
|
||||
nread = fread(data, sizeof(TYPE), n, dfself->handle); \
|
||||
if(!dfself->isNativeEncoding && (sizeof(TYPE) > 1) && (nread > 0)) \
|
||||
THDiskFile_reverseMemory(data, data, sizeof(TYPE), nread); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
long i; \
|
||||
for(i = 0; i < n; i++) \
|
||||
{ \
|
||||
ASCII_READ_ELEM; /* increment here result and break if wrong */ \
|
||||
} \
|
||||
if(dfself->file.isAutoSpacing && (n > 0)) \
|
||||
{ \
|
||||
int c = fgetc(dfself->handle); \
|
||||
if( (c != '\n') && (c != EOF) ) \
|
||||
ungetc(c, dfself->handle); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if(nread != n) \
|
||||
{ \
|
||||
dfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \
|
||||
if(!dfself->file.isQuiet) \
|
||||
THError("read error: read %d blocks instead of %d", nread, n); \
|
||||
} \
|
||||
\
|
||||
return nread; \
|
||||
} \
|
||||
\
|
||||
static long THDiskFile_write##TYPEC(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
THDiskFile *dfself = (THDiskFile*)(self); \
|
||||
long nwrite = 0L; \
|
||||
\
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \
|
||||
THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); \
|
||||
\
|
||||
if(dfself->file.isBinary) \
|
||||
{ \
|
||||
if(dfself->isNativeEncoding) \
|
||||
{ \
|
||||
nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(sizeof(TYPE) > 1) \
|
||||
{ \
|
||||
char *buffer = THAlloc(sizeof(TYPE)*n); \
|
||||
THDiskFile_reverseMemory(buffer, data, sizeof(TYPE), n); \
|
||||
nwrite = fwrite(buffer, sizeof(TYPE), n, dfself->handle); \
|
||||
THFree(buffer); \
|
||||
} \
|
||||
else \
|
||||
nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
long i; \
|
||||
for(i = 0; i < n; i++) \
|
||||
{ \
|
||||
ASCII_WRITE_ELEM; \
|
||||
if( dfself->file.isAutoSpacing && (i < n-1) ) \
|
||||
fprintf(dfself->handle, " "); \
|
||||
} \
|
||||
if(dfself->file.isAutoSpacing && (n > 0)) \
|
||||
fprintf(dfself->handle, "\n"); \
|
||||
} \
|
||||
\
|
||||
if(nwrite != n) \
|
||||
{ \
|
||||
dfself->file.hasError = 1; \
|
||||
if(!dfself->file.isQuiet) \
|
||||
THError("write error: wrote %d blocks instead of %d", nwrite, n); \
|
||||
} \
|
||||
\
|
||||
return nwrite; \
|
||||
}
|
||||
|
||||
static int THDiskFile_mode(const char *mode, int *isReadable, int *isWritable)
|
||||
{
|
||||
*isReadable = 0;
|
||||
*isWritable = 0;
|
||||
if(strlen(mode) == 1)
|
||||
{
|
||||
if(*mode == 'r')
|
||||
{
|
||||
*isReadable = 1;
|
||||
return 1;
|
||||
}
|
||||
else if(*mode == 'w')
|
||||
{
|
||||
*isWritable = 1;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
else if(strlen(mode) == 2)
|
||||
{
|
||||
if(mode[0] == 'r' && mode[1] == 'w')
|
||||
{
|
||||
*isReadable = 1;
|
||||
*isWritable = 1;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void THDiskFile_synchronize(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
fflush(dfself->handle);
|
||||
}
|
||||
|
||||
static void THDiskFile_seek(THFile *self, long position)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
THArgCheck(position >= 0, 2, "position must be positive");
|
||||
|
||||
if(fseek(dfself->handle, position, SEEK_SET) < 0)
|
||||
{
|
||||
dfself->file.hasError = 1;
|
||||
if(!dfself->file.isQuiet)
|
||||
THError("unable to seek at position %d", position);
|
||||
}
|
||||
}
|
||||
|
||||
static void THDiskFile_seekEnd(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
|
||||
if(fseek(dfself->handle, 0L, SEEK_END) < 0)
|
||||
{
|
||||
dfself->file.hasError = 1;
|
||||
if(!dfself->file.isQuiet)
|
||||
THError("unable to seek at end of file");
|
||||
}
|
||||
}
|
||||
|
||||
static long THDiskFile_position(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
return ftell(dfself->handle);
|
||||
}
|
||||
|
||||
static void THDiskFile_close(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
fclose(dfself->handle);
|
||||
dfself->handle = NULL;
|
||||
}
|
||||
|
||||
/* Little and Big Endian */
|
||||
|
||||
static void THDiskFile_reverseMemory(void *dst, const void *src, long blockSize, long numBlocks)
|
||||
{
|
||||
if(blockSize != 1)
|
||||
{
|
||||
long halfBlockSize = blockSize/2;
|
||||
char *charSrc = (char*)src;
|
||||
char *charDst = (char*)dst;
|
||||
long b, i;
|
||||
for(b = 0; b < numBlocks; b++)
|
||||
{
|
||||
for(i = 0; i < halfBlockSize; i++)
|
||||
{
|
||||
char z = charSrc[i];
|
||||
charDst[i] = charSrc[blockSize-1-i];
|
||||
charDst[blockSize-1-i] = z;
|
||||
}
|
||||
charSrc += blockSize;
|
||||
charDst += blockSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int THDiskFile_isLittleEndianCPU(void)
|
||||
{
|
||||
int x = 7;
|
||||
char *ptr = (char *)&x;
|
||||
|
||||
if(ptr[0] == 0)
|
||||
return 0;
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
|
||||
int THDiskFile_isBigEndianCPU(void)
|
||||
{
|
||||
return(!THDiskFile_isLittleEndianCPU());
|
||||
}
|
||||
|
||||
void THDiskFile_nativeEndianEncoding(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
dfself->isNativeEncoding = 1;
|
||||
}
|
||||
|
||||
void THDiskFile_littleEndianEncoding(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
dfself->isNativeEncoding = THDiskFile_isLittleEndianCPU();
|
||||
}
|
||||
|
||||
void THDiskFile_bigEndianEncoding(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
dfself->isNativeEncoding = !THDiskFile_isLittleEndianCPU();
|
||||
}
|
||||
|
||||
/* End of Little and Big Endian Stuff */
|
||||
|
||||
static void THDiskFile_free(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
if(dfself->handle)
|
||||
fclose(dfself->handle);
|
||||
THFree(dfself->name);
|
||||
THFree(dfself);
|
||||
}
|
||||
|
||||
/* READ_WRITE_METHODS(int, Bool, */
|
||||
/* int value = 0; int ret = fscanf(file->handle, "%d", &value); array[i] = (value ? 1 : 0); if(ret <= 0) break; else result++, */
|
||||
/* int value = (array[i] ? 1 : 0); nElemWritten = fprintf(file->handle, "%d", value), */
|
||||
/* true) */
|
||||
|
||||
/* Note that we do a trick */
|
||||
READ_WRITE_METHODS(unsigned char, Byte,
|
||||
nread = fread(data, 1, n, dfself->handle); break,
|
||||
nwrite = fwrite(data, 1, n, dfself->handle); break)
|
||||
|
||||
READ_WRITE_METHODS(char, Char,
|
||||
nread = fread(data, 1, n, dfself->handle); break,
|
||||
nwrite = fwrite(data, 1, n, dfself->handle); break)
|
||||
|
||||
READ_WRITE_METHODS(short, Short,
|
||||
int ret = fscanf(dfself->handle, "%hd", &data[i]); if(ret <= 0) break; else nread++,
|
||||
int ret = fprintf(dfself->handle, "%hd", data[i]); if(ret <= 0) break; else nwrite++)
|
||||
|
||||
READ_WRITE_METHODS(int, Int,
|
||||
int ret = fscanf(dfself->handle, "%d", &data[i]); if(ret <= 0) break; else nread++,
|
||||
int ret = fprintf(dfself->handle, "%d", data[i]); if(ret <= 0) break; else nwrite++)
|
||||
|
||||
READ_WRITE_METHODS(long, Long,
|
||||
int ret = fscanf(dfself->handle, "%ld", &data[i]); if(ret <= 0) break; else nread++,
|
||||
int ret = fprintf(dfself->handle, "%ld", data[i]); if(ret <= 0) break; else nwrite++)
|
||||
|
||||
READ_WRITE_METHODS(float, Float,
|
||||
int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++,
|
||||
int ret = fprintf(dfself->handle, "%g", data[i]); if(ret <= 0) break; else nwrite++)
|
||||
|
||||
READ_WRITE_METHODS(double, Double,
|
||||
int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++,
|
||||
int ret = fprintf(dfself->handle, "%lg", data[i]); if(ret <= 0) break; else nwrite++)
|
||||
|
||||
static long THDiskFile_readString(THFile *self, const char *format, char **str_)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file");
|
||||
THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'");
|
||||
|
||||
/* note: the string won't survive long, as it is copied into lua */
|
||||
/* so 1024 is not that big... */
|
||||
#define TBRS_BSZ 1024L
|
||||
|
||||
if(format[1] == 'a')
|
||||
{
|
||||
char *p = THAlloc(TBRS_BSZ);
|
||||
long total = TBRS_BSZ;
|
||||
long pos = 0L;
|
||||
|
||||
for (;;)
|
||||
{
|
||||
if(total-pos == 0) /* we need more space! */
|
||||
{
|
||||
total += TBRS_BSZ;
|
||||
p = THRealloc(p, total);
|
||||
}
|
||||
pos += fread(p+pos, 1, total-pos, dfself->handle);
|
||||
if (pos < total) /* eof? */
|
||||
{
|
||||
if(pos == 0L)
|
||||
{
|
||||
THFree(p);
|
||||
dfself->file.hasError = 1;
|
||||
if(!dfself->file.isQuiet)
|
||||
THError("read error: read 0 blocks instead of 1");
|
||||
|
||||
*str_ = NULL;
|
||||
return 0;
|
||||
}
|
||||
*str_ = p;
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
char *p = THAlloc(TBRS_BSZ);
|
||||
long total = TBRS_BSZ;
|
||||
long pos = 0L;
|
||||
long size;
|
||||
|
||||
for (;;)
|
||||
{
|
||||
if(total-pos <= 1) /* we can only write '\0' in there! */
|
||||
{
|
||||
total += TBRS_BSZ;
|
||||
p = THRealloc(p, total);
|
||||
}
|
||||
if (fgets(p+pos, total-pos, dfself->handle) == NULL) /* eof? */
|
||||
{
|
||||
if(pos == 0L)
|
||||
{
|
||||
THFree(p);
|
||||
dfself->file.hasError = 1;
|
||||
if(!dfself->file.isQuiet)
|
||||
THError("read error: read 0 blocks instead of 1");
|
||||
|
||||
*str_ = NULL;
|
||||
return 0;
|
||||
}
|
||||
*str_ = p;
|
||||
return pos;
|
||||
}
|
||||
size = strlen(p+pos);
|
||||
if (size == 0L || (p+pos)[size-1] != '\n')
|
||||
{
|
||||
pos += size;
|
||||
}
|
||||
else
|
||||
{
|
||||
pos += size-1L; /* do not include `eol' */
|
||||
*str_ = p;
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*str_ = NULL;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
static long THDiskFile_writeString(THFile *self, const char *str, long size)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
long nwrite;
|
||||
|
||||
THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file");
|
||||
THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file");
|
||||
|
||||
nwrite = fwrite(str, 1, size, dfself->handle);
|
||||
if(nwrite != size)
|
||||
{
|
||||
dfself->file.hasError = 1;
|
||||
if(!dfself->file.isQuiet)
|
||||
THError("write error: wrote %ld blocks instead of %ld", nwrite, size);
|
||||
}
|
||||
|
||||
return nwrite;
|
||||
}
|
||||
|
||||
THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet)
|
||||
{
|
||||
static struct THFileVTable vtable = {
|
||||
THDiskFile_isOpened,
|
||||
|
||||
THDiskFile_readByte,
|
||||
THDiskFile_readChar,
|
||||
THDiskFile_readShort,
|
||||
THDiskFile_readInt,
|
||||
THDiskFile_readLong,
|
||||
THDiskFile_readFloat,
|
||||
THDiskFile_readDouble,
|
||||
THDiskFile_readString,
|
||||
|
||||
THDiskFile_writeByte,
|
||||
THDiskFile_writeChar,
|
||||
THDiskFile_writeShort,
|
||||
THDiskFile_writeInt,
|
||||
THDiskFile_writeLong,
|
||||
THDiskFile_writeFloat,
|
||||
THDiskFile_writeDouble,
|
||||
THDiskFile_writeString,
|
||||
|
||||
THDiskFile_synchronize,
|
||||
THDiskFile_seek,
|
||||
THDiskFile_seekEnd,
|
||||
THDiskFile_position,
|
||||
THDiskFile_close,
|
||||
THDiskFile_free
|
||||
};
|
||||
|
||||
int isReadable;
|
||||
int isWritable;
|
||||
FILE *handle;
|
||||
THDiskFile *self;
|
||||
|
||||
THArgCheck(THDiskFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'");
|
||||
|
||||
if( isReadable && isWritable )
|
||||
{
|
||||
handle = fopen(name, "r+b");
|
||||
if(!handle)
|
||||
{
|
||||
handle = fopen(name, "wb");
|
||||
if(handle)
|
||||
{
|
||||
fclose(handle);
|
||||
handle = fopen(name, "r+b");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
handle = fopen(name, (isReadable ? "rb" : "wb"));
|
||||
|
||||
if(!handle)
|
||||
{
|
||||
if(isQuiet)
|
||||
return 0;
|
||||
else
|
||||
THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' '));
|
||||
}
|
||||
|
||||
self = THAlloc(sizeof(THDiskFile));
|
||||
|
||||
self->handle = handle;
|
||||
self->name = THAlloc(strlen(name)+1);
|
||||
strcpy(self->name, name);
|
||||
self->isNativeEncoding = 1;
|
||||
|
||||
self->file.vtable = &vtable;
|
||||
self->file.isQuiet = isQuiet;
|
||||
self->file.isReadable = isReadable;
|
||||
self->file.isWritable = isWritable;
|
||||
self->file.isBinary = 0;
|
||||
self->file.isAutoSpacing = 1;
|
||||
self->file.hasError = 0;
|
||||
|
||||
return (THFile*)self;
|
||||
}
|
||||
|
||||
/* PipeFile */
|
||||
|
||||
static int THPipeFile_mode(const char *mode, int *isReadable, int *isWritable)
|
||||
{
|
||||
*isReadable = 0;
|
||||
*isWritable = 0;
|
||||
if(strlen(mode) == 1)
|
||||
{
|
||||
if(*mode == 'r')
|
||||
{
|
||||
*isReadable = 1;
|
||||
return 1;
|
||||
}
|
||||
else if(*mode == 'w')
|
||||
{
|
||||
*isWritable = 1;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void THPipeFile_free(THFile *self)
|
||||
{
|
||||
THDiskFile *dfself = (THDiskFile*)(self);
|
||||
if(dfself->handle)
|
||||
pclose(dfself->handle);
|
||||
THFree(dfself->name);
|
||||
THFree(dfself);
|
||||
}
|
||||
|
||||
THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet)
|
||||
{
|
||||
static struct THFileVTable vtable = {
|
||||
THDiskFile_isOpened,
|
||||
|
||||
THDiskFile_readByte,
|
||||
THDiskFile_readChar,
|
||||
THDiskFile_readShort,
|
||||
THDiskFile_readInt,
|
||||
THDiskFile_readLong,
|
||||
THDiskFile_readFloat,
|
||||
THDiskFile_readDouble,
|
||||
THDiskFile_readString,
|
||||
|
||||
THDiskFile_writeByte,
|
||||
THDiskFile_writeChar,
|
||||
THDiskFile_writeShort,
|
||||
THDiskFile_writeInt,
|
||||
THDiskFile_writeLong,
|
||||
THDiskFile_writeFloat,
|
||||
THDiskFile_writeDouble,
|
||||
THDiskFile_writeString,
|
||||
|
||||
THDiskFile_synchronize,
|
||||
THDiskFile_seek,
|
||||
THDiskFile_seekEnd,
|
||||
THDiskFile_position,
|
||||
THDiskFile_close,
|
||||
THPipeFile_free
|
||||
};
|
||||
|
||||
int isReadable;
|
||||
int isWritable;
|
||||
FILE *handle;
|
||||
THDiskFile *self;
|
||||
|
||||
THArgCheck(THPipeFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w'");
|
||||
|
||||
#ifdef _WIN32
|
||||
handle = popen(name, (isReadable ? "rb" : "wb"));
|
||||
#else
|
||||
handle = popen(name, (isReadable ? "r" : "w"));
|
||||
#endif
|
||||
|
||||
if(!handle)
|
||||
{
|
||||
if(isQuiet)
|
||||
return 0;
|
||||
else
|
||||
THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' '));
|
||||
}
|
||||
|
||||
self = THAlloc(sizeof(THDiskFile));
|
||||
|
||||
self->handle = handle;
|
||||
self->name = THAlloc(strlen(name)+1);
|
||||
strcpy(self->name, name);
|
||||
self->isNativeEncoding = 1;
|
||||
|
||||
self->file.vtable = &vtable;
|
||||
self->file.isQuiet = isQuiet;
|
||||
self->file.isReadable = isReadable;
|
||||
self->file.isWritable = isWritable;
|
||||
self->file.isBinary = 0;
|
||||
self->file.isAutoSpacing = 1;
|
||||
self->file.hasError = 0;
|
||||
|
||||
return (THFile*)self;
|
||||
}
|
||||
17
lib/TH/THDiskFile.h
Normal file
17
lib/TH/THDiskFile.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef TH_DISK_FILE_INC
|
||||
#define TH_DISK_FILE_INC
|
||||
|
||||
#include "THFile.h"
|
||||
|
||||
THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet);
|
||||
THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet);
|
||||
|
||||
const char *THDiskFile_name(THFile *self);
|
||||
|
||||
int THDiskFile_isLittleEndianCPU(void);
|
||||
int THDiskFile_isBigEndianCPU(void);
|
||||
void THDiskFile_nativeEndianEncoding(THFile *self);
|
||||
void THDiskFile_littleEndianEncoding(THFile *self);
|
||||
void THDiskFile_bigEndianEncoding(THFile *self);
|
||||
|
||||
#endif
|
||||
154
lib/TH/THFile.c
Normal file
154
lib/TH/THFile.c
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
#include "THFile.h"
|
||||
#include "THFilePrivate.h"
|
||||
|
||||
#define IMPLEMENT_THFILE_RW(TYPEC, TYPE) \
|
||||
long THFile_read##TYPEC##Raw(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
return (*self->vtable->read##TYPEC)(self, data, n); \
|
||||
} \
|
||||
\
|
||||
long THFile_write##TYPEC##Raw(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
return (*self->vtable->write##TYPEC)(self, data, n); \
|
||||
}
|
||||
|
||||
IMPLEMENT_THFILE_RW(Byte, unsigned char)
|
||||
IMPLEMENT_THFILE_RW(Char, char)
|
||||
IMPLEMENT_THFILE_RW(Short, short)
|
||||
IMPLEMENT_THFILE_RW(Int, int)
|
||||
IMPLEMENT_THFILE_RW(Long, long)
|
||||
IMPLEMENT_THFILE_RW(Float, float)
|
||||
IMPLEMENT_THFILE_RW(Double, double)
|
||||
|
||||
long THFile_readStringRaw(THFile *self, const char *format, char **str_)
|
||||
{
|
||||
return self->vtable->readString(self, format, str_);
|
||||
}
|
||||
|
||||
long THFile_writeStringRaw(THFile *self, const char *str, long size)
|
||||
{
|
||||
return self->vtable->writeString(self, str, size);
|
||||
}
|
||||
|
||||
void THFile_synchronize(THFile *self)
|
||||
{
|
||||
self->vtable->synchronize(self);
|
||||
}
|
||||
|
||||
void THFile_seek(THFile *self, long position)
|
||||
{
|
||||
self->vtable->seek(self, position);
|
||||
}
|
||||
|
||||
void THFile_seekEnd(THFile *self)
|
||||
{
|
||||
self->vtable->seekEnd(self);
|
||||
}
|
||||
|
||||
long THFile_position(THFile *self)
|
||||
{
|
||||
return self->vtable->position(self);
|
||||
}
|
||||
|
||||
void THFile_close(THFile *self)
|
||||
{
|
||||
self->vtable->close(self);
|
||||
}
|
||||
|
||||
void THFile_free(THFile *self)
|
||||
{
|
||||
self->vtable->free(self);
|
||||
}
|
||||
|
||||
int THFile_isOpened(THFile *self)
|
||||
{
|
||||
return self->vtable->isOpened(self);
|
||||
}
|
||||
|
||||
#define IMPLEMENT_THFILE_FLAGS(FLAG) \
|
||||
int THFile_##FLAG(THFile *self) \
|
||||
{ \
|
||||
return self->FLAG; \
|
||||
}
|
||||
|
||||
IMPLEMENT_THFILE_FLAGS(isQuiet)
|
||||
IMPLEMENT_THFILE_FLAGS(isReadable)
|
||||
IMPLEMENT_THFILE_FLAGS(isWritable)
|
||||
IMPLEMENT_THFILE_FLAGS(isBinary)
|
||||
IMPLEMENT_THFILE_FLAGS(isAutoSpacing)
|
||||
IMPLEMENT_THFILE_FLAGS(hasError)
|
||||
|
||||
void THFile_binary(THFile *self)
|
||||
{
|
||||
self->isBinary = 1;
|
||||
}
|
||||
|
||||
void THFile_ascii(THFile *self)
|
||||
{
|
||||
self->isBinary = 0;
|
||||
}
|
||||
|
||||
void THFile_autoSpacing(THFile *self)
|
||||
{
|
||||
self->isAutoSpacing = 1;
|
||||
}
|
||||
|
||||
void THFile_noAutoSpacing(THFile *self)
|
||||
{
|
||||
self->isAutoSpacing = 0;
|
||||
}
|
||||
|
||||
void THFile_quiet(THFile *self)
|
||||
{
|
||||
self->isQuiet = 1;
|
||||
}
|
||||
|
||||
void THFile_pedantic(THFile *self)
|
||||
{
|
||||
self->isQuiet = 0;
|
||||
}
|
||||
|
||||
void THFile_clearError(THFile *self)
|
||||
{
|
||||
self->hasError = 0;
|
||||
}
|
||||
|
||||
#define IMPLEMENT_THFILE_SCALAR(TYPEC, TYPE) \
|
||||
TYPE THFile_read##TYPEC##Scalar(THFile *self) \
|
||||
{ \
|
||||
TYPE scalar; \
|
||||
THFile_read##TYPEC##Raw(self, &scalar, 1); \
|
||||
return scalar; \
|
||||
} \
|
||||
\
|
||||
void THFile_write##TYPEC##Scalar(THFile *self, TYPE scalar) \
|
||||
{ \
|
||||
THFile_write##TYPEC##Raw(self, &scalar, 1); \
|
||||
}
|
||||
|
||||
IMPLEMENT_THFILE_SCALAR(Byte, unsigned char)
|
||||
IMPLEMENT_THFILE_SCALAR(Char, char)
|
||||
IMPLEMENT_THFILE_SCALAR(Short, short)
|
||||
IMPLEMENT_THFILE_SCALAR(Int, int)
|
||||
IMPLEMENT_THFILE_SCALAR(Long, long)
|
||||
IMPLEMENT_THFILE_SCALAR(Float, float)
|
||||
IMPLEMENT_THFILE_SCALAR(Double, double)
|
||||
|
||||
#define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \
|
||||
long THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \
|
||||
{ \
|
||||
return THFile_read##TYPEC##Raw(self, storage->data, storage->size); \
|
||||
} \
|
||||
\
|
||||
long THFile_write##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \
|
||||
{ \
|
||||
return THFile_write##TYPEC##Raw(self, storage->data, storage->size); \
|
||||
}
|
||||
|
||||
IMPLEMENT_THFILE_STORAGE(Byte, unsigned char)
|
||||
IMPLEMENT_THFILE_STORAGE(Char, char)
|
||||
IMPLEMENT_THFILE_STORAGE(Short, short)
|
||||
IMPLEMENT_THFILE_STORAGE(Int, int)
|
||||
IMPLEMENT_THFILE_STORAGE(Long, long)
|
||||
IMPLEMENT_THFILE_STORAGE(Float, float)
|
||||
IMPLEMENT_THFILE_STORAGE(Double, double)
|
||||
84
lib/TH/THFile.h
Normal file
84
lib/TH/THFile.h
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
#ifndef TH_FILE_INC
|
||||
#define TH_FILE_INC
|
||||
|
||||
#include "THStorage.h"
|
||||
|
||||
typedef struct THFile__ THFile;
|
||||
|
||||
int THFile_isOpened(THFile *self);
|
||||
int THFile_isQuiet(THFile *self);
|
||||
int THFile_isReadable(THFile *self);
|
||||
int THFile_isWritable(THFile *self);
|
||||
int THFile_isBinary(THFile *self);
|
||||
int THFile_isAutoSpacing(THFile *self);
|
||||
int THFile_hasError(THFile *self);
|
||||
|
||||
void THFile_binary(THFile *self);
|
||||
void THFile_ascii(THFile *self);
|
||||
void THFile_autoSpacing(THFile *self);
|
||||
void THFile_noAutoSpacing(THFile *self);
|
||||
void THFile_quiet(THFile *self);
|
||||
void THFile_pedantic(THFile *self);
|
||||
void THFile_clearError(THFile *self);
|
||||
|
||||
/* scalar */
|
||||
unsigned char THFile_readByteScalar(THFile *self);
|
||||
char THFile_readCharScalar(THFile *self);
|
||||
short THFile_readShortScalar(THFile *self);
|
||||
int THFile_readIntScalar(THFile *self);
|
||||
long THFile_readLongScalar(THFile *self);
|
||||
float THFile_readFloatScalar(THFile *self);
|
||||
double THFile_readDoubleScalar(THFile *self);
|
||||
|
||||
void THFile_writeByteScalar(THFile *self, unsigned char scalar);
|
||||
void THFile_writeCharScalar(THFile *self, char scalar);
|
||||
void THFile_writeShortScalar(THFile *self, short scalar);
|
||||
void THFile_writeIntScalar(THFile *self, int scalar);
|
||||
void THFile_writeLongScalar(THFile *self, long scalar);
|
||||
void THFile_writeFloatScalar(THFile *self, float scalar);
|
||||
void THFile_writeDoubleScalar(THFile *self, double scalar);
|
||||
|
||||
/* storage */
|
||||
long THFile_readByte(THFile *self, THByteStorage *storage);
|
||||
long THFile_readChar(THFile *self, THCharStorage *storage);
|
||||
long THFile_readShort(THFile *self, THShortStorage *storage);
|
||||
long THFile_readInt(THFile *self, THIntStorage *storage);
|
||||
long THFile_readLong(THFile *self, THLongStorage *storage);
|
||||
long THFile_readFloat(THFile *self, THFloatStorage *storage);
|
||||
long THFile_readDouble(THFile *self, THDoubleStorage *storage);
|
||||
|
||||
long THFile_writeByte(THFile *self, THByteStorage *storage);
|
||||
long THFile_writeChar(THFile *self, THCharStorage *storage);
|
||||
long THFile_writeShort(THFile *self, THShortStorage *storage);
|
||||
long THFile_writeInt(THFile *self, THIntStorage *storage);
|
||||
long THFile_writeLong(THFile *self, THLongStorage *storage);
|
||||
long THFile_writeFloat(THFile *self, THFloatStorage *storage);
|
||||
long THFile_writeDouble(THFile *self, THDoubleStorage *storage);
|
||||
|
||||
/* raw */
|
||||
long THFile_readByteRaw(THFile *self, unsigned char *data, long n);
|
||||
long THFile_readCharRaw(THFile *self, char *data, long n);
|
||||
long THFile_readShortRaw(THFile *self, short *data, long n);
|
||||
long THFile_readIntRaw(THFile *self, int *data, long n);
|
||||
long THFile_readLongRaw(THFile *self, long *data, long n);
|
||||
long THFile_readFloatRaw(THFile *self, float *data, long n);
|
||||
long THFile_readDoubleRaw(THFile *self, double *data, long n);
|
||||
long THFile_readStringRaw(THFile *self, const char *format, char **str_); /* you must deallocate str_ */
|
||||
|
||||
long THFile_writeByteRaw(THFile *self, unsigned char *data, long n);
|
||||
long THFile_writeCharRaw(THFile *self, char *data, long n);
|
||||
long THFile_writeShortRaw(THFile *self, short *data, long n);
|
||||
long THFile_writeIntRaw(THFile *self, int *data, long n);
|
||||
long THFile_writeLongRaw(THFile *self, long *data, long n);
|
||||
long THFile_writeFloatRaw(THFile *self, float *data, long n);
|
||||
long THFile_writeDoubleRaw(THFile *self, double *data, long n);
|
||||
long THFile_writeStringRaw(THFile *self, const char *str, long size);
|
||||
|
||||
void THFile_synchronize(THFile *self);
|
||||
void THFile_seek(THFile *self, long position);
|
||||
void THFile_seekEnd(THFile *self);
|
||||
long THFile_position(THFile *self);
|
||||
void THFile_close(THFile *self);
|
||||
void THFile_free(THFile *self);
|
||||
|
||||
#endif
|
||||
43
lib/TH/THFilePrivate.h
Normal file
43
lib/TH/THFilePrivate.h
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
struct THFile__
|
||||
{
|
||||
struct THFileVTable *vtable;
|
||||
|
||||
int isQuiet;
|
||||
int isReadable;
|
||||
int isWritable;
|
||||
int isBinary;
|
||||
int isAutoSpacing;
|
||||
int hasError;
|
||||
};
|
||||
|
||||
/* virtual table definition */
|
||||
|
||||
struct THFileVTable
|
||||
{
|
||||
int (*isOpened)(THFile *self);
|
||||
|
||||
long (*readByte)(THFile *self, unsigned char *data, long n);
|
||||
long (*readChar)(THFile *self, char *data, long n);
|
||||
long (*readShort)(THFile *self, short *data, long n);
|
||||
long (*readInt)(THFile *self, int *data, long n);
|
||||
long (*readLong)(THFile *self, long *data, long n);
|
||||
long (*readFloat)(THFile *self, float *data, long n);
|
||||
long (*readDouble)(THFile *self, double *data, long n);
|
||||
long (*readString)(THFile *self, const char *format, char **str_);
|
||||
|
||||
long (*writeByte)(THFile *self, unsigned char *data, long n);
|
||||
long (*writeChar)(THFile *self, char *data, long n);
|
||||
long (*writeShort)(THFile *self, short *data, long n);
|
||||
long (*writeInt)(THFile *self, int *data, long n);
|
||||
long (*writeLong)(THFile *self, long *data, long n);
|
||||
long (*writeFloat)(THFile *self, float *data, long n);
|
||||
long (*writeDouble)(THFile *self, double *data, long n);
|
||||
long (*writeString)(THFile *self, const char *str, long size);
|
||||
|
||||
void (*synchronize)(THFile *self);
|
||||
void (*seek)(THFile *self, long position);
|
||||
void (*seekEnd)(THFile *self);
|
||||
long (*position)(THFile *self);
|
||||
void (*close)(THFile *self);
|
||||
void (*free)(THFile *self);
|
||||
};
|
||||
110
lib/TH/THGeneral.c
Normal file
110
lib/TH/THGeneral.c
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
#include "THGeneral.h"
|
||||
|
||||
/* Torch Error Handling */
|
||||
static void defaultTorchErrorHandlerFunction(const char *msg)
|
||||
{
|
||||
printf("$ Error: %s\n", msg);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
static void (*torchErrorHandlerFunction)(const char *msg) = defaultTorchErrorHandlerFunction;
|
||||
|
||||
void THError(const char *fmt, ...)
|
||||
{
|
||||
static char msg[1024];
|
||||
va_list args;
|
||||
|
||||
/* vasprintf not standard */
|
||||
/* vsnprintf: how to handle if does not exists? */
|
||||
va_start(args, fmt);
|
||||
vsnprintf(msg, 1024, fmt, args);
|
||||
va_end(args);
|
||||
|
||||
(*torchErrorHandlerFunction)(msg);
|
||||
}
|
||||
|
||||
void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg) )
|
||||
{
|
||||
if(torchErrorHandlerFunction_)
|
||||
torchErrorHandlerFunction = torchErrorHandlerFunction_;
|
||||
else
|
||||
torchErrorHandlerFunction = defaultTorchErrorHandlerFunction;
|
||||
}
|
||||
|
||||
/* Torch Arg Checking Handling */
|
||||
static void defaultTorchArgCheckHandlerFunction(int condition, int argNumber, const char *msg)
|
||||
{
|
||||
if(!condition)
|
||||
{
|
||||
if(msg)
|
||||
printf("$ Invalid argument %d: %s\n", argNumber, msg);
|
||||
else
|
||||
printf("$ Invalid argument %d\n", argNumber);
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
static void (*torchArgCheckHandlerFunction)(int condition, int argNumber, const char *msg) = defaultTorchArgCheckHandlerFunction;
|
||||
|
||||
void THArgCheck(int condition, int argNumber, const char *msg)
|
||||
{
|
||||
(*torchArgCheckHandlerFunction)(condition, argNumber, msg);
|
||||
}
|
||||
|
||||
void THSetArgCheckHandler( void (*torchArgCheckHandlerFunction_)(int condition, int argNumber, const char *msg) )
|
||||
{
|
||||
if(torchArgCheckHandlerFunction_)
|
||||
torchArgCheckHandlerFunction = torchArgCheckHandlerFunction_;
|
||||
else
|
||||
torchArgCheckHandlerFunction = defaultTorchArgCheckHandlerFunction;
|
||||
}
|
||||
|
||||
void* THAlloc(long size)
|
||||
{
|
||||
void *ptr;
|
||||
|
||||
if(size < 0)
|
||||
THError("$ Torch: invalid memory size -- maybe an overflow?");
|
||||
|
||||
if(size == 0)
|
||||
return NULL;
|
||||
|
||||
ptr = malloc(size);
|
||||
if(!ptr)
|
||||
THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void* THRealloc(void *ptr, long size)
|
||||
{
|
||||
if(!ptr)
|
||||
return(THAlloc(size));
|
||||
|
||||
if(size == 0)
|
||||
{
|
||||
THFree(ptr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if(size < 0)
|
||||
THError("$ Torch: invalid memory size -- maybe an overflow?");
|
||||
|
||||
ptr = realloc(ptr, size);
|
||||
if(!ptr)
|
||||
THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void THFree(void *ptr)
|
||||
{
|
||||
free(ptr);
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
double log1p(const double x)
|
||||
{
|
||||
volatile double y;
|
||||
y = 1 + x;
|
||||
return log(y) - ((y-1)-x)/y ; /* cancels errors with IEEE arithmetic */
|
||||
}
|
||||
#endif
|
||||
72
lib/TH/THGeneral.h
Normal file
72
lib/TH/THGeneral.h
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
#ifndef TH_GENERAL_INC
|
||||
#define TH_GENERAL_INC
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <stdarg.h>
|
||||
#include <math.h>
|
||||
#include <limits.h>
|
||||
#include <float.h>
|
||||
#include <time.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
# define TH_EXTERNC extern "C"
|
||||
#else
|
||||
# define TH_EXTERNC extern
|
||||
#endif
|
||||
|
||||
#ifdef WIN32
|
||||
# ifdef TH_EXPORTS
|
||||
# define TH_API TH_EXTERNC __declspec(dllexport)
|
||||
# else
|
||||
# define TH_API TH_EXTERNC __declspec(dllimport)
|
||||
# endif
|
||||
#else
|
||||
# define TH_API TH_EXTERNC
|
||||
#endif
|
||||
|
||||
#define THInf DBL_MAX
|
||||
|
||||
#if !defined(inline)
|
||||
# define inline
|
||||
#endif
|
||||
|
||||
#ifndef M_PI
|
||||
# define M_PI 3.14159265358979323846
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
TH_API double log1p(const double x);
|
||||
#endif
|
||||
|
||||
TH_API void THError(const char *fmt, ...);
|
||||
TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg) );
|
||||
TH_API void THArgCheck(int condition, int argNumber, const char *msg);
|
||||
TH_API void THSetArgCheckHandler( void (*torchArgCheckHandlerFunction)(int condition, int argNumber, const char *msg) );
|
||||
TH_API void* THAlloc(long size);
|
||||
TH_API void* THRealloc(void *ptr, long size);
|
||||
TH_API void THFree(void *ptr);
|
||||
|
||||
#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y)
|
||||
#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y
|
||||
|
||||
#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z)
|
||||
#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z
|
||||
|
||||
#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w)
|
||||
#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w
|
||||
|
||||
#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y)
|
||||
#define TH_CONCAT_2_EXPAND(x,y) x ## y
|
||||
|
||||
#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z)
|
||||
#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z
|
||||
|
||||
#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w
|
||||
#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w)
|
||||
|
||||
#define THMin(X, Y) ((X) < (Y) ? (X) : (Y))
|
||||
#define THMax(X, Y) ((X) > (Y) ? (X) : (Y))
|
||||
|
||||
#endif
|
||||
83
lib/TH/THGenerateAllTypes.h
Normal file
83
lib/TH/THGenerateAllTypes.h
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h"
|
||||
#endif
|
||||
|
||||
#define real unsigned char
|
||||
#define accreal long
|
||||
#define Real Byte
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
/*#line 1 "THByteStorage.h"*/
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_BYTE
|
||||
|
||||
#define real char
|
||||
#define accreal long
|
||||
#define Real Char
|
||||
#define TH_REAL_IS_CHAR
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_CHAR
|
||||
|
||||
#define real short
|
||||
#define accreal long
|
||||
#define Real Short
|
||||
#define TH_REAL_IS_SHORT
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_SHORT
|
||||
|
||||
#define real int
|
||||
#define accreal long
|
||||
#define Real Int
|
||||
#define TH_REAL_IS_INT
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_INT
|
||||
|
||||
#define real long
|
||||
#define accreal long
|
||||
#define Real Long
|
||||
#define TH_REAL_IS_LONG
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_LONG
|
||||
|
||||
#define real float
|
||||
#define accreal double
|
||||
#define Real Float
|
||||
#define TH_REAL_IS_FLOAT
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_FLOAT
|
||||
|
||||
#define real double
|
||||
#define accreal double
|
||||
#define Real Double
|
||||
#define TH_REAL_IS_DOUBLE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_DOUBLE
|
||||
|
||||
#undef TH_GENERIC_FILE
|
||||
27
lib/TH/THGenerateFloatTypes.h
Normal file
27
lib/TH/THGenerateFloatTypes.h
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h"
|
||||
#endif
|
||||
|
||||
#define real float
|
||||
#define accreal double
|
||||
#define Real Float
|
||||
#define TH_REAL_IS_FLOAT
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef accreal
|
||||
#undef real
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_FLOAT
|
||||
|
||||
#define real double
|
||||
#define accreal double
|
||||
#define Real Double
|
||||
#define TH_REAL_IS_DOUBLE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef accreal
|
||||
#undef real
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_DOUBLE
|
||||
|
||||
#undef TH_GENERIC_FILE
|
||||
60
lib/TH/THGenerateIntTypes.h
Normal file
60
lib/TH/THGenerateIntTypes.h
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#error "You must define TH_GENERIC_FILE before including THGenerateIntTypes.h"
|
||||
#endif
|
||||
|
||||
#define real unsigned char
|
||||
#define accreal long
|
||||
#define Real Byte
|
||||
#define TH_REAL_IS_BYTE
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_BYTE
|
||||
|
||||
#define real char
|
||||
#define accreal long
|
||||
#define Real Char
|
||||
#define TH_REAL_IS_CHAR
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_CHAR
|
||||
|
||||
#define real short
|
||||
#define accreal long
|
||||
#define Real Short
|
||||
#define TH_REAL_IS_SHORT
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_SHORT
|
||||
|
||||
#define real int
|
||||
#define accreal long
|
||||
#define Real Int
|
||||
#define TH_REAL_IS_INT
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_INT
|
||||
|
||||
#define real long
|
||||
#define accreal long
|
||||
#define Real Long
|
||||
#define TH_REAL_IS_LONG
|
||||
#line 1 TH_GENERIC_FILE
|
||||
#include TH_GENERIC_FILE
|
||||
#undef real
|
||||
#undef accreal
|
||||
#undef Real
|
||||
#undef TH_REAL_IS_LONG
|
||||
|
||||
#undef TH_GENERIC_FILE
|
||||
5
lib/TH/THLapack.c
Normal file
5
lib/TH/THLapack.c
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#include "THLapack.h"
|
||||
|
||||
/* #include "THCBlas.h" */
|
||||
#include "generic/THLapack.c"
|
||||
#include "THGenerateFloatTypes.h"
|
||||
11
lib/TH/THLapack.h
Normal file
11
lib/TH/THLapack.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#ifndef TH_LAPACK_INC
|
||||
#define TH_LAPACK_INC
|
||||
|
||||
#include "THGeneral.h"
|
||||
|
||||
#define THLapack_(NAME) TH_CONCAT_4(TH,Real,Lapack_,NAME)
|
||||
|
||||
#include "generic/THLapack.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#endif
|
||||
86
lib/TH/THLogAdd.c
Normal file
86
lib/TH/THLogAdd.c
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
#include "THLogAdd.h"
|
||||
|
||||
#ifdef USE_DOUBLE
|
||||
#define MINUS_LOG_THRESHOLD -39.14
|
||||
#else
|
||||
#define MINUS_LOG_THRESHOLD -18.42
|
||||
#endif
|
||||
|
||||
const double THLog2Pi=1.83787706640934548355;
|
||||
const double THLogZero=-THInf;
|
||||
const double THLogOne=0;
|
||||
|
||||
double THLogAdd(double log_a, double log_b)
|
||||
{
|
||||
double minusdif;
|
||||
|
||||
if (log_a < log_b)
|
||||
{
|
||||
double tmp = log_a;
|
||||
log_a = log_b;
|
||||
log_b = tmp;
|
||||
}
|
||||
|
||||
minusdif = log_b - log_a;
|
||||
#ifdef DEBUG
|
||||
if (isnan(minusdif))
|
||||
THError("THLogAdd: minusdif (%f) log_b (%f) or log_a (%f) is nan", minusdif, log_b, log_a);
|
||||
#endif
|
||||
if (minusdif < MINUS_LOG_THRESHOLD)
|
||||
return log_a;
|
||||
else
|
||||
return log_a + log1p(exp(minusdif));
|
||||
}
|
||||
|
||||
double THLogSub(double log_a, double log_b)
|
||||
{
|
||||
double minusdif;
|
||||
|
||||
if (log_a < log_b)
|
||||
THError("LogSub: log_a (%f) should be greater than log_b (%f)", log_a, log_b);
|
||||
|
||||
minusdif = log_b - log_a;
|
||||
#ifdef DEBUG
|
||||
if (isnan(minusdif))
|
||||
THError("LogSub: minusdif (%f) log_b (%f) or log_a (%f) is nan", minusdif, log_b, log_a);
|
||||
#endif
|
||||
if (log_a == log_b)
|
||||
return THLogZero;
|
||||
else if (minusdif < MINUS_LOG_THRESHOLD)
|
||||
return log_a;
|
||||
else
|
||||
return log_a + log1p(-exp(minusdif));
|
||||
}
|
||||
|
||||
/* Credits to Leon Bottou */
|
||||
double THExpMinusApprox(double x)
|
||||
{
|
||||
#define EXACT_EXPONENTIAL 0
|
||||
#if EXACT_EXPONENTIAL
|
||||
return exp(-x);
|
||||
#else
|
||||
/* fast approximation of exp(-x) for x positive */
|
||||
# define A0 (1.0)
|
||||
# define A1 (0.125)
|
||||
# define A2 (0.0078125)
|
||||
# define A3 (0.00032552083)
|
||||
# define A4 (1.0172526e-5)
|
||||
if (x < 13.0)
|
||||
{
|
||||
/* assert(x>=0); */
|
||||
double y;
|
||||
y = A0+x*(A1+x*(A2+x*(A3+x*A4)));
|
||||
y *= y;
|
||||
y *= y;
|
||||
y *= y;
|
||||
y = 1/y;
|
||||
return y;
|
||||
}
|
||||
return 0;
|
||||
# undef A0
|
||||
# undef A1
|
||||
# undef A2
|
||||
# undef A3
|
||||
# undef A4
|
||||
#endif
|
||||
}
|
||||
14
lib/TH/THLogAdd.h
Normal file
14
lib/TH/THLogAdd.h
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#ifndef TH_LOG_ADD_INC
|
||||
#define TH_LOG_ADD_INC
|
||||
|
||||
#include "THGeneral.h"
|
||||
|
||||
TH_API const double THLog2Pi;
|
||||
TH_API const double THLogZero;
|
||||
TH_API const double THLogOne;
|
||||
|
||||
TH_API double THLogAdd(double log_a, double log_b);
|
||||
TH_API double THLogSub(double log_a, double log_b);
|
||||
TH_API double THExpMinusApprox(const double x);
|
||||
|
||||
#endif
|
||||
492
lib/TH/THMemoryFile.c
Normal file
492
lib/TH/THMemoryFile.c
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
#include "THMemoryFile.h"
|
||||
#include "THFilePrivate.h"
|
||||
|
||||
typedef struct THMemoryFile__
|
||||
{
|
||||
THFile file;
|
||||
THCharStorage *storage;
|
||||
long size;
|
||||
long position;
|
||||
|
||||
} THMemoryFile;
|
||||
|
||||
static int THMemoryFile_isOpened(THFile *self)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
return (mfself->storage != NULL);
|
||||
}
|
||||
|
||||
static char *THMemoryFile_strnextspace(char *str_, char *c_)
|
||||
{
|
||||
char c;
|
||||
|
||||
while( (c = *str_) )
|
||||
{
|
||||
if( (c != ' ') && (c != '\n') && (c != ':') && (c != ';') )
|
||||
break;
|
||||
str_++;
|
||||
}
|
||||
|
||||
while( (c = *str_) )
|
||||
{
|
||||
if( (c == ' ') || (c == '\n') || (c == ':') || (c == ';') )
|
||||
{
|
||||
*c_ = c;
|
||||
*str_ = '\0';
|
||||
return(str_);
|
||||
}
|
||||
str_++;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void THMemoryFile_grow(THMemoryFile *self, long size)
|
||||
{
|
||||
long missingSpace;
|
||||
|
||||
if(size <= self->size)
|
||||
return;
|
||||
else
|
||||
{
|
||||
if(size < self->storage->size) /* note the "<" and not "<=" */
|
||||
{
|
||||
self->size = size;
|
||||
self->storage->data[self->size] = '\0';
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
missingSpace = size-self->storage->size+1; /* +1 for the '\0' */
|
||||
THCharStorage_resize(self->storage, (self->storage->size/2 > missingSpace ?
|
||||
self->storage->size + (self->storage->size/2)
|
||||
: self->storage->size + missingSpace));
|
||||
}
|
||||
|
||||
static int THMemoryFile_mode(const char *mode, int *isReadable, int *isWritable)
|
||||
{
|
||||
*isReadable = 0;
|
||||
*isWritable = 0;
|
||||
if(strlen(mode) == 1)
|
||||
{
|
||||
if(*mode == 'r')
|
||||
{
|
||||
*isReadable = 1;
|
||||
return 1;
|
||||
}
|
||||
else if(*mode == 'w')
|
||||
{
|
||||
*isWritable = 1;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
else if(strlen(mode) == 2)
|
||||
{
|
||||
if(mode[0] == 'r' && mode[1] == 'w')
|
||||
{
|
||||
*isReadable = 1;
|
||||
*isWritable = 1;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/********************************************************/
|
||||
|
||||
#define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM, INSIDE_SPACING) \
|
||||
static long THMemoryFile_read##TYPEC(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
THMemoryFile *mfself = (THMemoryFile*)self; \
|
||||
long nread = 0L; \
|
||||
\
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); \
|
||||
THArgCheck(mfself->file.isReadable, 1, "attempt to read in a write-only file"); \
|
||||
\
|
||||
if(mfself->file.isBinary) \
|
||||
{ \
|
||||
long nByte = sizeof(TYPE)*n; \
|
||||
long nByteRemaining = (mfself->position + nByte <= mfself->size ? nByte : mfself->size-mfself->position); \
|
||||
nread = nByteRemaining/sizeof(TYPE); \
|
||||
memmove(data, mfself->storage->data+mfself->position, nread*sizeof(TYPE)); \
|
||||
mfself->position += nread*sizeof(TYPE); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
long i; \
|
||||
for(i = 0; i < n; i++) \
|
||||
{ \
|
||||
long nByteRead = 0; \
|
||||
char spaceChar = 0; \
|
||||
char *spacePtr = THMemoryFile_strnextspace(mfself->storage->data+mfself->position, &spaceChar); \
|
||||
ASCII_READ_ELEM; \
|
||||
if(ret == EOF) \
|
||||
{ \
|
||||
while(mfself->storage->data[mfself->position]) \
|
||||
mfself->position++; \
|
||||
} \
|
||||
else \
|
||||
mfself->position += nByteRead; \
|
||||
if(spacePtr) \
|
||||
*spacePtr = spaceChar; \
|
||||
} \
|
||||
if(mfself->file.isAutoSpacing && (n > 0)) \
|
||||
{ \
|
||||
if( (mfself->position < mfself->size) && (mfself->storage->data[mfself->position] == '\n') ) \
|
||||
mfself->position++; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if(nread != n) \
|
||||
{ \
|
||||
mfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \
|
||||
if(!mfself->file.isQuiet) \
|
||||
THError("read error: read %d blocks instead of %d", nread, n); \
|
||||
} \
|
||||
\
|
||||
return nread; \
|
||||
} \
|
||||
\
|
||||
static long THMemoryFile_write##TYPEC(THFile *self, TYPE *data, long n) \
|
||||
{ \
|
||||
THMemoryFile *mfself = (THMemoryFile*)self; \
|
||||
long nread = 0L; \
|
||||
\
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file"); \
|
||||
THArgCheck(mfself->file.isWritable, 1, "attempt to write in a read-only file"); \
|
||||
\
|
||||
if(mfself->file.isBinary) \
|
||||
{ \
|
||||
long nByte = sizeof(TYPE)*n; \
|
||||
THMemoryFile_grow(mfself, mfself->position+nByte); \
|
||||
memmove(mfself->storage->data+mfself->position, data, nByte); \
|
||||
mfself->position += nByte; \
|
||||
if(mfself->position > mfself->size) \
|
||||
{ \
|
||||
mfself->size = mfself->position; \
|
||||
mfself->storage->data[mfself->size] = '\0'; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
long i; \
|
||||
for(i = 0; i < n; i++) \
|
||||
{ \
|
||||
long nByteWritten; \
|
||||
while (1) \
|
||||
{ \
|
||||
ASCII_WRITE_ELEM; \
|
||||
if( (nByteWritten > -1) && (nByteWritten < mfself->storage->size-mfself->position) ) \
|
||||
{ \
|
||||
mfself->position += nByteWritten; \
|
||||
break; \
|
||||
} \
|
||||
THMemoryFile_grow(mfself, mfself->storage->size + (mfself->storage->size/2) + 2); \
|
||||
} \
|
||||
if(mfself->file.isAutoSpacing) \
|
||||
{ \
|
||||
if(i < n-1) \
|
||||
{ \
|
||||
THMemoryFile_grow(mfself, mfself->position+1); \
|
||||
sprintf(mfself->storage->data+mfself->position, " "); \
|
||||
mfself->position++; \
|
||||
} \
|
||||
if(i == n-1) \
|
||||
{ \
|
||||
THMemoryFile_grow(mfself, mfself->position+1); \
|
||||
sprintf(mfself->storage->data+mfself->position, "\n"); \
|
||||
mfself->position++; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
if(mfself->position > mfself->size) \
|
||||
{ \
|
||||
mfself->size = mfself->position; \
|
||||
mfself->storage->data[mfself->size] = '\0'; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
return n; \
|
||||
}
|
||||
|
||||
|
||||
THCharStorage *THMemoryFile_storage(THFile *self)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");
|
||||
|
||||
THCharStorage_resize(mfself->storage, mfself->size+1);
|
||||
|
||||
return mfself->storage;
|
||||
}
|
||||
|
||||
static void THMemoryFile_synchronize(THFile *self)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");
|
||||
}
|
||||
|
||||
static void THMemoryFile_seek(THFile *self, long position)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");
|
||||
THArgCheck(position >= 0, 2, "position must be positive");
|
||||
|
||||
if(position <= mfself->size)
|
||||
mfself->position = position;
|
||||
else
|
||||
{
|
||||
mfself->file.hasError = 1;
|
||||
if(!mfself->file.isQuiet)
|
||||
THError("unable to seek at position %d", position);
|
||||
}
|
||||
}
|
||||
|
||||
static void THMemoryFile_seekEnd(THFile *self)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");
|
||||
|
||||
mfself->position = mfself->size;
|
||||
}
|
||||
|
||||
static long THMemoryFile_position(THFile *self)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");
|
||||
return mfself->position;
|
||||
}
|
||||
|
||||
static void THMemoryFile_close(THFile *self)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");
|
||||
THCharStorage_free(mfself->storage);
|
||||
mfself->storage = NULL;
|
||||
}
|
||||
|
||||
static void THMemoryFile_free(THFile *self)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
|
||||
if(mfself->storage)
|
||||
THCharStorage_free(mfself->storage);
|
||||
|
||||
THFree(mfself);
|
||||
}
|
||||
|
||||
/* READ_WRITE_METHODS(bool, Bool, */
|
||||
/* int value = 0; int ret = sscanf(mfself->storage->data+mfself->position, "%d%n", &value, &nByteRead); data[i] = (value ? 1 : 0), */
|
||||
/* int value = (data[i] ? 1 : 0); nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%d", value), */
|
||||
/* 1) */
|
||||
|
||||
READ_WRITE_METHODS(unsigned char, Byte,
|
||||
long ret = (mfself->position + n <= mfself->size ? n : mfself->size-mfself->position); \
|
||||
if(spacePtr) *spacePtr = spaceChar; \
|
||||
nByteRead = ret; \
|
||||
nread = ret; \
|
||||
i = n-1; \
|
||||
memmove(data, mfself->storage->data+mfself->position, nByteRead),
|
||||
nByteWritten = (n < mfself->storage->size-mfself->position ? n : -1); \
|
||||
i = n-1; \
|
||||
if(nByteWritten > -1)
|
||||
memmove(mfself->storage->data+mfself->position, data, nByteWritten),
|
||||
0)
|
||||
|
||||
/* DEBUG: we should check if %n is count or not as a element (so ret might need to be ret-- on some systems) */
|
||||
/* Note that we do a trick for char */
|
||||
READ_WRITE_METHODS(char, Char,
|
||||
long ret = (mfself->position + n <= mfself->size ? n : mfself->size-mfself->position); \
|
||||
if(spacePtr) *spacePtr = spaceChar; \
|
||||
nByteRead = ret; \
|
||||
nread = ret; \
|
||||
i = n-1; \
|
||||
memmove(data, mfself->storage->data+mfself->position, nByteRead),
|
||||
nByteWritten = (n < mfself->storage->size-mfself->position ? n : -1); \
|
||||
i = n-1; \
|
||||
if(nByteWritten > -1)
|
||||
memmove(mfself->storage->data+mfself->position, data, nByteWritten),
|
||||
0)
|
||||
|
||||
READ_WRITE_METHODS(short, Short,
|
||||
int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%hd%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++,
|
||||
nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%hd", data[i]),
|
||||
1)
|
||||
|
||||
READ_WRITE_METHODS(int, Int,
|
||||
int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%d%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++,
|
||||
nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%d", data[i]),
|
||||
1)
|
||||
|
||||
READ_WRITE_METHODS(long, Long,
|
||||
int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%ld%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++,
|
||||
nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%ld", data[i]),
|
||||
1)
|
||||
|
||||
READ_WRITE_METHODS(float, Float,
|
||||
int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++,
|
||||
nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%g", data[i]),
|
||||
1)
|
||||
|
||||
READ_WRITE_METHODS(double, Double,
|
||||
int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%lg%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++,
|
||||
nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%lg", data[i]),
|
||||
1)
|
||||
|
||||
static char* THMemoryFile_cloneString(const char *str, long size)
|
||||
{
|
||||
char *cstr = THAlloc(size);
|
||||
memcpy(cstr, str, size);
|
||||
return cstr;
|
||||
}
|
||||
|
||||
static long THMemoryFile_readString(THFile *self, const char *format, char **str_)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");
|
||||
THArgCheck(mfself->file.isReadable, 1, "attempt to read in a write-only file");
|
||||
THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'");
|
||||
|
||||
if(mfself->position == mfself->size) /* eof ? */
|
||||
{
|
||||
mfself->file.hasError = 1;
|
||||
if(!mfself->file.isQuiet)
|
||||
THError("read error: read 0 blocks instead of 1");
|
||||
|
||||
*str_ = NULL;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(format[1] == 'a')
|
||||
{
|
||||
long str_size = mfself->size-mfself->position;
|
||||
|
||||
*str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, str_size);
|
||||
mfself->position = mfself->size;
|
||||
|
||||
return str_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
char *p = mfself->storage->data+mfself->position;
|
||||
long posEol = -1;
|
||||
long i;
|
||||
for(i = 0L; i < mfself->size-mfself->position; i++)
|
||||
{
|
||||
if(p[i] == '\n')
|
||||
{
|
||||
posEol = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(posEol >= 0)
|
||||
{
|
||||
*str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, posEol);
|
||||
mfself->position += posEol+1;
|
||||
return posEol;
|
||||
}
|
||||
else /* well, we read all! */
|
||||
{
|
||||
long str_size = mfself->size-mfself->position;
|
||||
|
||||
*str_ = THMemoryFile_cloneString(mfself->storage->data+mfself->position, str_size);
|
||||
mfself->position = mfself->size;
|
||||
|
||||
return str_size;
|
||||
}
|
||||
}
|
||||
|
||||
*str_ = NULL;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static long THMemoryFile_writeString(THFile *self, const char *str, long size)
|
||||
{
|
||||
THMemoryFile *mfself = (THMemoryFile*)self;
|
||||
|
||||
THArgCheck(mfself->storage != NULL, 1, "attempt to use a closed file");
|
||||
THArgCheck(mfself->file.isWritable, 1, "attempt to write in a read-only file");
|
||||
|
||||
THMemoryFile_grow(mfself, mfself->position+size);
|
||||
memmove(mfself->storage->data+mfself->position, str, size);
|
||||
mfself->position += size;
|
||||
if(mfself->position > mfself->size)
|
||||
{
|
||||
mfself->size = mfself->position;
|
||||
mfself->storage->data[mfself->size] = '\0';
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode)
|
||||
{
|
||||
static struct THFileVTable vtable = {
|
||||
THMemoryFile_isOpened,
|
||||
|
||||
THMemoryFile_readByte,
|
||||
THMemoryFile_readChar,
|
||||
THMemoryFile_readShort,
|
||||
THMemoryFile_readInt,
|
||||
THMemoryFile_readLong,
|
||||
THMemoryFile_readFloat,
|
||||
THMemoryFile_readDouble,
|
||||
THMemoryFile_readString,
|
||||
|
||||
THMemoryFile_writeByte,
|
||||
THMemoryFile_writeChar,
|
||||
THMemoryFile_writeShort,
|
||||
THMemoryFile_writeInt,
|
||||
THMemoryFile_writeLong,
|
||||
THMemoryFile_writeFloat,
|
||||
THMemoryFile_writeDouble,
|
||||
THMemoryFile_writeString,
|
||||
|
||||
THMemoryFile_synchronize,
|
||||
THMemoryFile_seek,
|
||||
THMemoryFile_seekEnd,
|
||||
THMemoryFile_position,
|
||||
THMemoryFile_close,
|
||||
THMemoryFile_free
|
||||
};
|
||||
|
||||
THMemoryFile *mfself;
|
||||
int isReadable;
|
||||
int isWritable;
|
||||
|
||||
if(storage)
|
||||
{
|
||||
THArgCheck(storage->data[storage->size-1] == '\0', 1, "provided CharStorage must be terminated by 0");
|
||||
THArgCheck(THMemoryFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'");
|
||||
THCharStorage_retain(storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
THArgCheck(THMemoryFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'");
|
||||
storage = THCharStorage_newWithSize(1);
|
||||
storage->data[0] = '\0';
|
||||
}
|
||||
|
||||
mfself = THAlloc(sizeof(THMemoryFile));
|
||||
|
||||
mfself->storage = storage;
|
||||
mfself->size = (storage ? storage->size-1 : 0);
|
||||
mfself->position = 0;
|
||||
|
||||
mfself->file.vtable = &vtable;
|
||||
mfself->file.isQuiet = 0;
|
||||
mfself->file.isReadable = isReadable;
|
||||
mfself->file.isWritable = isWritable;
|
||||
mfself->file.isBinary = 0;
|
||||
mfself->file.isAutoSpacing = 1;
|
||||
mfself->file.hasError = 0;
|
||||
|
||||
return (THFile*)mfself;
|
||||
}
|
||||
|
||||
THFile *THMemoryFile_new(const char *mode)
|
||||
{
|
||||
return THMemoryFile_newWithStorage(NULL, mode);
|
||||
}
|
||||
12
lib/TH/THMemoryFile.h
Normal file
12
lib/TH/THMemoryFile.h
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
#ifndef TH_MEMORY_FILE_INC
|
||||
#define TH_MEMORY_FILE_INC
|
||||
|
||||
#include "THFile.h"
|
||||
#include "THStorage.h"
|
||||
|
||||
THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode);
|
||||
THFile *THMemoryFile_new(const char *mode);
|
||||
|
||||
THCharStorage *THMemoryFile_storage(THFile *self);
|
||||
|
||||
#endif
|
||||
238
lib/TH/THRandom.c
Normal file
238
lib/TH/THRandom.c
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
#include "THGeneral.h"
|
||||
#include "THRandom.h"
|
||||
|
||||
/* The initial seed. */
|
||||
static unsigned long the_initial_seed;
|
||||
|
||||
/* Code for the Mersenne Twister random generator.... */
|
||||
#define n 624
|
||||
#define m 397
|
||||
static int left = 1;
|
||||
static int initf = 0;
|
||||
static unsigned long *next;
|
||||
static unsigned long state[n]; /* the array for the state vector */
|
||||
/********************************/
|
||||
|
||||
/* For normal distribution */
|
||||
static double normal_x;
|
||||
static double normal_y;
|
||||
static double normal_rho;
|
||||
static int normal_is_valid = 0;
|
||||
|
||||
unsigned long THRandom_seed()
|
||||
{
|
||||
unsigned long s = (unsigned long)time(0);
|
||||
THRandom_manualSeed(s);
|
||||
return s;
|
||||
}
|
||||
|
||||
/* The next 4 methods are taken from http:www.math.keio.ac.jpmatumotoemt.html
|
||||
Here is the copyright:
|
||||
Some minor modifications have been made to adapt to "my" C... */
|
||||
|
||||
/*
|
||||
A C-program for MT19937, with initialization improved 2002/2/10.
|
||||
Coded by Takuji Nishimura and Makoto Matsumoto.
|
||||
This is a faster version by taking Shawn Cokus's optimization,
|
||||
Matthe Bellew's simplification, Isaku Wada's double version.
|
||||
|
||||
Before using, initialize the state by using init_genrand(seed)
|
||||
or init_by_array(init_key, key_length).
|
||||
|
||||
Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. The names of its contributors may not be used to endorse or promote
|
||||
products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
Any feedback is very welcome.
|
||||
http://www.math.keio.ac.jp/matumoto/emt.html
|
||||
email: matumoto@math.keio.ac.jp
|
||||
*/
|
||||
|
||||
/* Macros for the Mersenne Twister random generator... */
|
||||
/* Period parameters */
|
||||
/* #define n 624 */
|
||||
/* #define m 397 */
|
||||
#define MATRIX_A 0x9908b0dfUL /* constant vector a */
|
||||
#define UMASK 0x80000000UL /* most significant w-r bits */
|
||||
#define LMASK 0x7fffffffUL /* least significant r bits */
|
||||
#define MIXBITS(u,v) ( ((u) & UMASK) | ((v) & LMASK) )
|
||||
#define TWIST(u,v) ((MIXBITS(u,v) >> 1) ^ ((v)&1UL ? MATRIX_A : 0UL))
|
||||
/*********************************************************** That's it. */
|
||||
|
||||
void THRandom_manualSeed(unsigned long the_seed_)
|
||||
{
|
||||
int j;
|
||||
the_initial_seed = the_seed_;
|
||||
state[0]= the_initial_seed & 0xffffffffUL;
|
||||
for(j = 1; j < n; j++)
|
||||
{
|
||||
state[j] = (1812433253UL * (state[j-1] ^ (state[j-1] >> 30)) + j);
|
||||
/* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */
|
||||
/* In the previous versions, mSBs of the seed affect */
|
||||
/* only mSBs of the array state[]. */
|
||||
/* 2002/01/09 modified by makoto matsumoto */
|
||||
state[j] &= 0xffffffffUL; /* for >32 bit machines */
|
||||
}
|
||||
left = 1;
|
||||
initf = 1;
|
||||
}
|
||||
|
||||
unsigned long THRandom_initialSeed()
|
||||
{
|
||||
if(initf == 0)
|
||||
{
|
||||
THRandom_seed();
|
||||
}
|
||||
|
||||
return the_initial_seed;
|
||||
}
|
||||
|
||||
void THRandom_nextState()
|
||||
{
|
||||
unsigned long *p=state;
|
||||
int j;
|
||||
|
||||
/* if init_genrand() has not been called, */
|
||||
/* a default initial seed is used */
|
||||
if(initf == 0)
|
||||
THRandom_seed();
|
||||
|
||||
left = n;
|
||||
next = state;
|
||||
|
||||
for(j = n-m+1; --j; p++)
|
||||
*p = p[m] ^ TWIST(p[0], p[1]);
|
||||
|
||||
for(j = m; --j; p++)
|
||||
*p = p[m-n] ^ TWIST(p[0], p[1]);
|
||||
|
||||
*p = p[m-n] ^ TWIST(p[0], state[0]);
|
||||
}
|
||||
|
||||
unsigned long THRandom_random()
|
||||
{
|
||||
unsigned long y;
|
||||
|
||||
if (--left == 0)
|
||||
THRandom_nextState();
|
||||
y = *next++;
|
||||
|
||||
/* Tempering */
|
||||
y ^= (y >> 11);
|
||||
y ^= (y << 7) & 0x9d2c5680UL;
|
||||
y ^= (y << 15) & 0xefc60000UL;
|
||||
y ^= (y >> 18);
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
/* generates a random number on [0,1)-double-interval */
|
||||
static double __uniform__()
|
||||
{
|
||||
unsigned long y;
|
||||
|
||||
if(--left == 0)
|
||||
THRandom_nextState();
|
||||
y = *next++;
|
||||
|
||||
/* Tempering */
|
||||
y ^= (y >> 11);
|
||||
y ^= (y << 7) & 0x9d2c5680UL;
|
||||
y ^= (y << 15) & 0xefc60000UL;
|
||||
y ^= (y >> 18);
|
||||
|
||||
return (double)y * (1.0/4294967296.0);
|
||||
/* divided by 2^32 */
|
||||
}
|
||||
|
||||
/*********************************************************
|
||||
|
||||
Thanks *a lot* Takuji Nishimura and Makoto Matsumoto!
|
||||
|
||||
Now my own code...
|
||||
|
||||
*********************************************************/
|
||||
|
||||
double THRandom_uniform(double a, double b)
|
||||
{
|
||||
return(__uniform__() * (b - a) + a);
|
||||
}
|
||||
|
||||
double THRandom_normal(double mean, double stdv)
|
||||
{
|
||||
THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive");
|
||||
|
||||
if(!normal_is_valid)
|
||||
{
|
||||
normal_x = __uniform__();
|
||||
normal_y = __uniform__();
|
||||
normal_rho = sqrt(-2. * log(1.0-normal_y));
|
||||
normal_is_valid = 1;
|
||||
}
|
||||
else
|
||||
normal_is_valid = 0;
|
||||
|
||||
if(normal_is_valid)
|
||||
return normal_rho*cos(2.*M_PI*normal_x)*stdv+mean;
|
||||
else
|
||||
return normal_rho*sin(2.*M_PI*normal_x)*stdv+mean;
|
||||
}
|
||||
|
||||
double THRandom_exponential(double lambda)
|
||||
{
|
||||
return(-1. / lambda * log(1-__uniform__()));
|
||||
}
|
||||
|
||||
double THRandom_cauchy(double median, double sigma)
|
||||
{
|
||||
return(median + sigma * tan(M_PI*(__uniform__()-0.5)));
|
||||
}
|
||||
|
||||
/* Faut etre malade pour utiliser ca.
|
||||
M'enfin. */
|
||||
double THRandom_logNormal(double mean, double stdv)
|
||||
{
|
||||
double zm = mean*mean;
|
||||
double zs = stdv*stdv;
|
||||
THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive");
|
||||
return(exp(THRandom_normal(log(zm/sqrt(zs + zm)), sqrt(log(zs/zm+1)) )));
|
||||
}
|
||||
|
||||
int THRandom_geometric(double p)
|
||||
{
|
||||
THArgCheck(p > 0 && p < 1, 1, "must be > 0 and < 1");
|
||||
return((int)(log(1-__uniform__()) / log(p)) + 1);
|
||||
}
|
||||
|
||||
int THRandom_bernoulli(double p)
|
||||
{
|
||||
THArgCheck(p > 0 && p < 1, 1, "must be > 0 and < 1");
|
||||
return(__uniform__() <= p);
|
||||
}
|
||||
52
lib/TH/THRandom.h
Normal file
52
lib/TH/THRandom.h
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
#ifndef TH_RANDOM_INC
|
||||
#define TH_RANDOM_INC
|
||||
|
||||
#include "THGeneral.h"
|
||||
|
||||
/* Initializes the random number generator with the current time (granularity: seconds) and returns the seed. */
|
||||
TH_API unsigned long THRandom_seed();
|
||||
|
||||
/* Initializes the random number generator with the given long "the_seed_". */
|
||||
TH_API void THRandom_manualSeed(unsigned long the_seed_);
|
||||
|
||||
/* Returns the starting seed used. */
|
||||
TH_API unsigned long THRandom_initialSeed();
|
||||
|
||||
/* Generates a uniform 32 bits integer. */
|
||||
TH_API unsigned long THRandom_random();
|
||||
|
||||
/* Generates a uniform random number on [0,1[. */
|
||||
TH_API double THRandom_uniform(double a, double b);
|
||||
|
||||
/** Generates a random number from a normal distribution.
|
||||
(With mean #mean# and standard deviation #stdv >= 0#).
|
||||
*/
|
||||
TH_API double THRandom_normal(double mean, double stdv);
|
||||
|
||||
/** Generates a random number from an exponential distribution.
|
||||
The density is $p(x) = lambda * exp(-lambda * x)$, where
|
||||
lambda is a positive number.
|
||||
*/
|
||||
TH_API double THRandom_exponential(double lambda);
|
||||
|
||||
/** Returns a random number from a Cauchy distribution.
|
||||
The Cauchy density is $p(x) = sigma/(pi*(sigma^2 + (x-median)^2))$
|
||||
*/
|
||||
TH_API double THRandom_cauchy(double median, double sigma);
|
||||
|
||||
/** Generates a random number from a log-normal distribution.
|
||||
(#mean > 0# is the mean of the log-normal distribution
|
||||
and #stdv# is its standard deviation).
|
||||
*/
|
||||
TH_API double THRandom_logNormal(double mean, double stdv);
|
||||
|
||||
/** Generates a random number from a geometric distribution.
|
||||
It returns an integer #i#, where $p(i) = (1-p) * p^(i-1)$.
|
||||
p must satisfy $0 < p < 1$.
|
||||
*/
|
||||
TH_API int THRandom_geometric(double p);
|
||||
|
||||
/* Returns true with probability $p$ and false with probability $1-p$ (p > 0). */
|
||||
TH_API int THRandom_bernoulli(double p);
|
||||
|
||||
#endif
|
||||
7
lib/TH/THStorage.c
Normal file
7
lib/TH/THStorage.c
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
#include "THStorage.h"
|
||||
|
||||
#include "generic/THStorage.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "generic/THStorageCopy.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
33
lib/TH/THStorage.h
Normal file
33
lib/TH/THStorage.h
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#ifndef TH_STORAGE_INC
|
||||
#define TH_STORAGE_INC
|
||||
|
||||
#include "THGeneral.h"
|
||||
|
||||
/* stuff for mapped files */
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
#if HAVE_MMAP
|
||||
#include <sys/types.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
/* end of stuff for mapped files */
|
||||
|
||||
#define THStorage TH_CONCAT_3(TH,Real,Storage)
|
||||
#define THStorage_(NAME) TH_CONCAT_4(TH,Real,Storage_,NAME)
|
||||
|
||||
/* fast access methods */
|
||||
#define TH_STORAGE_GET(storage, idx) ((storage)->data[(idx)])
|
||||
#define TH_STORAGE_SET(storage, idx, value) ((storage)->data[(idx)] = (value))
|
||||
|
||||
#include "generic/THStorage.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "generic/THStorageCopy.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#endif
|
||||
24
lib/TH/THTensor.c
Normal file
24
lib/TH/THTensor.c
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#include "THTensor.h"
|
||||
#include "THVector.h"
|
||||
#include "THBlas.h"
|
||||
#include "THLapack.h"
|
||||
#include "THRandom.h"
|
||||
#include "THTensorDimApply.h"
|
||||
|
||||
#include "generic/THTensor.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "generic/THTensorCopy.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "generic/THTensorRandom.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "generic/THTensorMath.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "generic/THTensorConv.c"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "generic/THTensorLapack.c"
|
||||
#include "THGenerateFloatTypes.h"
|
||||
35
lib/TH/THTensor.h
Normal file
35
lib/TH/THTensor.h
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
#ifndef TH_TENSOR_INC
|
||||
#define TH_TENSOR_INC
|
||||
|
||||
#include "THStorage.h"
|
||||
#include "THTensorApply.h"
|
||||
|
||||
#define THTensor TH_CONCAT_3(TH,Real,Tensor)
|
||||
#define THTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME)
|
||||
|
||||
/* basics */
|
||||
#include "generic/THTensor.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "generic/THTensorCopy.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
#include "THTensorMacros.h"
|
||||
|
||||
/* random numbers */
|
||||
#include "generic/THTensorRandom.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
/* maths */
|
||||
#include "generic/THTensorMath.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
/* convolutions */
|
||||
#include "generic/THTensorConv.h"
|
||||
#include "THGenerateAllTypes.h"
|
||||
|
||||
/* lapack support */
|
||||
#include "generic/THTensorLapack.h"
|
||||
#include "THGenerateFloatTypes.h"
|
||||
|
||||
#endif
|
||||
428
lib/TH/THTensorApply.h
Normal file
428
lib/TH/THTensorApply.h
Normal file
|
|
@ -0,0 +1,428 @@
|
|||
#ifndef TH_TENSOR_APPLY_INC
|
||||
#define TH_TENSOR_APPLY_INC
|
||||
|
||||
#define TH_TENSOR_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, CODE) \
|
||||
{ \
|
||||
TYPE1 *TENSOR1##_data = NULL; \
|
||||
long *TENSOR1##_counter = NULL; \
|
||||
long TENSOR1##_stride = 0, TENSOR1##_size = 0, TENSOR1##_dim = 0, TENSOR1##_i, TENSOR1##_n; \
|
||||
TYPE2 *TENSOR2##_data = NULL; \
|
||||
long *TENSOR2##_counter = NULL; \
|
||||
long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \
|
||||
TYPE2 *TENSOR3##_data = NULL; \
|
||||
long *TENSOR3##_counter = NULL; \
|
||||
long TENSOR3##_stride = 0, TENSOR3##_size = 0, TENSOR3##_dim = 0, TENSOR3##_i, TENSOR3##_n; \
|
||||
int TH_TENSOR_APPLY_hasFinished = 0; \
|
||||
\
|
||||
TENSOR1##_n = (TENSOR1->nDimension ? 1 : 0); \
|
||||
for(TENSOR1##_i = 0; TENSOR1##_i < TENSOR1->nDimension; TENSOR1##_i++) \
|
||||
TENSOR1##_n *= TENSOR1->size[TENSOR1##_i]; \
|
||||
\
|
||||
TENSOR2##_n = (TENSOR2->nDimension ? 1 : 0); \
|
||||
for(TENSOR2##_i = 0; TENSOR2##_i < TENSOR2->nDimension; TENSOR2##_i++) \
|
||||
TENSOR2##_n *= TENSOR2->size[TENSOR2##_i]; \
|
||||
\
|
||||
TENSOR3##_n = (TENSOR3->nDimension ? 1 : 0); \
|
||||
for(TENSOR3##_i = 0; TENSOR3##_i < TENSOR3->nDimension; TENSOR3##_i++) \
|
||||
TENSOR3##_n *= TENSOR3->size[TENSOR3##_i]; \
|
||||
\
|
||||
if(TENSOR1##_n != TENSOR2##_n || TENSOR1##_n != TENSOR3##_n) /* should we do the check in the function instead? i think so */ \
|
||||
THError("inconsistent tensor size"); \
|
||||
\
|
||||
if(TENSOR1->nDimension == 0) \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
else \
|
||||
{ \
|
||||
TENSOR1##_data = TENSOR1->storage->data+TENSOR1->storageOffset; \
|
||||
for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \
|
||||
{ \
|
||||
if(TENSOR1->size[TENSOR1##_dim] != 1) \
|
||||
break; \
|
||||
} \
|
||||
TENSOR1##_stride = (TENSOR1##_dim == -1 ? 0 : TENSOR1->stride[TENSOR1##_dim]); \
|
||||
TENSOR1##_size = 1; \
|
||||
for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \
|
||||
{ \
|
||||
if(TENSOR1->size[TENSOR1##_dim] != 1) \
|
||||
{ \
|
||||
if(TENSOR1->stride[TENSOR1##_dim] == TENSOR1##_size) \
|
||||
TENSOR1##_size *= TENSOR1->size[TENSOR1##_dim]; \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
TENSOR1##_counter = (long*)THAlloc(sizeof(long)*(TENSOR1##_dim+1)); \
|
||||
for(TENSOR1##_i = 0; TENSOR1##_i <= TENSOR1##_dim; TENSOR1##_i++) \
|
||||
TENSOR1##_counter[TENSOR1##_i] = 0; \
|
||||
\
|
||||
TENSOR2##_data = TENSOR2->storage->data+TENSOR2->storageOffset; \
|
||||
for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \
|
||||
{ \
|
||||
if(TENSOR2->size[TENSOR2##_dim] != 1) \
|
||||
break; \
|
||||
} \
|
||||
TENSOR2##_stride = (TENSOR2##_dim == -1 ? 0 : TENSOR2->stride[TENSOR2##_dim]); \
|
||||
TENSOR2##_size = 1; \
|
||||
for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \
|
||||
{ \
|
||||
if(TENSOR2->size[TENSOR2##_dim] != 1) \
|
||||
{ \
|
||||
if(TENSOR2->stride[TENSOR2##_dim] == TENSOR2##_size) \
|
||||
TENSOR2##_size *= TENSOR2->size[TENSOR2##_dim]; \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
TENSOR2##_counter = (long*)THAlloc(sizeof(long)*(TENSOR2##_dim+1)); \
|
||||
for(TENSOR2##_i = 0; TENSOR2##_i <= TENSOR2##_dim; TENSOR2##_i++) \
|
||||
TENSOR2##_counter[TENSOR2##_i] = 0; \
|
||||
\
|
||||
TENSOR3##_data = TENSOR3->storage->data+TENSOR3->storageOffset; \
|
||||
for(TENSOR3##_dim = TENSOR3->nDimension-1; TENSOR3##_dim >= 0; TENSOR3##_dim--) \
|
||||
{ \
|
||||
if(TENSOR3->size[TENSOR3##_dim] != 1) \
|
||||
break; \
|
||||
} \
|
||||
TENSOR3##_stride = (TENSOR3##_dim == -1 ? 0 : TENSOR3->stride[TENSOR3##_dim]); \
|
||||
TENSOR3##_size = 1; \
|
||||
for(TENSOR3##_dim = TENSOR3->nDimension-1; TENSOR3##_dim >= 0; TENSOR3##_dim--) \
|
||||
{ \
|
||||
if(TENSOR3->size[TENSOR3##_dim] != 1) \
|
||||
{ \
|
||||
if(TENSOR3->stride[TENSOR3##_dim] == TENSOR3##_size) \
|
||||
TENSOR3##_size *= TENSOR3->size[TENSOR3##_dim]; \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
TENSOR3##_counter = (long*)THAlloc(sizeof(long)*(TENSOR3##_dim+1)); \
|
||||
for(TENSOR3##_i = 0; TENSOR3##_i <= TENSOR3##_dim; TENSOR3##_i++) \
|
||||
TENSOR3##_counter[TENSOR3##_i] = 0; \
|
||||
} \
|
||||
\
|
||||
TENSOR1##_i = 0; \
|
||||
TENSOR2##_i = 0; \
|
||||
TENSOR3##_i = 0; \
|
||||
while(!TH_TENSOR_APPLY_hasFinished) \
|
||||
{ \
|
||||
for(; TENSOR1##_i < TENSOR1##_size && TENSOR2##_i < TENSOR2##_size && TENSOR3##_i < TENSOR3##_size; TENSOR1##_i++, TENSOR2##_i++, TENSOR3##_i++, TENSOR1##_data += TENSOR1##_stride, TENSOR2##_data += TENSOR2##_stride, TENSOR3##_data += TENSOR3##_stride) /* 0 et pas TENSOR##_dim! */ \
|
||||
{ \
|
||||
CODE \
|
||||
} \
|
||||
\
|
||||
if(TENSOR1##_i == TENSOR1##_size) \
|
||||
{ \
|
||||
if(TENSOR1##_dim == -1) \
|
||||
break; \
|
||||
\
|
||||
TENSOR1##_data -= TENSOR1##_size*TENSOR1##_stride; \
|
||||
for(TENSOR1##_i = TENSOR1##_dim; TENSOR1##_i >= 0; TENSOR1##_i--) \
|
||||
{ \
|
||||
TENSOR1##_counter[TENSOR1##_i]++; \
|
||||
TENSOR1##_data += TENSOR1->stride[TENSOR1##_i]; \
|
||||
\
|
||||
if(TENSOR1##_counter[TENSOR1##_i] == TENSOR1->size[TENSOR1##_i]) \
|
||||
{ \
|
||||
if(TENSOR1##_i == 0) \
|
||||
{ \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR1##_data -= TENSOR1##_counter[TENSOR1##_i]*TENSOR1->stride[TENSOR1##_i]; \
|
||||
TENSOR1##_counter[TENSOR1##_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
TENSOR1##_i = 0; \
|
||||
} \
|
||||
\
|
||||
if(TENSOR2##_i == TENSOR2##_size) \
|
||||
{ \
|
||||
if(TENSOR2##_dim == -1) \
|
||||
break; \
|
||||
\
|
||||
TENSOR2##_data -= TENSOR2##_size*TENSOR2##_stride; \
|
||||
for(TENSOR2##_i = TENSOR2##_dim; TENSOR2##_i >= 0; TENSOR2##_i--) \
|
||||
{ \
|
||||
TENSOR2##_counter[TENSOR2##_i]++; \
|
||||
TENSOR2##_data += TENSOR2->stride[TENSOR2##_i]; \
|
||||
\
|
||||
if(TENSOR2##_counter[TENSOR2##_i] == TENSOR2->size[TENSOR2##_i]) \
|
||||
{ \
|
||||
if(TENSOR2##_i == 0) \
|
||||
{ \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR2##_data -= TENSOR2##_counter[TENSOR2##_i]*TENSOR2->stride[TENSOR2##_i]; \
|
||||
TENSOR2##_counter[TENSOR2##_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
TENSOR2##_i = 0; \
|
||||
} \
|
||||
\
|
||||
if(TENSOR3##_i == TENSOR3##_size) \
|
||||
{ \
|
||||
if(TENSOR3##_dim == -1) \
|
||||
break; \
|
||||
\
|
||||
TENSOR3##_data -= TENSOR3##_size*TENSOR3##_stride; \
|
||||
for(TENSOR3##_i = TENSOR3##_dim; TENSOR3##_i >= 0; TENSOR3##_i--) \
|
||||
{ \
|
||||
TENSOR3##_counter[TENSOR3##_i]++; \
|
||||
TENSOR3##_data += TENSOR3->stride[TENSOR3##_i]; \
|
||||
\
|
||||
if(TENSOR3##_counter[TENSOR3##_i] == TENSOR3->size[TENSOR3##_i]) \
|
||||
{ \
|
||||
if(TENSOR3##_i == 0) \
|
||||
{ \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR3##_data -= TENSOR3##_counter[TENSOR3##_i]*TENSOR3->stride[TENSOR3##_i]; \
|
||||
TENSOR3##_counter[TENSOR3##_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
TENSOR3##_i = 0; \
|
||||
} \
|
||||
} \
|
||||
THFree(TENSOR1##_counter); \
|
||||
THFree(TENSOR2##_counter); \
|
||||
THFree(TENSOR3##_counter); \
|
||||
}
|
||||
|
||||
#define TH_TENSOR_APPLY2(TYPE1, TENSOR1, TYPE2, TENSOR2, CODE) \
|
||||
{ \
|
||||
TYPE1 *TENSOR1##_data = NULL; \
|
||||
long *TENSOR1##_counter = NULL; \
|
||||
long TENSOR1##_stride = 0, TENSOR1##_size = 0, TENSOR1##_dim = 0, TENSOR1##_i, TENSOR1##_n; \
|
||||
TYPE2 *TENSOR2##_data = NULL; \
|
||||
long *TENSOR2##_counter = NULL; \
|
||||
long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \
|
||||
int TH_TENSOR_APPLY_hasFinished = 0; \
|
||||
\
|
||||
TENSOR1##_n = (TENSOR1->nDimension ? 1 : 0); \
|
||||
for(TENSOR1##_i = 0; TENSOR1##_i < TENSOR1->nDimension; TENSOR1##_i++) \
|
||||
TENSOR1##_n *= TENSOR1->size[TENSOR1##_i]; \
|
||||
\
|
||||
TENSOR2##_n = (TENSOR2->nDimension ? 1 : 0); \
|
||||
for(TENSOR2##_i = 0; TENSOR2##_i < TENSOR2->nDimension; TENSOR2##_i++) \
|
||||
TENSOR2##_n *= TENSOR2->size[TENSOR2##_i]; \
|
||||
\
|
||||
if(TENSOR1##_n != TENSOR2##_n) /* should we do the check in the function instead? i think so */ \
|
||||
THError("inconsistent tensor size"); \
|
||||
\
|
||||
if(TENSOR1->nDimension == 0) \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
else \
|
||||
{ \
|
||||
TENSOR1##_data = TENSOR1->storage->data+TENSOR1->storageOffset; \
|
||||
for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \
|
||||
{ \
|
||||
if(TENSOR1->size[TENSOR1##_dim] != 1) \
|
||||
break; \
|
||||
} \
|
||||
TENSOR1##_stride = (TENSOR1##_dim == -1 ? 0 : TENSOR1->stride[TENSOR1##_dim]); \
|
||||
TENSOR1##_size = 1; \
|
||||
for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \
|
||||
{ \
|
||||
if(TENSOR1->size[TENSOR1##_dim] != 1) \
|
||||
{ \
|
||||
if(TENSOR1->stride[TENSOR1##_dim] == TENSOR1##_size) \
|
||||
TENSOR1##_size *= TENSOR1->size[TENSOR1##_dim]; \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
TENSOR1##_counter = (long*)THAlloc(sizeof(long)*(TENSOR1##_dim+1)); \
|
||||
for(TENSOR1##_i = 0; TENSOR1##_i <= TENSOR1##_dim; TENSOR1##_i++) \
|
||||
TENSOR1##_counter[TENSOR1##_i] = 0; \
|
||||
\
|
||||
TENSOR2##_data = TENSOR2->storage->data+TENSOR2->storageOffset; \
|
||||
for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \
|
||||
{ \
|
||||
if(TENSOR2->size[TENSOR2##_dim] != 1) \
|
||||
break; \
|
||||
} \
|
||||
TENSOR2##_stride = (TENSOR2##_dim == -1 ? 0 : TENSOR2->stride[TENSOR2##_dim]); \
|
||||
TENSOR2##_size = 1; \
|
||||
for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \
|
||||
{ \
|
||||
if(TENSOR2->size[TENSOR2##_dim] != 1) \
|
||||
{ \
|
||||
if(TENSOR2->stride[TENSOR2##_dim] == TENSOR2##_size) \
|
||||
TENSOR2##_size *= TENSOR2->size[TENSOR2##_dim]; \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
TENSOR2##_counter = (long*)THAlloc(sizeof(long)*(TENSOR2##_dim+1)); \
|
||||
for(TENSOR2##_i = 0; TENSOR2##_i <= TENSOR2##_dim; TENSOR2##_i++) \
|
||||
TENSOR2##_counter[TENSOR2##_i] = 0; \
|
||||
} \
|
||||
\
|
||||
TENSOR1##_i = 0; \
|
||||
TENSOR2##_i = 0; \
|
||||
while(!TH_TENSOR_APPLY_hasFinished) \
|
||||
{ \
|
||||
for(; TENSOR1##_i < TENSOR1##_size && TENSOR2##_i < TENSOR2##_size; TENSOR1##_i++, TENSOR2##_i++, TENSOR1##_data += TENSOR1##_stride, TENSOR2##_data += TENSOR2##_stride) /* 0 et pas TENSOR##_dim! */ \
|
||||
{ \
|
||||
CODE \
|
||||
} \
|
||||
\
|
||||
if(TENSOR1##_i == TENSOR1##_size) \
|
||||
{ \
|
||||
if(TENSOR1##_dim == -1) \
|
||||
break; \
|
||||
\
|
||||
TENSOR1##_data -= TENSOR1##_size*TENSOR1##_stride; \
|
||||
for(TENSOR1##_i = TENSOR1##_dim; TENSOR1##_i >= 0; TENSOR1##_i--) \
|
||||
{ \
|
||||
TENSOR1##_counter[TENSOR1##_i]++; \
|
||||
TENSOR1##_data += TENSOR1->stride[TENSOR1##_i]; \
|
||||
\
|
||||
if(TENSOR1##_counter[TENSOR1##_i] == TENSOR1->size[TENSOR1##_i]) \
|
||||
{ \
|
||||
if(TENSOR1##_i == 0) \
|
||||
{ \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR1##_data -= TENSOR1##_counter[TENSOR1##_i]*TENSOR1->stride[TENSOR1##_i]; \
|
||||
TENSOR1##_counter[TENSOR1##_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
TENSOR1##_i = 0; \
|
||||
} \
|
||||
\
|
||||
if(TENSOR2##_i == TENSOR2##_size) \
|
||||
{ \
|
||||
if(TENSOR2##_dim == -1) \
|
||||
break; \
|
||||
\
|
||||
TENSOR2##_data -= TENSOR2##_size*TENSOR2##_stride; \
|
||||
for(TENSOR2##_i = TENSOR2##_dim; TENSOR2##_i >= 0; TENSOR2##_i--) \
|
||||
{ \
|
||||
TENSOR2##_counter[TENSOR2##_i]++; \
|
||||
TENSOR2##_data += TENSOR2->stride[TENSOR2##_i]; \
|
||||
\
|
||||
if(TENSOR2##_counter[TENSOR2##_i] == TENSOR2->size[TENSOR2##_i]) \
|
||||
{ \
|
||||
if(TENSOR2##_i == 0) \
|
||||
{ \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR2##_data -= TENSOR2##_counter[TENSOR2##_i]*TENSOR2->stride[TENSOR2##_i]; \
|
||||
TENSOR2##_counter[TENSOR2##_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
TENSOR2##_i = 0; \
|
||||
} \
|
||||
} \
|
||||
THFree(TENSOR1##_counter); \
|
||||
THFree(TENSOR2##_counter); \
|
||||
}
|
||||
|
||||
#define TH_TENSOR_APPLY(TYPE, TENSOR, CODE) \
|
||||
{ \
|
||||
TYPE *TENSOR##_data = NULL; \
|
||||
long *TENSOR##_counter = NULL; \
|
||||
long TENSOR##_stride = 0, TENSOR##_size = 0, TENSOR##_dim = 0, TENSOR##_i; \
|
||||
int TH_TENSOR_APPLY_hasFinished = 0; \
|
||||
\
|
||||
if(TENSOR->nDimension == 0) \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
else \
|
||||
{ \
|
||||
TENSOR##_data = TENSOR->storage->data+TENSOR->storageOffset; \
|
||||
\
|
||||
/* what is the first stride (ignore first dims=1)? */ \
|
||||
/* it will be used for the whole largest contiguous section */ \
|
||||
for(TENSOR##_dim = TENSOR->nDimension-1; TENSOR##_dim >= 0; TENSOR##_dim--) \
|
||||
{ \
|
||||
if(TENSOR->size[TENSOR##_dim] != 1) \
|
||||
break; \
|
||||
} \
|
||||
TENSOR##_stride = (TENSOR##_dim == -1 ? 0 : TENSOR->stride[TENSOR##_dim]); \
|
||||
\
|
||||
/* what is the largest contiguous section? */ \
|
||||
TENSOR##_size = 1; \
|
||||
for(TENSOR##_dim = TENSOR->nDimension-1; TENSOR##_dim >= 0; TENSOR##_dim--) \
|
||||
{ \
|
||||
if(TENSOR->size[TENSOR##_dim] != 1) \
|
||||
{ \
|
||||
if(TENSOR->stride[TENSOR##_dim] == TENSOR##_size) \
|
||||
TENSOR##_size *= TENSOR->size[TENSOR##_dim]; \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
/* counter over found dimensions */ \
|
||||
TENSOR##_counter = (long*)THAlloc(sizeof(long)*(TENSOR##_dim+1)); \
|
||||
for(TENSOR##_i = 0; TENSOR##_i <= TENSOR##_dim; TENSOR##_i++) \
|
||||
TENSOR##_counter[TENSOR##_i] = 0; \
|
||||
} \
|
||||
\
|
||||
while(!TH_TENSOR_APPLY_hasFinished) \
|
||||
{ \
|
||||
for(TENSOR##_i = 0; TENSOR##_i < TENSOR##_size; TENSOR##_i++, TENSOR##_data += TENSOR##_stride) /* 0 et pas TENSOR##_dim! */ \
|
||||
{ \
|
||||
CODE \
|
||||
} \
|
||||
\
|
||||
if(TENSOR##_dim == -1) \
|
||||
break; \
|
||||
\
|
||||
TENSOR##_data -= TENSOR##_i*TENSOR##_stride; \
|
||||
for(TENSOR##_i = TENSOR##_dim; TENSOR##_i >= 0; TENSOR##_i--) \
|
||||
{ \
|
||||
TENSOR##_counter[TENSOR##_i]++; \
|
||||
TENSOR##_data += TENSOR->stride[TENSOR##_i]; \
|
||||
\
|
||||
if(TENSOR##_counter[TENSOR##_i] == TENSOR->size[TENSOR##_i]) \
|
||||
{ \
|
||||
if(TENSOR##_i == 0) \
|
||||
{ \
|
||||
TH_TENSOR_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR##_data -= TENSOR##_counter[TENSOR##_i]*TENSOR->stride[TENSOR##_i]; \
|
||||
TENSOR##_counter[TENSOR##_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
THFree(TENSOR##_counter); \
|
||||
}
|
||||
|
||||
#endif
|
||||
232
lib/TH/THTensorDimApply.h
Normal file
232
lib/TH/THTensorDimApply.h
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
#ifndef TH_TENSOR_DIM_APPLY_INC
|
||||
#define TH_TENSOR_DIM_APPLY_INC
|
||||
|
||||
#define TH_TENSOR_DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIMENSION, CODE) \
|
||||
{ \
|
||||
TYPE1 *TENSOR1##_data = NULL; \
|
||||
long TENSOR1##_stride = 0, TENSOR1##_size = 0; \
|
||||
TYPE2 *TENSOR2##_data = NULL; \
|
||||
long TENSOR2##_stride = 0, TENSOR2##_size = 0; \
|
||||
TYPE3 *TENSOR3##_data = NULL; \
|
||||
long TENSOR3##_stride = 0, TENSOR3##_size = 0; \
|
||||
long *TH_TENSOR_DIM_APPLY_counter = NULL; \
|
||||
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
|
||||
int TH_TENSOR_DIM_APPLY_i; \
|
||||
\
|
||||
if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->nDimension) ) \
|
||||
THError("invalid dimension"); \
|
||||
if( TENSOR1->nDimension != TENSOR2->nDimension ) \
|
||||
THError("inconsistent tensor sizes"); \
|
||||
if( TENSOR1->nDimension != TENSOR3->nDimension ) \
|
||||
THError("inconsistent tensor sizes"); \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \
|
||||
continue; \
|
||||
if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) \
|
||||
THError("inconsistent tensor sizes"); \
|
||||
if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR3->size[TH_TENSOR_DIM_APPLY_i]) \
|
||||
THError("inconsistent tensor sizes"); \
|
||||
} \
|
||||
\
|
||||
TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR1->nDimension)); \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
|
||||
\
|
||||
TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \
|
||||
TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \
|
||||
TENSOR1##_size = TENSOR1->size[DIMENSION]; \
|
||||
\
|
||||
TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \
|
||||
TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \
|
||||
TENSOR2##_size = TENSOR2->size[DIMENSION]; \
|
||||
\
|
||||
TENSOR3##_data = (TENSOR3)->storage->data+(TENSOR3)->storageOffset; \
|
||||
TENSOR3##_stride = (TENSOR3)->stride[DIMENSION]; \
|
||||
TENSOR3##_size = TENSOR3->size[DIMENSION]; \
|
||||
\
|
||||
while(!TH_TENSOR_DIM_APPLY_hasFinished) \
|
||||
{ \
|
||||
CODE \
|
||||
\
|
||||
if(TENSOR1->nDimension == 1) \
|
||||
break; \
|
||||
\
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \
|
||||
{ \
|
||||
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
continue; \
|
||||
} \
|
||||
\
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \
|
||||
TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TENSOR3##_data += TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
\
|
||||
if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \
|
||||
{ \
|
||||
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TENSOR3##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
THFree(TH_TENSOR_DIM_APPLY_counter); \
|
||||
}
|
||||
|
||||
#define TH_TENSOR_DIM_APPLY2(TYPE1, TENSOR1, TYPE2, TENSOR2, DIMENSION, CODE) \
|
||||
{ \
|
||||
TYPE1 *TENSOR1##_data = NULL; \
|
||||
long TENSOR1##_stride = 0, TENSOR1##_size = 0; \
|
||||
TYPE2 *TENSOR2##_data = NULL; \
|
||||
long TENSOR2##_stride = 0, TENSOR2##_size = 0; \
|
||||
long *TH_TENSOR_DIM_APPLY_counter = NULL; \
|
||||
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
|
||||
int TH_TENSOR_DIM_APPLY_i; \
|
||||
\
|
||||
if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->nDimension) ) \
|
||||
THError("invalid dimension"); \
|
||||
if( TENSOR1->nDimension != TENSOR2->nDimension ) \
|
||||
THError("inconsistent tensor sizes"); \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \
|
||||
continue; \
|
||||
if(TENSOR1->size[TH_TENSOR_DIM_APPLY_i] != TENSOR2->size[TH_TENSOR_DIM_APPLY_i]) \
|
||||
THError("inconsistent tensor sizes"); \
|
||||
} \
|
||||
\
|
||||
TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR1->nDimension)); \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
|
||||
\
|
||||
TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \
|
||||
TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \
|
||||
TENSOR1##_size = TENSOR1->size[DIMENSION]; \
|
||||
\
|
||||
TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \
|
||||
TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \
|
||||
TENSOR2##_size = TENSOR2->size[DIMENSION]; \
|
||||
\
|
||||
while(!TH_TENSOR_DIM_APPLY_hasFinished) \
|
||||
{ \
|
||||
CODE \
|
||||
\
|
||||
if(TENSOR1->nDimension == 1) \
|
||||
break; \
|
||||
\
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \
|
||||
{ \
|
||||
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
continue; \
|
||||
} \
|
||||
\
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \
|
||||
TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
\
|
||||
if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) \
|
||||
{ \
|
||||
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
THFree(TH_TENSOR_DIM_APPLY_counter); \
|
||||
}
|
||||
|
||||
#define TH_TENSOR_DIM_APPLY(TYPE, TENSOR, DIMENSION, CODE) \
|
||||
{ \
|
||||
TYPE *TENSOR##_data = NULL; \
|
||||
long TENSOR##_stride = 0, TENSOR##_size = 0; \
|
||||
long *TH_TENSOR_DIM_APPLY_counter = NULL; \
|
||||
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
|
||||
int TH_TENSOR_DIM_APPLY_i; \
|
||||
\
|
||||
if( (DIMENSION < 0) || (DIMENSION >= TENSOR->nDimension) ) \
|
||||
THError("invalid dimension"); \
|
||||
\
|
||||
TENSOR##_data = (TENSOR)->storage->data+(TENSOR)->storageOffset; \
|
||||
TENSOR##_stride = (TENSOR)->stride[DIMENSION]; \
|
||||
TENSOR##_size = TENSOR->size[DIMENSION]; \
|
||||
TH_TENSOR_DIM_APPLY_counter = (long*)THAlloc(sizeof(long)*(TENSOR->nDimension)); \
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->nDimension; TH_TENSOR_DIM_APPLY_i++) \
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
|
||||
\
|
||||
while(!TH_TENSOR_DIM_APPLY_hasFinished) \
|
||||
{ \
|
||||
CODE \
|
||||
\
|
||||
if(TENSOR->nDimension == 1) \
|
||||
break; \
|
||||
\
|
||||
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->nDimension; TH_TENSOR_DIM_APPLY_i++) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == TENSOR->nDimension-1) \
|
||||
{ \
|
||||
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
continue; \
|
||||
} \
|
||||
\
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \
|
||||
TENSOR##_data += TENSOR->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
\
|
||||
if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR->size[TH_TENSOR_DIM_APPLY_i]) \
|
||||
{ \
|
||||
if(TH_TENSOR_DIM_APPLY_i == TENSOR->nDimension-1) \
|
||||
{ \
|
||||
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TENSOR##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR->stride[TH_TENSOR_DIM_APPLY_i]; \
|
||||
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
THFree(TH_TENSOR_DIM_APPLY_counter); \
|
||||
}
|
||||
|
||||
#endif
|
||||
30
lib/TH/THTensorMacros.h
Normal file
30
lib/TH/THTensorMacros.h
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
#ifndef TH_TENSOR_MACROS_INC
|
||||
#define TH_TENSOR_MACROS_INC
|
||||
|
||||
/* fast method to access to tensor data */
|
||||
|
||||
#define THTensor_fastGet1d(self, x0) \
|
||||
(((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]])
|
||||
|
||||
#define THTensor_fastGet2d(self, x0, x1) \
|
||||
(((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]])
|
||||
|
||||
#define THTensor_fastGet3d(self, x0, x1, x2) \
|
||||
(((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]])
|
||||
|
||||
#define THTensor_fastGet4d(self, x0, x1, x2, x3) \
|
||||
(((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]+(x3)*(self)->stride[3]])
|
||||
|
||||
#define THTensor_fastSet1d(self, x0, value) \
|
||||
(((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]] = value)
|
||||
|
||||
#define THTensor_fastSet2d(self, x0, x1, value) \
|
||||
(((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]] = value)
|
||||
|
||||
#define THTensor_fastSet3d(self, x0, x1, x2, value) \
|
||||
(((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]] = value)
|
||||
|
||||
#define THTensor_fastSet4d(self, x0, x1, x2, x3, value) \
|
||||
(((self)->storage->data+(self)->storageOffset)[(x0)*(self)->stride[0]+(x1)*(self)->stride[1]+(x2)*(self)->stride[2]+(x3)*(self)->stride[3]] = value)
|
||||
|
||||
#endif
|
||||
240
lib/TH/THVector.h
Normal file
240
lib/TH/THVector.h
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
#ifndef TH_VECTOR_INC
|
||||
#define TH_VECTOR_INC
|
||||
|
||||
#include "THGeneral.h"
|
||||
|
||||
#define THVector_(NAME) TH_CONCAT_4(TH,Real,Vector_,NAME)
|
||||
|
||||
#if defined __SSE2__ || defined __SSE3__ || defined __SSSE3__ \
|
||||
|| defined __SSE4_1__ || defined __SSE4_2__
|
||||
|
||||
#ifdef __SSE2__
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
|
||||
#ifdef __SSE3__
|
||||
#include <pmmintrin.h>
|
||||
#endif
|
||||
|
||||
#ifdef __SSSE3__
|
||||
#include <tmmintrin.h>
|
||||
#endif
|
||||
|
||||
#if defined (__SSE4_2__) || defined (__SSE4_1__)
|
||||
#include <smmintrin.h>
|
||||
#endif
|
||||
|
||||
#define THDoubleVector_fill(x, c, n) { \
|
||||
long i; \
|
||||
__m128d XMM0 = _mm_set1_pd(c); \
|
||||
for (i=0; i<=((n)-8); i+=8) { \
|
||||
_mm_storeu_pd((x)+i , XMM0); \
|
||||
_mm_storeu_pd((x)+i+2, XMM0); \
|
||||
_mm_storeu_pd((x)+i+4, XMM0); \
|
||||
_mm_storeu_pd((x)+i+6, XMM0); \
|
||||
} \
|
||||
long off = (n) - ((n)%8); \
|
||||
for (i=0; i<((n)%8); i++) { \
|
||||
x[off+i] = c; \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
#define THDoubleVector_add(y, x, c, n) { \
|
||||
long i = 0; \
|
||||
__m128d XMM7 = _mm_set1_pd(c); \
|
||||
__m128d XMM0,XMM2; \
|
||||
for (; i<=((n)-2); i+=2) { \
|
||||
XMM0 = _mm_loadu_pd((x)+i); \
|
||||
XMM2 = _mm_loadu_pd((y)+i); \
|
||||
XMM0 = _mm_mul_pd(XMM0, XMM7); \
|
||||
XMM2 = _mm_add_pd(XMM2, XMM0); \
|
||||
_mm_storeu_pd((y)+i , XMM2); \
|
||||
} \
|
||||
for (; i<(n); i++) { \
|
||||
y[i] += c * x[i]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THDoubleVector_diff(z, x, y, n) { \
|
||||
long i; \
|
||||
for (i=0; i<=((n)-8); i+=8) { \
|
||||
__m128d XMM0 = _mm_loadu_pd((x)+i ); \
|
||||
__m128d XMM1 = _mm_loadu_pd((x)+i+2); \
|
||||
__m128d XMM2 = _mm_loadu_pd((x)+i+4); \
|
||||
__m128d XMM3 = _mm_loadu_pd((x)+i+6); \
|
||||
__m128d XMM4 = _mm_loadu_pd((y)+i ); \
|
||||
__m128d XMM5 = _mm_loadu_pd((y)+i+2); \
|
||||
__m128d XMM6 = _mm_loadu_pd((y)+i+4); \
|
||||
__m128d XMM7 = _mm_loadu_pd((y)+i+6); \
|
||||
XMM0 = _mm_sub_pd(XMM0, XMM4); \
|
||||
XMM1 = _mm_sub_pd(XMM1, XMM5); \
|
||||
XMM2 = _mm_sub_pd(XMM2, XMM6); \
|
||||
XMM3 = _mm_sub_pd(XMM3, XMM7); \
|
||||
_mm_storeu_pd((z)+i , XMM0); \
|
||||
_mm_storeu_pd((z)+i+2, XMM1); \
|
||||
_mm_storeu_pd((z)+i+4, XMM2); \
|
||||
_mm_storeu_pd((z)+i+6, XMM3); \
|
||||
} \
|
||||
long off = (n) - ((n)%8); \
|
||||
for (i=0; i<((n)%8); i++) { \
|
||||
z[off+i] = x[off+i] - y[off+i]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THDoubleVector_scale(y, c, n) { \
|
||||
long i; \
|
||||
__m128d XMM7 = _mm_set1_pd(c); \
|
||||
for (i=0; i<=((n)-4); i+=4) { \
|
||||
__m128d XMM0 = _mm_loadu_pd((y)+i ); \
|
||||
__m128d XMM1 = _mm_loadu_pd((y)+i+2); \
|
||||
XMM0 = _mm_mul_pd(XMM0, XMM7); \
|
||||
XMM1 = _mm_mul_pd(XMM1, XMM7); \
|
||||
_mm_storeu_pd((y)+i , XMM0); \
|
||||
_mm_storeu_pd((y)+i+2, XMM1); \
|
||||
} \
|
||||
long off = (n) - ((n)%4); \
|
||||
for (i=0; i<((n)%4); i++) { \
|
||||
y[off+i] *= c; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THDoubleVector_mul(y, x, n) { \
|
||||
long i; \
|
||||
for (i=0; i<=((n)-8); i+=8) { \
|
||||
__m128d XMM0 = _mm_loadu_pd((x)+i ); \
|
||||
__m128d XMM1 = _mm_loadu_pd((x)+i+2); \
|
||||
__m128d XMM2 = _mm_loadu_pd((x)+i+4); \
|
||||
__m128d XMM3 = _mm_loadu_pd((x)+i+6); \
|
||||
__m128d XMM4 = _mm_loadu_pd((y)+i ); \
|
||||
__m128d XMM5 = _mm_loadu_pd((y)+i+2); \
|
||||
__m128d XMM6 = _mm_loadu_pd((y)+i+4); \
|
||||
__m128d XMM7 = _mm_loadu_pd((y)+i+6); \
|
||||
XMM4 = _mm_mul_pd(XMM4, XMM0); \
|
||||
XMM5 = _mm_mul_pd(XMM5, XMM1); \
|
||||
XMM6 = _mm_mul_pd(XMM6, XMM2); \
|
||||
XMM7 = _mm_mul_pd(XMM7, XMM3); \
|
||||
_mm_storeu_pd((y)+i , XMM4); \
|
||||
_mm_storeu_pd((y)+i+2, XMM5); \
|
||||
_mm_storeu_pd((y)+i+4, XMM6); \
|
||||
_mm_storeu_pd((y)+i+6, XMM7); \
|
||||
} \
|
||||
long off = (n) - ((n)%8); \
|
||||
for (i=0; i<((n)%8); i++) { \
|
||||
y[off+i] *= x[off+i]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THFloatVector_fill(x, c, n) { \
|
||||
long i; \
|
||||
__m128 XMM0 = _mm_set_ps1(c); \
|
||||
for (i=0; i<=((n)-16); i+=16) { \
|
||||
_mm_storeu_ps((x)+i , XMM0); \
|
||||
_mm_storeu_ps((x)+i+4, XMM0); \
|
||||
_mm_storeu_ps((x)+i+8, XMM0); \
|
||||
_mm_storeu_ps((x)+i+12, XMM0); \
|
||||
} \
|
||||
long off = (n) - ((n)%16); \
|
||||
for (i=0; i<((n)%16); i++) { \
|
||||
x[off+i] = c; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THFloatVector_add(y, x, c, n) { \
|
||||
long i = 0; \
|
||||
__m128 XMM7 = _mm_set_ps1(c); \
|
||||
__m128 XMM0,XMM2; \
|
||||
for (; i<=((n)-4); i+=4) { \
|
||||
XMM0 = _mm_loadu_ps((x)+i); \
|
||||
XMM2 = _mm_loadu_ps((y)+i); \
|
||||
XMM0 = _mm_mul_ps(XMM0, XMM7); \
|
||||
XMM2 = _mm_add_ps(XMM2, XMM0); \
|
||||
_mm_storeu_ps((y)+i , XMM2); \
|
||||
} \
|
||||
for (; i<(n); i++) { \
|
||||
y[i] += c * x[i]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THFloatVector_diff(z, x, y, n) { \
|
||||
long i; \
|
||||
for (i=0; i<=((n)-16); i+=16) { \
|
||||
__m128 XMM0 = _mm_loadu_ps((x)+i ); \
|
||||
__m128 XMM1 = _mm_loadu_ps((x)+i+ 4); \
|
||||
__m128 XMM2 = _mm_loadu_ps((x)+i+ 8); \
|
||||
__m128 XMM3 = _mm_loadu_ps((x)+i+12); \
|
||||
__m128 XMM4 = _mm_loadu_ps((y)+i ); \
|
||||
__m128 XMM5 = _mm_loadu_ps((y)+i+ 4); \
|
||||
__m128 XMM6 = _mm_loadu_ps((y)+i+ 8); \
|
||||
__m128 XMM7 = _mm_loadu_ps((y)+i+12); \
|
||||
XMM0 = _mm_sub_ps(XMM0, XMM4); \
|
||||
XMM1 = _mm_sub_ps(XMM1, XMM5); \
|
||||
XMM2 = _mm_sub_ps(XMM2, XMM6); \
|
||||
XMM3 = _mm_sub_ps(XMM3, XMM7); \
|
||||
_mm_storeu_ps((z)+i , XMM0); \
|
||||
_mm_storeu_ps((z)+i+ 4, XMM1); \
|
||||
_mm_storeu_ps((z)+i+ 8, XMM2); \
|
||||
_mm_storeu_ps((z)+i+12, XMM3); \
|
||||
} \
|
||||
long off = (n) - ((n)%16); \
|
||||
for (i=0; i<((n)%16); i++) { \
|
||||
z[off+i] = x[off+i] - y[off+i]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THFloatVector_scale(y, c, n) { \
|
||||
long i; \
|
||||
__m128 XMM7 = _mm_set_ps1(c); \
|
||||
for (i=0; i<=((n)-8); i+=8) { \
|
||||
__m128 XMM0 = _mm_loadu_ps((y)+i ); \
|
||||
__m128 XMM1 = _mm_loadu_ps((y)+i+4); \
|
||||
XMM0 = _mm_mul_ps(XMM0, XMM7); \
|
||||
XMM1 = _mm_mul_ps(XMM1, XMM7); \
|
||||
_mm_storeu_ps((y)+i , XMM0); \
|
||||
_mm_storeu_ps((y)+i+4, XMM1); \
|
||||
} \
|
||||
long off = (n) - ((n)%8); \
|
||||
for (i=0; i<((n)%8); i++) { \
|
||||
y[off+i] *= c; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THFloatVector_mul(y, x, n) { \
|
||||
long i; \
|
||||
for (i=0; i<=((n)-16); i+=16) { \
|
||||
__m128 XMM0 = _mm_loadu_ps((x)+i ); \
|
||||
__m128 XMM1 = _mm_loadu_ps((x)+i+ 4); \
|
||||
__m128 XMM2 = _mm_loadu_ps((x)+i+ 8); \
|
||||
__m128 XMM3 = _mm_loadu_ps((x)+i+12); \
|
||||
__m128 XMM4 = _mm_loadu_ps((y)+i ); \
|
||||
__m128 XMM5 = _mm_loadu_ps((y)+i+ 4); \
|
||||
__m128 XMM6 = _mm_loadu_ps((y)+i+ 8); \
|
||||
__m128 XMM7 = _mm_loadu_ps((y)+i+12); \
|
||||
XMM4 = _mm_mul_ps(XMM4, XMM0); \
|
||||
XMM5 = _mm_mul_ps(XMM5, XMM1); \
|
||||
XMM6 = _mm_mul_ps(XMM6, XMM2); \
|
||||
XMM7 = _mm_mul_ps(XMM7, XMM3); \
|
||||
_mm_storeu_ps((y)+i , XMM4); \
|
||||
_mm_storeu_ps((y)+i+ 4, XMM5); \
|
||||
_mm_storeu_ps((y)+i+ 8, XMM6); \
|
||||
_mm_storeu_ps((y)+i+12, XMM7); \
|
||||
} \
|
||||
long off = (n) - ((n)%16); \
|
||||
for (i=0; i<((n)%16); i++) { \
|
||||
y[off+i] *= x[off+i]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
/* If SSE2 not defined, then generate plain C operators */
|
||||
#include "generic/THVector.c"
|
||||
#include "THGenerateFloatTypes.h"
|
||||
|
||||
#endif
|
||||
|
||||
/* For non-float types, generate plain C operators */
|
||||
#include "generic/THVector.c"
|
||||
#include "THGenerateIntTypes.h"
|
||||
|
||||
#endif
|
||||
212
lib/TH/cmake/FindBLAS.cmake
Normal file
212
lib/TH/cmake/FindBLAS.cmake
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
# - Find BLAS library
|
||||
# This module finds an installed fortran library that implements the BLAS
|
||||
# linear-algebra interface (see http://www.netlib.org/blas/).
|
||||
# The list of libraries searched for is taken
|
||||
# from the autoconf macro file, acx_blas.m4 (distributed at
|
||||
# http://ac-archive.sourceforge.net/ac-archive/acx_blas.html).
|
||||
#
|
||||
# This module sets the following variables:
|
||||
# BLAS_FOUND - set to true if a library implementing the BLAS interface is found.
|
||||
# BLAS_INFO - name of the detected BLAS library.
|
||||
# BLAS_F2C - set to true if following the f2c return convention
|
||||
# BLAS_LIBRARIES - list of libraries to link against to use BLAS
|
||||
# BLAS_INCLUDE_DIR - include directory
|
||||
|
||||
SET(BLAS_LIBRARIES)
|
||||
SET(BLAS_INCLUDE_DIR)
|
||||
SET(BLAS_INFO)
|
||||
SET(BLAS_F2C)
|
||||
|
||||
# CBLAS in Intel mkl
|
||||
FIND_PACKAGE(MKL)
|
||||
IF (MKL_FOUND AND NOT BLAS_LIBRARIES)
|
||||
SET(BLAS_INFO imkl)
|
||||
SET(BLAS_LIBRARIES ${MKL_LIBRARIES})
|
||||
SET(BLAS_INCLUDE_DIR ${MKL_INCLUDE_DIR})
|
||||
SET(BLAS_VERSION ${MKL_VERSION})
|
||||
ENDIF (MKL_FOUND AND NOT BLAS_LIBRARIES)
|
||||
|
||||
# Old FindBlas
|
||||
INCLUDE(CheckCSourceRuns)
|
||||
INCLUDE(CheckFortranFunctionExists)
|
||||
SET(_verbose TRUE)
|
||||
|
||||
MACRO(Check_Fortran_Libraries LIBRARIES _prefix _name _flags _list)
|
||||
# This macro checks for the existence of the combination of fortran libraries
|
||||
# given by _list. If the combination is found, this macro checks (using the
|
||||
# Check_Fortran_Function_Exists macro) whether can link against that library
|
||||
# combination using the name of a routine given by _name using the linker
|
||||
# flags given by _flags. If the combination of libraries is found and passes
|
||||
# the link test, LIBRARIES is set to the list of complete library paths that
|
||||
# have been found. Otherwise, LIBRARIES is set to NOTFOUND.
|
||||
# N.B. _prefix is the prefix applied to the names of all cached variables that
|
||||
# are generated internally and marked advanced by this macro.
|
||||
|
||||
set(__list)
|
||||
foreach(_elem ${_list})
|
||||
if(__list)
|
||||
set(__list "${__list} - ${_elem}")
|
||||
else(__list)
|
||||
set(__list "${_elem}")
|
||||
endif(__list)
|
||||
endforeach(_elem)
|
||||
if(_verbose)
|
||||
message(STATUS "Checking for [${__list}]")
|
||||
endif(_verbose)
|
||||
|
||||
set(_libraries_work TRUE)
|
||||
set(${LIBRARIES})
|
||||
set(_combined_name)
|
||||
foreach(_library ${_list})
|
||||
set(_combined_name ${_combined_name}_${_library})
|
||||
if(_libraries_work)
|
||||
if ( WIN32 )
|
||||
find_library(${_prefix}_${_library}_LIBRARY
|
||||
NAMES ${_library}
|
||||
PATHS ENV LIB
|
||||
PATHS ENV PATH )
|
||||
endif ( WIN32 )
|
||||
if ( APPLE )
|
||||
find_library(${_prefix}_${_library}_LIBRARY
|
||||
NAMES ${_library}
|
||||
PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64
|
||||
ENV DYLD_LIBRARY_PATH )
|
||||
else ( APPLE )
|
||||
find_library(${_prefix}_${_library}_LIBRARY
|
||||
NAMES ${_library}
|
||||
PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64
|
||||
ENV LD_LIBRARY_PATH )
|
||||
endif( APPLE )
|
||||
mark_as_advanced(${_prefix}_${_library}_LIBRARY)
|
||||
set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY})
|
||||
set(_libraries_work ${${_prefix}_${_library}_LIBRARY})
|
||||
endif(_libraries_work)
|
||||
endforeach(_library ${_list})
|
||||
if(_libraries_work)
|
||||
# Test this combination of libraries.
|
||||
set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}})
|
||||
if (CMAKE_Fortran_COMPILER_WORKS)
|
||||
check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS)
|
||||
else (CMAKE_Fortran_COMPILER_WORKS)
|
||||
check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS)
|
||||
endif (CMAKE_Fortran_COMPILER_WORKS)
|
||||
set(CMAKE_REQUIRED_LIBRARIES)
|
||||
mark_as_advanced(${_prefix}${_combined_name}_WORKS)
|
||||
set(_libraries_work ${${_prefix}${_combined_name}_WORKS})
|
||||
endif(_libraries_work)
|
||||
if(NOT _libraries_work)
|
||||
set(${LIBRARIES} NOTFOUND)
|
||||
endif(NOT _libraries_work)
|
||||
endmacro(Check_Fortran_Libraries)
|
||||
|
||||
|
||||
# Apple BLAS library?
|
||||
if(NOT BLAS_LIBRARIES)
|
||||
check_fortran_libraries(
|
||||
BLAS_LIBRARIES
|
||||
BLAS
|
||||
sgemm
|
||||
""
|
||||
"Accelerate")
|
||||
if (BLAS_LIBRARIES)
|
||||
set(BLAS_INFO "accelerate")
|
||||
endif (BLAS_LIBRARIES)
|
||||
endif(NOT BLAS_LIBRARIES)
|
||||
if ( NOT BLAS_LIBRARIES )
|
||||
check_fortran_libraries(
|
||||
BLAS_LIBRARIES
|
||||
BLAS
|
||||
sgemm
|
||||
""
|
||||
"vecLib")
|
||||
if (BLAS_LIBRARIES)
|
||||
set(BLAS_INFO "veclib")
|
||||
endif (BLAS_LIBRARIES)
|
||||
endif ( NOT BLAS_LIBRARIES )
|
||||
|
||||
# BLAS in ATLAS library? (http://math-atlas.sourceforge.net/)
|
||||
if(NOT BLAS_LIBRARIES)
|
||||
check_fortran_libraries(
|
||||
BLAS_LIBRARIES
|
||||
BLAS
|
||||
sgemm
|
||||
""
|
||||
"cblas;f77blas;atlas")
|
||||
if (BLAS_LIBRARIES)
|
||||
set(BLAS_INFO "atlas")
|
||||
endif (BLAS_LIBRARIES)
|
||||
endif(NOT BLAS_LIBRARIES)
|
||||
|
||||
# Generic BLAS library?
|
||||
if(NOT BLAS_LIBRARIES)
|
||||
check_fortran_libraries(
|
||||
BLAS_LIBRARIES
|
||||
BLAS
|
||||
sgemm
|
||||
""
|
||||
"blas")
|
||||
if (BLAS_LIBRARIES)
|
||||
set(BLAS_INFO "generic")
|
||||
endif (BLAS_LIBRARIES)
|
||||
endif(NOT BLAS_LIBRARIES)
|
||||
|
||||
# Determine if blas was compiled with the f2c conventions
|
||||
IF (BLAS_LIBRARIES)
|
||||
SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
|
||||
CHECK_C_SOURCE_RUNS("
|
||||
#include <stdio.h>
|
||||
float x[4] = { 1, 2, 3, 4 };
|
||||
float y[4] = { .1, .01, .001, .0001 };
|
||||
int four = 4;
|
||||
int one = 1;
|
||||
extern double sdot_();
|
||||
int main() {
|
||||
int i;
|
||||
double r = sdot_(&four, x, &one, y, &one);
|
||||
exit((float)r != (float).1234);
|
||||
}" BLAS_F2C_DOUBLE_WORKS )
|
||||
CHECK_C_SOURCE_RUNS("
|
||||
#include <stdio.h>
|
||||
float x[4] = { 1, 2, 3, 4 };
|
||||
float y[4] = { .1, .01, .001, .0001 };
|
||||
int four = 4;
|
||||
int one = 1;
|
||||
extern float sdot_();
|
||||
int main() {
|
||||
int i;
|
||||
double r = sdot_(&four, x, &one, y, &one);
|
||||
exit((float)r != (float).1234);
|
||||
}" BLAS_F2C_FLOAT_WORKS )
|
||||
IF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS)
|
||||
IF (_verbose)
|
||||
MESSAGE(STATUS "This BLAS uses the F2C return conventions")
|
||||
ENDIF(_verbose)
|
||||
SET(BLAS_F2C TRUE)
|
||||
ELSE (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS)
|
||||
SET(BLAS_F2C FALSE)
|
||||
ENDIF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS)
|
||||
ENDIF(BLAS_LIBRARIES)
|
||||
|
||||
# epilogue
|
||||
|
||||
if(BLAS_LIBRARIES)
|
||||
set(BLAS_FOUND TRUE)
|
||||
else(BLAS_LIBRARIES)
|
||||
set(BLAS_FOUND FALSE)
|
||||
endif(BLAS_LIBRARIES)
|
||||
|
||||
IF (NOT BLAS_FOUND AND BLAS_FIND_REQUIRED)
|
||||
message(FATAL_ERROR "Cannot find a library with BLAS API. Please specify library location.")
|
||||
ENDIF (NOT BLAS_FOUND AND BLAS_FIND_REQUIRED)
|
||||
IF(NOT BLAS_FIND_QUIETLY)
|
||||
IF(BLAS_FOUND)
|
||||
MESSAGE(STATUS "Found a library with BLAS API (${BLAS_INFO}).")
|
||||
ELSE(BLAS_FOUND)
|
||||
MESSAGE(STATUS "Cannot find a library with BLAS API. Not using BLAS.")
|
||||
ENDIF(BLAS_FOUND)
|
||||
ENDIF(NOT BLAS_FIND_QUIETLY)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
166
lib/TH/cmake/FindLAPACK.cmake
Normal file
166
lib/TH/cmake/FindLAPACK.cmake
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
# - Find LAPACK library
|
||||
# This module finds an installed fortran library that implements the LAPACK
|
||||
# linear-algebra interface (see http://www.netlib.org/lapack/).
|
||||
#
|
||||
# The approach follows that taken for the autoconf macro file, acx_lapack.m4
|
||||
# (distributed at http://ac-archive.sourceforge.net/ac-archive/acx_lapack.html).
|
||||
#
|
||||
# This module sets the following variables:
|
||||
# LAPACK_FOUND - set to true if a library implementing the LAPACK interface is found
|
||||
# LAPACK_LIBRARIES - list of libraries (using full path name) for LAPACK
|
||||
|
||||
SET(LAPACK_LIBRARIES)
|
||||
SET(LAPACK_INFO)
|
||||
|
||||
IF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED)
|
||||
FIND_PACKAGE(BLAS)
|
||||
ELSE(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED)
|
||||
FIND_PACKAGE(BLAS REQUIRED)
|
||||
ENDIF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED)
|
||||
|
||||
# LAPACK in Intel mkl
|
||||
IF (MKL_FOUND AND NOT LAPACK_LIBRARIES)
|
||||
SET(LAPACK_LIBRARIES ${MKL_LAPACK_LIBRARIES} ${MKL_LIBRARIES})
|
||||
SET(LAPACK_INCLUDE_DIR ${MKL_INCLUDE_DIR})
|
||||
SET(LAPACK_INFO "mkl")
|
||||
ENDIF (MKL_FOUND AND NOT LAPACK_LIBRARIES)
|
||||
|
||||
# Old search lapack script
|
||||
include(CheckFortranFunctionExists)
|
||||
|
||||
macro(Check_Lapack_Libraries LIBRARIES _prefix _name _flags _list _blas)
|
||||
# This macro checks for the existence of the combination of fortran libraries
|
||||
# given by _list. If the combination is found, this macro checks (using the
|
||||
# Check_Fortran_Function_Exists macro) whether can link against that library
|
||||
# combination using the name of a routine given by _name using the linker
|
||||
# flags given by _flags. If the combination of libraries is found and passes
|
||||
# the link test, LIBRARIES is set to the list of complete library paths that
|
||||
# have been found. Otherwise, LIBRARIES is set to FALSE.
|
||||
# N.B. _prefix is the prefix applied to the names of all cached variables that
|
||||
# are generated internally and marked advanced by this macro.
|
||||
set(_libraries_work TRUE)
|
||||
set(${LIBRARIES})
|
||||
set(_combined_name)
|
||||
foreach(_library ${_list})
|
||||
set(_combined_name ${_combined_name}_${_library})
|
||||
if(_libraries_work)
|
||||
if (WIN32)
|
||||
find_library(${_prefix}_${_library}_LIBRARY
|
||||
NAMES ${_library} PATHS ENV LIB PATHS ENV PATH)
|
||||
else (WIN32)
|
||||
if(APPLE)
|
||||
find_library(${_prefix}_${_library}_LIBRARY
|
||||
NAMES ${_library}
|
||||
PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64
|
||||
ENV DYLD_LIBRARY_PATH)
|
||||
else(APPLE)
|
||||
find_library(${_prefix}_${_library}_LIBRARY
|
||||
NAMES ${_library}
|
||||
PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64
|
||||
ENV LD_LIBRARY_PATH)
|
||||
endif(APPLE)
|
||||
endif(WIN32)
|
||||
mark_as_advanced(${_prefix}_${_library}_LIBRARY)
|
||||
set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY})
|
||||
set(_libraries_work ${${_prefix}_${_library}_LIBRARY})
|
||||
endif(_libraries_work)
|
||||
endforeach(_library ${_list})
|
||||
if(_libraries_work)
|
||||
# Test this combination of libraries.
|
||||
set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_blas})
|
||||
if (CMAKE_Fortran_COMPILER_WORKS)
|
||||
check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS)
|
||||
else (CMAKE_Fortran_COMPILER_WORKS)
|
||||
check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS)
|
||||
endif (CMAKE_Fortran_COMPILER_WORKS)
|
||||
set(CMAKE_REQUIRED_LIBRARIES)
|
||||
mark_as_advanced(${_prefix}${_combined_name}_WORKS)
|
||||
set(_libraries_work ${${_prefix}${_combined_name}_WORKS})
|
||||
endif(_libraries_work)
|
||||
if(NOT _libraries_work)
|
||||
set(${LIBRARIES} FALSE)
|
||||
endif(NOT _libraries_work)
|
||||
endmacro(Check_Lapack_Libraries)
|
||||
|
||||
|
||||
if(BLAS_FOUND)
|
||||
|
||||
#acml lapack
|
||||
if(NOT LAPACK_LIBRARIES)
|
||||
check_lapack_libraries(
|
||||
LAPACK_LIBRARIES
|
||||
LAPACK
|
||||
cheev
|
||||
""
|
||||
"acml"
|
||||
"${BLAS_LIBRARIES}"
|
||||
)
|
||||
if(LAPACK_LIBRARIES)
|
||||
SET(LAPACK_INFO "acml")
|
||||
endif(LAPACK_LIBRARIES)
|
||||
endif(NOT LAPACK_LIBRARIES)
|
||||
|
||||
# Apple LAPACK library?
|
||||
if(NOT LAPACK_LIBRARIES)
|
||||
check_lapack_libraries(
|
||||
LAPACK_LIBRARIES
|
||||
LAPACK
|
||||
cheev
|
||||
""
|
||||
"Accelerate"
|
||||
"${BLAS_LIBRARIES}"
|
||||
)
|
||||
if(LAPACK_LIBRARIES)
|
||||
SET(LAPACK_INFO "Accelerate")
|
||||
endif(LAPACK_LIBRARIES)
|
||||
endif(NOT LAPACK_LIBRARIES)
|
||||
|
||||
if ( NOT LAPACK_LIBRARIES )
|
||||
check_lapack_libraries(
|
||||
LAPACK_LIBRARIES
|
||||
LAPACK
|
||||
cheev
|
||||
""
|
||||
"vecLib"
|
||||
"${BLAS_LIBRARIES}"
|
||||
)
|
||||
if(LAPACK_LIBRARIES)
|
||||
SET(LAPACK_INFO "veclib")
|
||||
endif(LAPACK_LIBRARIES)
|
||||
endif ( NOT LAPACK_LIBRARIES )
|
||||
|
||||
# Generic LAPACK library?
|
||||
if ( NOT LAPACK_LIBRARIES )
|
||||
check_lapack_libraries(
|
||||
LAPACK_LIBRARIES
|
||||
LAPACK
|
||||
cheev
|
||||
""
|
||||
"lapack"
|
||||
"${BLAS_LIBRARIES}"
|
||||
)
|
||||
if(LAPACK_LIBRARIES)
|
||||
SET(LAPACK_INFO "generic")
|
||||
endif(LAPACK_LIBRARIES)
|
||||
endif ( NOT LAPACK_LIBRARIES )
|
||||
|
||||
else(BLAS_FOUND)
|
||||
message(STATUS "LAPACK requires BLAS")
|
||||
endif(BLAS_FOUND)
|
||||
|
||||
if(LAPACK_LIBRARIES)
|
||||
set(LAPACK_FOUND TRUE)
|
||||
else(LAPACK_LIBRARIES)
|
||||
set(LAPACK_FOUND FALSE)
|
||||
endif(LAPACK_LIBRARIES)
|
||||
|
||||
IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED)
|
||||
message(FATAL_ERROR "Cannot find a library with LAPACK API. Please specify library location.")
|
||||
ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED)
|
||||
IF(NOT LAPACK_FIND_QUIETLY)
|
||||
IF(LAPACK_FOUND)
|
||||
MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})")
|
||||
ELSE(LAPACK_FOUND)
|
||||
MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.")
|
||||
ENDIF(LAPACK_FOUND)
|
||||
ENDIF(NOT LAPACK_FIND_QUIETLY)
|
||||
274
lib/TH/cmake/FindMKL.cmake
Normal file
274
lib/TH/cmake/FindMKL.cmake
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
# - Find INTEL MKL library
|
||||
#
|
||||
# This module finds the Intel Mkl libraries.
|
||||
#
|
||||
# This module sets the following variables:
|
||||
# MKL_FOUND - set to true if a library implementing the CBLAS interface is found
|
||||
# MKL_VERSION - best guess
|
||||
# MKL_INCLUDE_DIR - path to include dir.
|
||||
# MKL_LIBRARIES - list of libraries for base mkl
|
||||
# MKL_LAPACK_LIBRARIES - list of libraries to add for lapack
|
||||
# MKL_SCALAPACK_LIBRARIES - list of libraries to add for scalapack
|
||||
# MKL_SOLVER_LIBRARIES - list of libraries to add for the solvers
|
||||
# MKL_CDFT_LIBRARIES - list of libraries to add for the solvers
|
||||
|
||||
|
||||
# Do nothing if MKL_FOUND was set before!
|
||||
IF (NOT MKL_FOUND)
|
||||
|
||||
SET(MKL_VERSION)
|
||||
SET(MKL_INCLUDE_DIR)
|
||||
SET(MKL_LIBRARIES)
|
||||
SET(MKL_LAPACK_LIBRARIES)
|
||||
SET(MKL_SCALAPACK_LIBRARIES)
|
||||
SET(MKL_SOLVER_LIBRARIES)
|
||||
SET(MKL_CDFT_LIBRARIES)
|
||||
|
||||
# Includes
|
||||
INCLUDE(CheckTypeSize)
|
||||
INCLUDE(CheckFunctionExists)
|
||||
|
||||
# Prints diagnostic
|
||||
# SET(_verbose TRUE)
|
||||
|
||||
# Intel Compiler Suite
|
||||
SET(INTEL_COMPILER_DIR CACHE STRING
|
||||
"Root directory of the Intel Compiler Suite (contains ipp, mkl, etc.)")
|
||||
SET(INTEL_MKL_DIR CACHE STRING
|
||||
"Root directory of the Intel MKL (standalone)")
|
||||
SET(INTEL_MKL_SEQUENTIAL OFF CACHE BOOL
|
||||
"Force using the sequential (non threaded) libraries")
|
||||
|
||||
# Checks
|
||||
CHECK_TYPE_SIZE("void*" SIZE_OF_VOIDP)
|
||||
IF ("${SIZE_OF_VOIDP}" EQUAL 8)
|
||||
SET(mklvers "em64t")
|
||||
SET(iccvers "intel64")
|
||||
SET(mkl64s "_lp64")
|
||||
ELSE ("${SIZE_OF_VOIDP}" EQUAL 8)
|
||||
SET(mklvers "32")
|
||||
SET(iccvers "ia32")
|
||||
SET(mkl64s)
|
||||
ENDIF ("${SIZE_OF_VOIDP}" EQUAL 8)
|
||||
IF (CMAKE_COMPILER_IS_GNUCC)
|
||||
SET(mklthreads "mkl_gnu_thread" "mkl_intel_thread")
|
||||
SET(mklifaces "gf" "intel")
|
||||
ELSE (CMAKE_COMPILER_IS_GNUCC)
|
||||
SET(mklthreads "mkl_intel_thread")
|
||||
SET(mklifaces "intel")
|
||||
ENDIF (CMAKE_COMPILER_IS_GNUCC)
|
||||
SET(mklrtls "iomp5" "guide")
|
||||
|
||||
# Kernel libraries dynamically loaded
|
||||
SET(mklkerlibs "mc" "mc3" "nc" "p4n" "p4m" "p4m3" "p4p" "def")
|
||||
SET(mklseq)
|
||||
|
||||
|
||||
|
||||
# Paths
|
||||
SET(saved_CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH})
|
||||
SET(saved_CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH})
|
||||
IF (INTEL_COMPILER_DIR)
|
||||
# TODO: diagnostic if dir does not exist
|
||||
SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}
|
||||
"${INTEL_COMPILER_DIR}/lib/${iccvers}")
|
||||
IF (NOT INTEL_MKL_DIR)
|
||||
SET(INTEL_MKL_DIR "${INTEL_COMPILER_DIR}/mkl")
|
||||
ENDIF (NOT INTEL_MKL_DIR)
|
||||
ENDIF (INTEL_COMPILER_DIR)
|
||||
IF (INTEL_MKL_DIR)
|
||||
# TODO: diagnostic if dir does not exist
|
||||
SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH}
|
||||
"${INTEL_MKL_DIR}/include")
|
||||
SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}
|
||||
"${INTEL_MKL_DIR}/lib/${mklvers}")
|
||||
ENDIF (INTEL_MKL_DIR)
|
||||
|
||||
# Try linking multiple libs
|
||||
MACRO(CHECK_ALL_LIBRARIES LIBRARIES _name _list _flags)
|
||||
# This macro checks for the existence of the combination of libraries given by _list.
|
||||
# If the combination is found, this macro whether we can link against that library
|
||||
# combination using the name of a routine given by _name using the linker
|
||||
# flags given by _flags. If the combination of libraries is found and passes
|
||||
# the link test, LIBRARIES is set to the list of complete library paths that
|
||||
# have been found. Otherwise, LIBRARIES is set to FALSE.
|
||||
# N.B. _prefix is the prefix applied to the names of all cached variables that
|
||||
# are generated internally and marked advanced by this macro.
|
||||
SET(_prefix "${LIBRARIES}")
|
||||
IF (_verbose)
|
||||
SET(__list)
|
||||
FOREACH(_elem ${_list})
|
||||
IF(__list)
|
||||
SET(__list "${__list} - ${_elem}")
|
||||
ELSE(__list)
|
||||
SET(__list "${_elem}")
|
||||
ENDIF(__list)
|
||||
ENDFOREACH(_elem)
|
||||
ENDIF(_verbose)
|
||||
# start checking
|
||||
SET(_libraries_work TRUE)
|
||||
SET(${LIBRARIES})
|
||||
SET(_combined_name)
|
||||
SET(_paths)
|
||||
FOREACH(_library ${_list})
|
||||
SET(_combined_name ${_combined_name}_${_library})
|
||||
IF(_libraries_work)
|
||||
FIND_LIBRARY(${_prefix}_${_library}_LIBRARY NAMES ${_library})
|
||||
MARK_AS_ADVANCED(${_prefix}_${_library}_LIBRARY)
|
||||
SET(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY})
|
||||
SET(_libraries_work ${${_prefix}_${_library}_LIBRARY})
|
||||
ENDIF(_libraries_work)
|
||||
ENDFOREACH(_library ${_list})
|
||||
# Test this combination of libraries.
|
||||
IF(_libraries_work)
|
||||
SET(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}})
|
||||
CHECK_FUNCTION_EXISTS(${_name} ${_prefix}${_combined_name}_WORKS)
|
||||
SET(CMAKE_REQUIRED_LIBRARIES)
|
||||
MARK_AS_ADVANCED(${_prefix}${_combined_name}_WORKS)
|
||||
SET(_libraries_work ${${_prefix}${_combined_name}_WORKS})
|
||||
ENDIF(_libraries_work)
|
||||
# Fin
|
||||
IF(_libraries_work)
|
||||
IF (_verbose)
|
||||
MESSAGE(STATUS "FindMKL: ${__list} : ok")
|
||||
ENDIF (_verbose)
|
||||
ELSE (_libraries_work)
|
||||
SET(${LIBRARIES})
|
||||
MARK_AS_ADVANCED(${LIBRARIES})
|
||||
IF (_verbose)
|
||||
MESSAGE(STATUS "FindMKL: ${__list} : no")
|
||||
ENDIF (_verbose)
|
||||
ENDIF(_libraries_work)
|
||||
ENDMACRO(CHECK_ALL_LIBRARIES)
|
||||
|
||||
|
||||
# Check for version 10/11
|
||||
IF (NOT MKL_LIBRARIES)
|
||||
SET(MKL_VERSION 1011)
|
||||
ENDIF (NOT MKL_LIBRARIES)
|
||||
FOREACH(mklrtl ${mklrtls})
|
||||
FOREACH(mkliface ${mklifaces})
|
||||
FOREACH(mkl64 ${mkl64s} "")
|
||||
FOREACH(mklthread ${mklthreads})
|
||||
IF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
|
||||
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
||||
"mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;m" "")
|
||||
ENDIF (NOT MKL_LIBRARIES AND NOT INTEL_MKL_SEQUENTIAL)
|
||||
ENDFOREACH(mklthread)
|
||||
ENDFOREACH(mkl64)
|
||||
ENDFOREACH(mkliface)
|
||||
ENDFOREACH(mklrtl)
|
||||
FOREACH(mklrtl ${mklrtls})
|
||||
FOREACH(mkliface ${mklifaces})
|
||||
FOREACH(mkl64 ${mkl64s} "")
|
||||
IF (NOT MKL_LIBRARIES)
|
||||
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
||||
"mkl_${mkliface}${mkl64};mkl_sequential;mkl_core;m" "")
|
||||
IF (MKL_LIBRARIES)
|
||||
SET(mklseq "_sequential")
|
||||
ENDIF (MKL_LIBRARIES)
|
||||
ENDIF (NOT MKL_LIBRARIES)
|
||||
ENDFOREACH(mkl64)
|
||||
ENDFOREACH(mkliface)
|
||||
ENDFOREACH(mklrtl)
|
||||
FOREACH(mklrtl ${mklrtls})
|
||||
FOREACH(mkliface ${mklifaces})
|
||||
FOREACH(mkl64 ${mkl64s} "")
|
||||
FOREACH(mklthread ${mklthreads})
|
||||
IF (NOT MKL_LIBRARIES)
|
||||
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
||||
"mkl_${mkliface}${mkl64};${mklthread};mkl_core;${mklrtl};pthread;m" "")
|
||||
ENDIF (NOT MKL_LIBRARIES)
|
||||
ENDFOREACH(mklthread)
|
||||
ENDFOREACH(mkl64)
|
||||
ENDFOREACH(mkliface)
|
||||
ENDFOREACH(mklrtl)
|
||||
|
||||
# Check for older versions
|
||||
IF (NOT MKL_LIBRARIES)
|
||||
SET(MKL_VERSION 900)
|
||||
CHECK_ALL_LIBRARIES(MKL_LIBRARIES cblas_sgemm
|
||||
"mkl;guide;pthread;m" "")
|
||||
ENDIF (NOT MKL_LIBRARIES)
|
||||
|
||||
# Include files
|
||||
IF (MKL_LIBRARIES)
|
||||
FIND_PATH(MKL_INCLUDE_DIR "mkl_cblas.h")
|
||||
MARK_AS_ADVANCED(MKL_INCLUDE_DIR)
|
||||
ENDIF (MKL_LIBRARIES)
|
||||
|
||||
# Other libraries
|
||||
IF (MKL_LIBRARIES)
|
||||
FOREACH(mkl64 ${mkl64s} "_core" "")
|
||||
FOREACH(mkls ${mklseq} "")
|
||||
IF (NOT MKL_LAPACK_LIBRARIES)
|
||||
FIND_LIBRARY(MKL_LAPACK_LIBRARIES NAMES "mkl_lapack${mkl64}${mkls}")
|
||||
MARK_AS_ADVANCED(MKL_LAPACK_LIBRARIES)
|
||||
ENDIF (NOT MKL_LAPACK_LIBRARIES)
|
||||
IF (NOT MKL_SCALAPACK_LIBRARIES)
|
||||
FIND_LIBRARY(MKL_SCALAPACK_LIBRARIES NAMES "mkl_scalapack${mkl64}${mkls}")
|
||||
MARK_AS_ADVANCED(MKL_SCALAPACK_LIBRARIES)
|
||||
ENDIF (NOT MKL_SCALAPACK_LIBRARIES)
|
||||
IF (NOT MKL_SOLVER_LIBRARIES)
|
||||
FIND_LIBRARY(MKL_SOLVER_LIBRARIES NAMES "mkl_solver${mkl64}${mkls}")
|
||||
MARK_AS_ADVANCED(MKL_SOLVER_LIBRARIES)
|
||||
ENDIF (NOT MKL_SOLVER_LIBRARIES)
|
||||
IF (NOT MKL_CDFT_LIBRARIES)
|
||||
FIND_LIBRARY(MKL_CDFT_LIBRARIES NAMES "mkl_cdft${mkl64}${mkls}")
|
||||
MARK_AS_ADVANCED(MKL_CDFT_LIBRARIES)
|
||||
ENDIF (NOT MKL_CDFT_LIBRARIES)
|
||||
ENDFOREACH(mkls)
|
||||
ENDFOREACH(mkl64)
|
||||
ENDIF (MKL_LIBRARIES)
|
||||
|
||||
# LibIRC: intel compiler always links this;
|
||||
# gcc does not; but mkl kernels sometimes need it.
|
||||
IF (MKL_LIBRARIES)
|
||||
IF (CMAKE_COMPILER_IS_GNUCC)
|
||||
FIND_LIBRARY(MKL_KERNEL_libirc "irc")
|
||||
ELSEIF (CMAKE_C_COMPILER_ID AND NOT CMAKE_C_COMPILER_ID STREQUAL "Intel")
|
||||
FIND_LIBRARY(MKL_KERNEL_libirc "irc")
|
||||
ENDIF (CMAKE_COMPILER_IS_GNUCC)
|
||||
MARK_AS_ADVANCED(MKL_KERNEL_libirc)
|
||||
IF (MKL_KERNEL_libirc)
|
||||
SET(MKL_LIBRARIES ${MKL_LIBRARIES} ${MKL_KERNEL_libirc})
|
||||
ENDIF (MKL_KERNEL_libirc)
|
||||
ENDIF (MKL_LIBRARIES)
|
||||
|
||||
# Final
|
||||
SET(CMAKE_LIBRARY_PATH ${saved_CMAKE_LIBRARY_PATH})
|
||||
SET(CMAKE_INCLUDE_PATH ${saved_CMAKE_INCLUDE_PATH})
|
||||
IF (MKL_LIBRARIES)
|
||||
SET(MKL_FOUND TRUE)
|
||||
ELSE (MKL_LIBRARIES)
|
||||
SET(MKL_FOUND FALSE)
|
||||
SET(MKL_VERSION)
|
||||
ENDIF (MKL_LIBRARIES)
|
||||
|
||||
# Results
|
||||
IF (_verbose)
|
||||
MESSAGE(STATUS "*** MKL_FOUND = ${MKL_FOUND}")
|
||||
MESSAGE(STATUS "*** MKL_INCLUDE_DIR = ${MKL_INCLUDE_DIR}")
|
||||
MESSAGE(STATUS "*** MKL_LIBRARIES = ${MKL_LIBRARIES}")
|
||||
MESSAGE(STATUS "*** MKL_LAPACK_LIBRARIES = ${MKL_LAPACK_LIBRARIES}")
|
||||
MESSAGE(STATUS "*** MKL_SCALAPACK_LIBRARIES = ${MKL_SCALAPACK_LIBRARIES}")
|
||||
MESSAGE(STATUS "*** MKL_SOLVER_LIBRARIES = ${MKL_SOLVER_LIBRARIES}")
|
||||
MESSAGE(STATUS "*** MKL_CDFT_LIBRARIES = ${MKL_CDFT_LIBRARIES}")
|
||||
ENDIF(_verbose)
|
||||
|
||||
# Standard termination
|
||||
IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
|
||||
MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location")
|
||||
ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED)
|
||||
IF(NOT MKL_FIND_QUIETLY)
|
||||
IF(MKL_FOUND)
|
||||
MESSAGE(STATUS "MKL library found")
|
||||
ELSE(MKL_FOUND)
|
||||
MESSAGE(STATUS "MKL library not found")
|
||||
ENDIF(MKL_FOUND)
|
||||
ENDIF(NOT MKL_FIND_QUIETLY)
|
||||
|
||||
# Do nothing if MKL_FOUND was set before!
|
||||
ENDIF (NOT MKL_FOUND)
|
||||
|
||||
|
||||
104
lib/TH/cmake/FindSSE.cmake
Normal file
104
lib/TH/cmake/FindSSE.cmake
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
# Check if SSE instructions are available on the machine where
|
||||
# the project is compiled.
|
||||
|
||||
IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
EXEC_PROGRAM(cat ARGS "/proc/cpuinfo" OUTPUT_VARIABLE CPUINFO)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(sse2).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "sse2" "${SSE_THERE}" SSE2_TRUE)
|
||||
IF (SSE2_TRUE)
|
||||
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
|
||||
ELSE (SSE2_TRUE)
|
||||
set(SSE2_FOUND false CACHE BOOL "SSE2 available on host")
|
||||
ENDIF (SSE2_TRUE)
|
||||
|
||||
# /proc/cpuinfo apparently omits sse3 :(
|
||||
STRING(REGEX REPLACE "^.*[^s](sse3).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "sse3" "${SSE_THERE}" SSE3_TRUE)
|
||||
IF (NOT SSE3_TRUE)
|
||||
STRING(REGEX REPLACE "^.*(T2300).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "T2300" "${SSE_THERE}" SSE3_TRUE)
|
||||
ENDIF (NOT SSE3_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(ssse3).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "ssse3" "${SSE_THERE}" SSSE3_TRUE)
|
||||
IF (SSE3_TRUE OR SSSE3_TRUE)
|
||||
set(SSE3_FOUND true CACHE BOOL "SSE3 available on host")
|
||||
ELSE (SSE3_TRUE OR SSSE3_TRUE)
|
||||
set(SSE3_FOUND false CACHE BOOL "SSE3 available on host")
|
||||
ENDIF (SSE3_TRUE OR SSSE3_TRUE)
|
||||
IF (SSSE3_TRUE)
|
||||
set(SSSE3_FOUND true CACHE BOOL "SSSE3 available on host")
|
||||
ELSE (SSSE3_TRUE)
|
||||
set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host")
|
||||
ENDIF (SSSE3_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(sse4_1).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "sse4_1" "${SSE_THERE}" SSE41_TRUE)
|
||||
IF (SSE41_TRUE)
|
||||
set(SSE4_1_FOUND true CACHE BOOL "SSE4.1 available on host")
|
||||
ELSE (SSE41_TRUE)
|
||||
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
|
||||
ENDIF (SSE41_TRUE)
|
||||
ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE
|
||||
CPUINFO)
|
||||
|
||||
STRING(REGEX REPLACE "^.*[^S](SSE2).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "SSE2" "${SSE_THERE}" SSE2_TRUE)
|
||||
IF (SSE2_TRUE)
|
||||
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
|
||||
ELSE (SSE2_TRUE)
|
||||
set(SSE2_FOUND false CACHE BOOL "SSE2 available on host")
|
||||
ENDIF (SSE2_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*[^S](SSE3).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "SSE3" "${SSE_THERE}" SSE3_TRUE)
|
||||
IF (SSE3_TRUE)
|
||||
set(SSE3_FOUND true CACHE BOOL "SSE3 available on host")
|
||||
ELSE (SSE3_TRUE)
|
||||
set(SSE3_FOUND false CACHE BOOL "SSE3 available on host")
|
||||
ENDIF (SSE3_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(SSSE3).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "SSSE3" "${SSE_THERE}" SSSE3_TRUE)
|
||||
IF (SSSE3_TRUE)
|
||||
set(SSSE3_FOUND true CACHE BOOL "SSSE3 available on host")
|
||||
ELSE (SSSE3_TRUE)
|
||||
set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host")
|
||||
ENDIF (SSSE3_TRUE)
|
||||
|
||||
STRING(REGEX REPLACE "^.*(SSE4.1).*$" "\\1" SSE_THERE ${CPUINFO})
|
||||
STRING(COMPARE EQUAL "SSE4.1" "${SSE_THERE}" SSE41_TRUE)
|
||||
IF (SSE41_TRUE)
|
||||
set(SSE4_1_FOUND true CACHE BOOL "SSE4.1 available on host")
|
||||
ELSE (SSE41_TRUE)
|
||||
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
|
||||
ENDIF (SSE41_TRUE)
|
||||
ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
# TODO
|
||||
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
|
||||
set(SSE3_FOUND false CACHE BOOL "SSE3 available on host")
|
||||
set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host")
|
||||
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
|
||||
ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
|
||||
set(SSE3_FOUND false CACHE BOOL "SSE3 available on host")
|
||||
set(SSSE3_FOUND false CACHE BOOL "SSSE3 available on host")
|
||||
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
|
||||
ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
|
||||
if(NOT SSE2_FOUND)
|
||||
MESSAGE(STATUS "Could not find hardware support for SSE2 on this machine.")
|
||||
endif(NOT SSE2_FOUND)
|
||||
if(NOT SSE3_FOUND)
|
||||
MESSAGE(STATUS "Could not find hardware support for SSE3 on this machine.")
|
||||
endif(NOT SSE3_FOUND)
|
||||
if(NOT SSSE3_FOUND)
|
||||
MESSAGE(STATUS "Could not find hardware support for SSSE3 on this machine.")
|
||||
endif(NOT SSSE3_FOUND)
|
||||
if(NOT SSE4_1_FOUND)
|
||||
MESSAGE(STATUS "Could not find hardware support for SSE4.1 on this machine.")
|
||||
endif(NOT SSE4_1_FOUND)
|
||||
|
||||
mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND)
|
||||
382
lib/TH/generic/THBlas.c
Normal file
382
lib/TH/generic/THBlas.c
Normal file
|
|
@ -0,0 +1,382 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THBlas.c"
|
||||
#else
|
||||
|
||||
void THBlas_(swap)(long n, real *x, long incx, real *y, long incy)
|
||||
{
|
||||
if(n == 1)
|
||||
{
|
||||
incx = 1;
|
||||
incy = 1;
|
||||
}
|
||||
|
||||
#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dswap_(int *n, double *x, int *incx, double *y, int *incy);
|
||||
dswap_(&i_n, x, &i_incx, y, &i_incy);
|
||||
#else
|
||||
extern void sswap_(int *n, float *x, int *incx, float *y, int *incy);
|
||||
sswap_(&i_n, x, &i_incx, y, &i_incy);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
long i;
|
||||
for(i = 0; i < n; i++)
|
||||
{
|
||||
real z = x[i*incx];
|
||||
x[i*incx] = y[i*incy];
|
||||
y[i*incy] = z;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void THBlas_(scal)(long n, real a, real *x, long incx)
|
||||
{
|
||||
if(n == 1)
|
||||
incx = 1;
|
||||
|
||||
#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dscal_(int *n, double *a, double *x, int *incx);
|
||||
dscal_(&i_n, &a, x, &i_incx);
|
||||
#else
|
||||
extern void sscal_(int *n, float *a, float *x, int *incx);
|
||||
sscal_(&i_n, &a, x, &i_incx);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
long i;
|
||||
for(i = 0; i < n; i++)
|
||||
x[i*incx] *= a;
|
||||
}
|
||||
}
|
||||
|
||||
void THBlas_(copy)(long n, real *x, long incx, real *y, long incy)
|
||||
{
|
||||
if(n == 1)
|
||||
{
|
||||
incx = 1;
|
||||
incy = 1;
|
||||
}
|
||||
|
||||
#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dcopy_(int *n, double *x, int *incx, double *y, int *incy);
|
||||
dcopy_(&i_n, x, &i_incx, y, &i_incy);
|
||||
#else
|
||||
extern void scopy_(int *n, float *x, int *incx, float *y, int *incy);
|
||||
scopy_(&i_n, x, &i_incx, y, &i_incy);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
long i;
|
||||
for(i = 0; i < n; i++)
|
||||
y[i*incy] = x[i*incx];
|
||||
}
|
||||
}
|
||||
|
||||
void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy)
|
||||
{
|
||||
if(n == 1)
|
||||
{
|
||||
incx = 1;
|
||||
incy = 1;
|
||||
}
|
||||
|
||||
#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy);
|
||||
daxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
|
||||
#else
|
||||
extern void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy);
|
||||
saxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
long i;
|
||||
for(i = 0; i < n; i++)
|
||||
y[i*incy] += a*x[i*incx];
|
||||
}
|
||||
}
|
||||
|
||||
real THBlas_(dot)(long n, real *x, long incx, real *y, long incy)
|
||||
{
|
||||
if(n == 1)
|
||||
{
|
||||
incx = 1;
|
||||
incy = 1;
|
||||
}
|
||||
|
||||
#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern double ddot_(int *n, double *x, int *incx, double *y, int *incy);
|
||||
return ddot_(&i_n, x, &i_incx, y, &i_incy);
|
||||
#else
|
||||
extern float sdot_(int *n, float *x, int *incx, float *y, int *incy);
|
||||
return sdot_(&i_n, x, &i_incx, y, &i_incy);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
{
|
||||
long i;
|
||||
real sum = 0;
|
||||
for(i = 0; i < n; i++)
|
||||
sum += x[i*incx]*y[i*incy];
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
|
||||
void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, real *x, long incx, real beta, real *y, long incy)
|
||||
{
|
||||
if(n == 1)
|
||||
lda = m;
|
||||
|
||||
#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
|
||||
if( (m <= INT_MAX) && (n <= INT_MAX) &&
|
||||
(lda > 0) && (lda <= INT_MAX) &&
|
||||
(incx > 0) && (incx <= INT_MAX) &&
|
||||
(incy > 0) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_m = (int)m;
|
||||
int i_n = (int)n;
|
||||
int i_lda = (int)lda;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy);
|
||||
dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
|
||||
#else
|
||||
extern void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy);
|
||||
sgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
long i, j;
|
||||
|
||||
if( (trans == 'T') || (trans == 't') )
|
||||
{
|
||||
for(i = 0; i < n; i++)
|
||||
{
|
||||
real sum = 0;
|
||||
real *row_ = a+lda*i;
|
||||
for(j = 0; j < m; j++)
|
||||
sum += x[j*incx]*row_[j];
|
||||
y[i*incy] = beta*y[i*incy] + alpha*sum;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(beta != 1)
|
||||
THBlas_(scal)(m, beta, y, incy);
|
||||
|
||||
for(j = 0; j < n; j++)
|
||||
{
|
||||
real *column_ = a+lda*j;
|
||||
real z = alpha*x[j*incx];
|
||||
for(i = 0; i < m; i++)
|
||||
y[i*incy] += z*column_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long incy, real *a, long lda)
|
||||
{
|
||||
if(n == 1)
|
||||
lda = m;
|
||||
|
||||
#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
|
||||
if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_m = (int)m;
|
||||
int i_n = (int)n;
|
||||
int i_lda = (int)lda;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dger_(int *m, int *n, double *alpha, double *x, int *incx, real *y, int *incy, double *a, int *lda);
|
||||
dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
|
||||
#else
|
||||
extern void sger_(int *m, int *n, float *alpha, float *x, int *incx, real *y, int *incy, float *a, int *lda);
|
||||
sger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
long i, j;
|
||||
for(j = 0; j < n; j++)
|
||||
{
|
||||
real *column_ = a+j*lda;
|
||||
real z = alpha*y[j*incy];
|
||||
for(i = 0; i < m; i++)
|
||||
column_[i] += z*x[i*incx] ;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc)
|
||||
{
|
||||
int transa_ = ((transa == 't') || (transa == 'T'));
|
||||
int transb_ = ((transb == 't') || (transb == 'T'));
|
||||
|
||||
if(n == 1)
|
||||
ldc = m;
|
||||
|
||||
if(transa_)
|
||||
{
|
||||
if(m == 1)
|
||||
lda = k;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(k == 1)
|
||||
lda = m;
|
||||
}
|
||||
|
||||
if(transb_)
|
||||
{
|
||||
if(k == 1)
|
||||
ldb = n;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(n == 1)
|
||||
ldb = k;
|
||||
}
|
||||
|
||||
#if defined(USE_LAPACK) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
|
||||
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
|
||||
{
|
||||
int i_m = (int)m;
|
||||
int i_n = (int)n;
|
||||
int i_k = (int)k;
|
||||
int i_lda = (int)lda;
|
||||
int i_ldb = (int)ldb;
|
||||
int i_ldc = (int)ldc;
|
||||
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc);
|
||||
dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
|
||||
#else
|
||||
extern void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc);
|
||||
sgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
{
|
||||
long i, j, l;
|
||||
if(!transa_ && !transb_)
|
||||
{
|
||||
real *a_ = a;
|
||||
for(i = 0; i < m; i++)
|
||||
{
|
||||
real *b_ = b;
|
||||
for(j = 0; j < n; j++)
|
||||
{
|
||||
real sum = 0;
|
||||
for(l = 0; l < k; l++)
|
||||
sum += a_[l*lda]*b_[l];
|
||||
b_ += ldb;
|
||||
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
|
||||
}
|
||||
a_++;
|
||||
}
|
||||
}
|
||||
else if(transa_ && !transb_)
|
||||
{
|
||||
real *a_ = a;
|
||||
for(i = 0; i < m; i++)
|
||||
{
|
||||
real *b_ = b;
|
||||
for(j = 0; j < n; j++)
|
||||
{
|
||||
real sum = 0;
|
||||
for(l = 0; l < k; l++)
|
||||
sum += a_[l]*b_[l];
|
||||
b_ += ldb;
|
||||
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
|
||||
}
|
||||
a_ += lda;
|
||||
}
|
||||
}
|
||||
else if(!transa_ && transb_)
|
||||
{
|
||||
real *a_ = a;
|
||||
for(i = 0; i < m; i++)
|
||||
{
|
||||
real *b_ = b;
|
||||
for(j = 0; j < n; j++)
|
||||
{
|
||||
real sum = 0;
|
||||
for(l = 0; l < k; l++)
|
||||
sum += a_[l*lda]*b_[l*ldb];
|
||||
b_++;
|
||||
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
|
||||
}
|
||||
a_++;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
real *a_ = a;
|
||||
for(i = 0; i < m; i++)
|
||||
{
|
||||
real *b_ = b;
|
||||
for(j = 0; j < n; j++)
|
||||
{
|
||||
real sum = 0;
|
||||
for(l = 0; l < k; l++)
|
||||
sum += a_[l]*b_[l*ldb];
|
||||
b_++;
|
||||
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
|
||||
}
|
||||
a_ += lda;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
19
lib/TH/generic/THBlas.h
Normal file
19
lib/TH/generic/THBlas.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THBlas.h"
|
||||
#else
|
||||
|
||||
/* Level 1 */
|
||||
void THBlas_(swap)(long n, real *x, long incx, real *y, long incy);
|
||||
void THBlas_(scal)(long n, real a, real *x, long incx);
|
||||
void THBlas_(copy)(long n, real *x, long incx, real *y, long incy);
|
||||
void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy);
|
||||
real THBlas_(dot)(long n, real *x, long incx, real *y, long incy);
|
||||
|
||||
/* Level 2 */
|
||||
void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, real *x, long incx, real beta, real *y, long incy);
|
||||
void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long incy, real *a, long lda);
|
||||
|
||||
/* Level 3 */
|
||||
void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc);
|
||||
|
||||
#endif
|
||||
66
lib/TH/generic/THLapack.c
Normal file
66
lib/TH/generic/THLapack.c
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THLapack.c"
|
||||
#else
|
||||
|
||||
void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info)
|
||||
{
|
||||
#ifdef __LAPACK__
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
|
||||
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
|
||||
#else
|
||||
extern void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
|
||||
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
|
||||
#endif
|
||||
#else
|
||||
THError("gesv : Lapack library not found in compile time\n");
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info)
|
||||
{
|
||||
#ifdef __LAPACK__
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
|
||||
dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
|
||||
#else
|
||||
extern void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info);
|
||||
sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
|
||||
#endif
|
||||
#else
|
||||
THError("gels : Lapack library not found in compile time\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info)
|
||||
{
|
||||
#ifdef __LAPACK__
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
|
||||
dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
|
||||
#else
|
||||
extern void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);
|
||||
ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
|
||||
#endif
|
||||
#else
|
||||
THError("syev : Lapack library not found in compile time\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info)
|
||||
{
|
||||
#ifdef __LAPACK__
|
||||
#if defined(TH_REAL_IS_DOUBLE)
|
||||
extern void dgesvd_(char *jobu, char *jobvt, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *info);
|
||||
dgesvd_( &jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info);
|
||||
#else
|
||||
extern void sgesvd_(char *jobu, char *jobvt, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *info);
|
||||
sgesvd_( &jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info);
|
||||
#endif
|
||||
#else
|
||||
THError("gesvd : Lapack library not found in compile time\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
15
lib/TH/generic/THLapack.h
Normal file
15
lib/TH/generic/THLapack.h
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THLapack.h"
|
||||
#else
|
||||
|
||||
|
||||
|
||||
/* AX=B */
|
||||
void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info);
|
||||
/* ||AX-B|| */
|
||||
void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info);
|
||||
/* Eigenvals */
|
||||
void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info);
|
||||
/* svd */
|
||||
void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info);
|
||||
#endif
|
||||
259
lib/TH/generic/THStorage.c
Normal file
259
lib/TH/generic/THStorage.c
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THStorage.c"
|
||||
#else
|
||||
|
||||
THStorage* THStorage_(new)(void)
|
||||
{
|
||||
return THStorage_(newWithSize)(0);
|
||||
}
|
||||
|
||||
THStorage* THStorage_(newWithSize)(long size)
|
||||
{
|
||||
THStorage *storage = THAlloc(sizeof(THStorage));
|
||||
storage->data = THAlloc(sizeof(real)*size);
|
||||
storage->size = size;
|
||||
storage->refcount = 1;
|
||||
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
|
||||
return storage;
|
||||
}
|
||||
|
||||
THStorage* THStorage_(newWithSize1)(real data0)
|
||||
{
|
||||
THStorage *self = THStorage_(newWithSize)(1);
|
||||
self->data[0] = data0;
|
||||
return self;
|
||||
}
|
||||
|
||||
THStorage* THStorage_(newWithSize2)(real data0, real data1)
|
||||
{
|
||||
THStorage *self = THStorage_(newWithSize)(2);
|
||||
self->data[0] = data0;
|
||||
self->data[1] = data1;
|
||||
return self;
|
||||
}
|
||||
|
||||
THStorage* THStorage_(newWithSize3)(real data0, real data1, real data2)
|
||||
{
|
||||
THStorage *self = THStorage_(newWithSize)(3);
|
||||
self->data[0] = data0;
|
||||
self->data[1] = data1;
|
||||
self->data[2] = data2;
|
||||
return self;
|
||||
}
|
||||
|
||||
THStorage* THStorage_(newWithSize4)(real data0, real data1, real data2, real data3)
|
||||
{
|
||||
THStorage *self = THStorage_(newWithSize)(4);
|
||||
self->data[0] = data0;
|
||||
self->data[1] = data1;
|
||||
self->data[2] = data2;
|
||||
self->data[3] = data3;
|
||||
return self;
|
||||
}
|
||||
|
||||
#if defined(_WIN32) || defined(HAVE_MMAP)
|
||||
|
||||
THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared)
|
||||
{
|
||||
THStorage *storage = THAlloc(sizeof(THStorage));
|
||||
long size;
|
||||
|
||||
/* check size */
|
||||
FILE *f = fopen(fileName, "rb");
|
||||
if(f == NULL)
|
||||
THError("unable to open file <%s> for mapping (read-only mode)", fileName);
|
||||
fseek(f, 0, SEEK_END);
|
||||
size = ftell(f);
|
||||
fclose(f);
|
||||
size /= sizeof(real);
|
||||
|
||||
#ifdef _WIN32
|
||||
{
|
||||
HANDLE hfile;
|
||||
HANDLE hmfile;
|
||||
DWORD size_hi, size_lo;
|
||||
|
||||
/* open file */
|
||||
if(isShared)
|
||||
{
|
||||
hfile = CreateFileA(fileName, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0);
|
||||
if (hfile == INVALID_HANDLE_VALUE)
|
||||
THError("could not open file <%s> in read-write mode", fileName);
|
||||
}
|
||||
else
|
||||
{
|
||||
hfile = CreateFileA(fileName, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0);
|
||||
if (hfile == INVALID_HANDLE_VALUE)
|
||||
THError("could not open file <%s> in read-only mode", fileName);
|
||||
}
|
||||
|
||||
#if SIZEOF_SIZE_T > 4
|
||||
size_hi = (DWORD)((size*sizeof(real)) >> 32);
|
||||
size_lo = (DWORD)((size*sizeof(real)) & 0xFFFFFFFF);
|
||||
#else
|
||||
size_hi = 0;
|
||||
size_lo = (DWORD)(size*sizeof(real));
|
||||
#endif
|
||||
|
||||
/* get map handle */
|
||||
if(isShared)
|
||||
{
|
||||
if( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, size_hi, size_lo, NULL)) == NULL )
|
||||
THError("could not create a map on file <%s>", fileName);
|
||||
}
|
||||
else
|
||||
{
|
||||
if( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, size_hi, size_lo, NULL)) == NULL )
|
||||
THError("could not create a map on file <%s>", fileName);
|
||||
}
|
||||
|
||||
/* map the stuff */
|
||||
storage = THStorage_(new)();
|
||||
if(isShared)
|
||||
storage->data = MapViewOfFile(hmfile, FILE_MAP_ALL_ACCESS, 0, 0, 0);
|
||||
else
|
||||
storage->data = MapViewOfFile(hmfile, FILE_MAP_COPY, 0, 0, 0);
|
||||
|
||||
storage->size = size;
|
||||
if(storage->data == NULL)
|
||||
{
|
||||
THStorage_(free)(storage);
|
||||
THError("memory map failed on file <%s>", fileName);
|
||||
}
|
||||
CloseHandle(hfile);
|
||||
CloseHandle(hmfile);
|
||||
}
|
||||
#else
|
||||
{
|
||||
/* open file */
|
||||
int fd;
|
||||
if(isShared)
|
||||
{
|
||||
fd = open(fileName, O_RDWR);
|
||||
if(fd == -1)
|
||||
THError("unable to open file <%s> in read-write mode", fileName);
|
||||
}
|
||||
else
|
||||
{
|
||||
fd = open(fileName, O_RDONLY);
|
||||
if(fd == -1)
|
||||
THError("unable to open file <%s> in read-only mode", fileName);
|
||||
}
|
||||
|
||||
/* map it */
|
||||
storage = THStorage_(new)();
|
||||
if(isShared)
|
||||
storage->data = mmap(NULL, size*sizeof(real), PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0);
|
||||
else
|
||||
storage->data = mmap(NULL, size*sizeof(real), PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0);
|
||||
|
||||
storage->size = size;
|
||||
if(storage->data == MAP_FAILED)
|
||||
{
|
||||
storage->data = NULL; /* let's be sure it is NULL before calling free() */
|
||||
THStorage_(free)(storage);
|
||||
THError("memory map failed on file <%s>", fileName);
|
||||
}
|
||||
close (fd);
|
||||
}
|
||||
#endif
|
||||
|
||||
storage->refcount = 1;
|
||||
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_MAPPED | TH_STORAGE_FREEMEM;;
|
||||
return storage;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared)
|
||||
{
|
||||
THError("Mapped file Storages are not supported on your system");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void THStorage_(setFlag)(THStorage *storage, const char flag)
|
||||
{
|
||||
storage->flag |= flag;
|
||||
}
|
||||
|
||||
void THStorage_(clearFlag)(THStorage *storage, const char flag)
|
||||
{
|
||||
storage->flag &= ~flag;
|
||||
}
|
||||
|
||||
void THStorage_(retain)(THStorage *storage)
|
||||
{
|
||||
if(storage && (storage->flag & TH_STORAGE_REFCOUNTED))
|
||||
++storage->refcount;
|
||||
}
|
||||
|
||||
void THStorage_(free)(THStorage *storage)
|
||||
{
|
||||
if(!storage)
|
||||
return;
|
||||
|
||||
if((storage->flag & TH_STORAGE_REFCOUNTED) && (storage->refcount > 0))
|
||||
{
|
||||
if(--storage->refcount == 0)
|
||||
{
|
||||
if(storage->flag & TH_STORAGE_FREEMEM)
|
||||
{
|
||||
#if defined(_WIN32) || defined(HAVE_MMAP)
|
||||
if(storage->flag & TH_STORAGE_MAPPED)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
if(!UnmapViewOfFile((LPINT)storage->data))
|
||||
#else
|
||||
if (munmap(storage->data, storage->size*sizeof(real)))
|
||||
#endif
|
||||
THError("could not unmap the shared memory file");
|
||||
}
|
||||
else
|
||||
#endif
|
||||
THFree(storage->data);
|
||||
}
|
||||
THFree(storage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
THStorage* THStorage_(newWithData)(real *data, long size)
|
||||
{
|
||||
THStorage *storage = THAlloc(sizeof(THStorage));
|
||||
storage->data = data;
|
||||
storage->size = size;
|
||||
storage->refcount = 1;
|
||||
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
|
||||
return storage;
|
||||
}
|
||||
|
||||
void THStorage_(resize)(THStorage *storage, long size)
|
||||
{
|
||||
if(storage->flag & TH_STORAGE_RESIZABLE)
|
||||
{
|
||||
storage->data = THRealloc(storage->data, sizeof(real)*size);
|
||||
storage->size = size;
|
||||
}
|
||||
}
|
||||
|
||||
void THStorage_(fill)(THStorage *storage, real value)
|
||||
{
|
||||
long i;
|
||||
for(i = 0; i < storage->size; i++)
|
||||
storage->data[i] = value;
|
||||
}
|
||||
|
||||
void THStorage_(set)(THStorage *self, long idx, real value)
|
||||
{
|
||||
THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds");
|
||||
self->data[idx] = value;
|
||||
}
|
||||
|
||||
real THStorage_(get)(THStorage *self, long idx)
|
||||
{
|
||||
THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds");
|
||||
return self->data[idx];
|
||||
}
|
||||
|
||||
#endif
|
||||
59
lib/TH/generic/THStorage.h
Normal file
59
lib/TH/generic/THStorage.h
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THStorage.h"
|
||||
#else
|
||||
|
||||
/* on pourrait avoir un liste chainee
|
||||
qui initialise math, lab structures (or more).
|
||||
mouais -- complique.
|
||||
|
||||
Pb: THMapStorage is kind of a class
|
||||
THLab_()... comment je m'en sors?
|
||||
|
||||
en template, faudrait que je les instancie toutes!!! oh boy!
|
||||
Et comment je sais que c'est pour Cuda? Le type float est le meme dans les <>
|
||||
|
||||
au bout du compte, ca serait sur des pointeurs float/double... etc... = facile.
|
||||
primitives??
|
||||
*/
|
||||
|
||||
#define TH_STORAGE_REFCOUNTED 1
|
||||
#define TH_STORAGE_RESIZABLE 2
|
||||
#define TH_STORAGE_MAPPED 4
|
||||
#define TH_STORAGE_FREEMEM 8
|
||||
|
||||
typedef struct THStorage
|
||||
{
|
||||
real *data;
|
||||
long size;
|
||||
int refcount;
|
||||
char flag;
|
||||
|
||||
} THStorage;
|
||||
|
||||
TH_API real* THStorage_(data)(THStorage*);
|
||||
TH_API long THStorage_(size)(THStorage*);
|
||||
|
||||
/* slow access -- checks everything */
|
||||
TH_API void THStorage_(set)(THStorage*, long, real);
|
||||
TH_API real THStorage_(get)(THStorage*, long);
|
||||
|
||||
TH_API THStorage* THStorage_(new)(void);
|
||||
TH_API THStorage* THStorage_(newWithSize)(long size);
|
||||
TH_API THStorage* THStorage_(newWithSize1)(real);
|
||||
TH_API THStorage* THStorage_(newWithSize2)(real, real);
|
||||
TH_API THStorage* THStorage_(newWithSize3)(real, real, real);
|
||||
TH_API THStorage* THStorage_(newWithSize4)(real, real, real, real);
|
||||
TH_API THStorage* THStorage_(newWithMapping)(const char *fileName, int isShared);
|
||||
TH_API THStorage* THStorage_(newWithData)(real *data, long size);
|
||||
|
||||
/* should not differ with API */
|
||||
TH_API void THStorage_(setFlag)(THStorage *storage, const char flag);
|
||||
TH_API void THStorage_(clearFlag)(THStorage *storage, const char flag);
|
||||
TH_API void THStorage_(retain)(THStorage *storage);
|
||||
|
||||
/* might differ with other API (like CUDA) */
|
||||
TH_API void THStorage_(free)(THStorage *storage);
|
||||
TH_API void THStorage_(resize)(THStorage *storage, long size);
|
||||
TH_API void THStorage_(fill)(THStorage *storage, real value);
|
||||
|
||||
#endif
|
||||
36
lib/TH/generic/THStorageCopy.c
Normal file
36
lib/TH/generic/THStorageCopy.c
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THStorageCopy.c"
|
||||
#else
|
||||
|
||||
void THStorage_(rawCopy)(THStorage *storage, real *src)
|
||||
{
|
||||
long i;
|
||||
for(i = 0; i < storage->size; i++)
|
||||
storage->data[i] = src[i];
|
||||
}
|
||||
|
||||
void THStorage_(copy)(THStorage *storage, THStorage *src)
|
||||
{
|
||||
THArgCheck(storage->size == src->size, 2, "size mismatch");
|
||||
THStorage_(rawCopy)(storage, src->data);
|
||||
}
|
||||
|
||||
|
||||
#define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \
|
||||
void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \
|
||||
{ \
|
||||
long i; \
|
||||
THArgCheck(storage->size == src->size, 2, "size mismatch"); \
|
||||
for(i = 0; i < storage->size; i++) \
|
||||
storage->data[i] = (real)src->data[i]; \
|
||||
}
|
||||
|
||||
IMPLEMENT_THStorage_COPY(Byte)
|
||||
IMPLEMENT_THStorage_COPY(Char)
|
||||
IMPLEMENT_THStorage_COPY(Short)
|
||||
IMPLEMENT_THStorage_COPY(Int)
|
||||
IMPLEMENT_THStorage_COPY(Long)
|
||||
IMPLEMENT_THStorage_COPY(Float)
|
||||
IMPLEMENT_THStorage_COPY(Double)
|
||||
|
||||
#endif
|
||||
17
lib/TH/generic/THStorageCopy.h
Normal file
17
lib/TH/generic/THStorageCopy.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THStorageCopy.h"
|
||||
#else
|
||||
|
||||
/* Support for copy between different Storage types */
|
||||
|
||||
TH_API void THStorage_(rawCopy)(THStorage *storage, real *src);
|
||||
TH_API void THStorage_(copy)(THStorage *storage, THStorage *src);
|
||||
TH_API void THStorage_(copyByte)(THStorage *storage, struct THByteStorage *src);
|
||||
TH_API void THStorage_(copyChar)(THStorage *storage, struct THCharStorage *src);
|
||||
TH_API void THStorage_(copyShort)(THStorage *storage, struct THShortStorage *src);
|
||||
TH_API void THStorage_(copyInt)(THStorage *storage, struct THIntStorage *src);
|
||||
TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src);
|
||||
TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src);
|
||||
TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src);
|
||||
|
||||
#endif
|
||||
728
lib/TH/generic/THTensor.c
Normal file
728
lib/TH/generic/THTensor.c
Normal file
|
|
@ -0,0 +1,728 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensor.c"
|
||||
#else
|
||||
|
||||
/**** access methods ****/
|
||||
THStorage *THTensor_(storage)(THTensor *self)
|
||||
{
|
||||
return self->storage;
|
||||
}
|
||||
|
||||
long THTensor_(storageOffset)(THTensor *self)
|
||||
{
|
||||
return self->storageOffset;
|
||||
}
|
||||
|
||||
int THTensor_(nDimension)(THTensor *self)
|
||||
{
|
||||
return self->nDimension;
|
||||
}
|
||||
|
||||
long THTensor_(size)(THTensor *self, int dim)
|
||||
{
|
||||
THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "out of range");
|
||||
return self->size[dim];
|
||||
}
|
||||
|
||||
long THTensor_(stride)(THTensor *self, int dim)
|
||||
{
|
||||
THArgCheck((dim >= 0) && (dim < self->nDimension), 2, "out of range");
|
||||
return self->stride[dim];
|
||||
}
|
||||
|
||||
THLongStorage *THTensor_(newSizeOf)(THTensor *self)
|
||||
{
|
||||
THLongStorage *size = THLongStorage_newWithSize(self->nDimension);
|
||||
THLongStorage_rawCopy(size, self->size);
|
||||
return size;
|
||||
}
|
||||
|
||||
THLongStorage *THTensor_(newStrideOf)(THTensor *self)
|
||||
{
|
||||
THLongStorage *stride = THLongStorage_newWithSize(self->nDimension);
|
||||
THLongStorage_rawCopy(stride, self->stride);
|
||||
return stride;
|
||||
}
|
||||
|
||||
real *THTensor_(data)(THTensor *self)
|
||||
{
|
||||
if(self->storage)
|
||||
return (self->storage->data+self->storageOffset);
|
||||
else
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void THTensor_(setFlag)(THTensor *self, const char flag)
|
||||
{
|
||||
self->flag |= flag;
|
||||
}
|
||||
|
||||
void THTensor_(clearFlag)(THTensor *self, const char flag)
|
||||
{
|
||||
self->flag &= ~flag;
|
||||
}
|
||||
|
||||
/**** creation methods ****/
|
||||
|
||||
static void THTensor_(rawInit)(THTensor *self);
|
||||
static void THTensor_(rawSet)(THTensor *self, THStorage *storage, long storageOffset, int nDimension, long *size, long *stride);
|
||||
static void THTensor_(rawResize)(THTensor *self, int nDimension, long *size, long *stride);
|
||||
|
||||
|
||||
/* Empty init */
|
||||
THTensor *THTensor_(new)(void)
|
||||
{
|
||||
THTensor *self = THAlloc(sizeof(THTensor));
|
||||
THTensor_(rawInit)(self);
|
||||
return self;
|
||||
}
|
||||
|
||||
/* Pointer-copy init */
|
||||
THTensor *THTensor_(newWithTensor)(THTensor *tensor)
|
||||
{
|
||||
THTensor *self = THAlloc(sizeof(THTensor));
|
||||
THTensor_(rawInit)(self);
|
||||
THTensor_(rawSet)(self,
|
||||
tensor->storage,
|
||||
tensor->storageOffset,
|
||||
tensor->nDimension,
|
||||
tensor->size,
|
||||
tensor->stride);
|
||||
return self;
|
||||
}
|
||||
|
||||
/* Storage init */
|
||||
THTensor *THTensor_(newWithStorage)(THStorage *storage, long storageOffset, THLongStorage *size, THLongStorage *stride)
|
||||
{
|
||||
THTensor *self = THAlloc(sizeof(THTensor));
|
||||
if(size && stride)
|
||||
THArgCheck(size->size == stride->size, 4, "inconsistent size");
|
||||
|
||||
THTensor_(rawInit)(self);
|
||||
THTensor_(rawSet)(self,
|
||||
storage,
|
||||
storageOffset,
|
||||
(size ? size->size : (stride ? stride->size : 0)),
|
||||
(size ? size->data : NULL),
|
||||
(stride ? stride->data : NULL));
|
||||
|
||||
return self;
|
||||
}
|
||||
THTensor *THTensor_(newWithStorage1d)(THStorage *storage, long storageOffset,
|
||||
long size0, long stride0)
|
||||
{
|
||||
return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, -1, -1, -1, -1, -1, -1);
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newWithStorage2d)(THStorage *storage, long storageOffset,
|
||||
long size0, long stride0,
|
||||
long size1, long stride1)
|
||||
{
|
||||
return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, size1, stride1, -1, -1, -1, -1);
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newWithStorage3d)(THStorage *storage, long storageOffset,
|
||||
long size0, long stride0,
|
||||
long size1, long stride1,
|
||||
long size2, long stride2)
|
||||
{
|
||||
return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, size1, stride1, size2, stride2, -1, -1);
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newWithStorage4d)(THStorage *storage, long storageOffset,
|
||||
long size0, long stride0,
|
||||
long size1, long stride1,
|
||||
long size2, long stride2,
|
||||
long size3, long stride3)
|
||||
{
|
||||
long size[4] = {size0, size1, size2, size3};
|
||||
long stride[4] = {stride0, stride1, stride2, stride3};
|
||||
|
||||
THTensor *self = THAlloc(sizeof(THTensor));
|
||||
THTensor_(rawInit)(self);
|
||||
THTensor_(rawSet)(self, storage, storageOffset, 4, size, stride);
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newWithSize)(THLongStorage *size, THLongStorage *stride)
|
||||
{
|
||||
return THTensor_(newWithStorage)(NULL, 0, size, stride);
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newWithSize1d)(long size0)
|
||||
{
|
||||
return THTensor_(newWithSize4d)(size0, -1, -1, -1);
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newWithSize2d)(long size0, long size1)
|
||||
{
|
||||
return THTensor_(newWithSize4d)(size0, size1, -1, -1);
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newWithSize3d)(long size0, long size1, long size2)
|
||||
{
|
||||
return THTensor_(newWithSize4d)(size0, size1, size2, -1);
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newWithSize4d)(long size0, long size1, long size2, long size3)
|
||||
{
|
||||
long size[4] = {size0, size1, size2, size3};
|
||||
|
||||
THTensor *self = THAlloc(sizeof(THTensor));
|
||||
THTensor_(rawInit)(self);
|
||||
THTensor_(rawResize)(self, 4, size, NULL);
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newClone)(THTensor *self)
|
||||
{
|
||||
THTensor *tensor = THTensor_(new)();
|
||||
THTensor_(resizeAs)(tensor, self);
|
||||
THTensor_(copy)(tensor, self);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newContiguous)(THTensor *self)
|
||||
{
|
||||
if(!THTensor_(isContiguous)(self))
|
||||
return THTensor_(newClone)(self);
|
||||
else
|
||||
{
|
||||
THTensor_(retain)(self);
|
||||
return self;
|
||||
}
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newSelect)(THTensor *tensor, int dimension_, long sliceIndex_)
|
||||
{
|
||||
THTensor *self = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(select)(self, NULL, dimension_, sliceIndex_);
|
||||
return self;
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long firstIndex_, long size_)
|
||||
{
|
||||
THTensor *self = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(narrow)(self, NULL, dimension_, firstIndex_, size_);
|
||||
return self;
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_)
|
||||
{
|
||||
THTensor *self = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(transpose)(self, NULL, dimension1_, dimension2_);
|
||||
return self;
|
||||
}
|
||||
|
||||
THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_)
|
||||
{
|
||||
THTensor *self = THTensor_(newWithTensor)(tensor);
|
||||
THTensor_(unfold)(self, NULL, dimension_, size_, step_);
|
||||
return self;
|
||||
}
|
||||
|
||||
/* Resize */
|
||||
void THTensor_(resize)(THTensor *self, THLongStorage *size, THLongStorage *stride)
|
||||
{
|
||||
THArgCheck(size != NULL, 2, "invalid size");
|
||||
if(stride)
|
||||
THArgCheck(stride->size == size->size, 3, "invalid stride");
|
||||
|
||||
THTensor_(rawResize)(self, size->size, size->data, (stride ? stride->data : NULL));
|
||||
}
|
||||
|
||||
void THTensor_(resizeAs)(THTensor *self, THTensor *src)
|
||||
{
|
||||
int isSame = 0;
|
||||
int d;
|
||||
if(self->nDimension == src->nDimension)
|
||||
{
|
||||
isSame = 1;
|
||||
for(d = 0; d < self->nDimension; d++)
|
||||
{
|
||||
if(self->size[d] != src->size[d])
|
||||
{
|
||||
isSame = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(!isSame)
|
||||
THTensor_(rawResize)(self, src->nDimension, src->size, NULL);
|
||||
}
|
||||
|
||||
void THTensor_(resize1d)(THTensor *tensor, long size0)
|
||||
{
|
||||
THTensor_(resize4d)(tensor, size0, -1, -1, -1);
|
||||
}
|
||||
|
||||
void THTensor_(resize2d)(THTensor *tensor, long size0, long size1)
|
||||
{
|
||||
THTensor_(resize4d)(tensor, size0, size1, -1, -1);
|
||||
}
|
||||
|
||||
void THTensor_(resize3d)(THTensor *tensor, long size0, long size1, long size2)
|
||||
{
|
||||
THTensor_(resize4d)(tensor, size0, size1, size2, -1);
|
||||
}
|
||||
|
||||
void THTensor_(resize4d)(THTensor *self, long size0, long size1, long size2, long size3)
|
||||
{
|
||||
long size[4] = {size0, size1, size2, size3};
|
||||
|
||||
THTensor_(rawResize)(self, 4, size, NULL);
|
||||
}
|
||||
|
||||
void THTensor_(resize5d)(THTensor *self, long size0, long size1, long size2, long size3, long size4)
|
||||
{
|
||||
long size[5] = {size0, size1, size2, size3, size4};
|
||||
|
||||
THTensor_(rawResize)(self, 5, size, NULL);
|
||||
}
|
||||
|
||||
void THTensor_(set)(THTensor *self, THTensor *src)
|
||||
{
|
||||
if(self != src)
|
||||
THTensor_(rawSet)(self,
|
||||
src->storage,
|
||||
src->storageOffset,
|
||||
src->nDimension,
|
||||
src->size,
|
||||
src->stride);
|
||||
}
|
||||
|
||||
void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_)
|
||||
{
|
||||
if(size_ && stride_)
|
||||
THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes");
|
||||
|
||||
THTensor_(rawSet)(self,
|
||||
storage_,
|
||||
storageOffset_,
|
||||
(size_ ? size_->size : (stride_ ? stride_->size : 0)),
|
||||
(size_ ? size_->data : NULL),
|
||||
(stride_ ? stride_->data : NULL));
|
||||
}
|
||||
|
||||
void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_)
|
||||
{
|
||||
THTensor_(setStorage4d)(self, storage_, storageOffset_,
|
||||
size0_, stride0_,
|
||||
-1, -1,
|
||||
-1, -1,
|
||||
-1, -1);
|
||||
}
|
||||
|
||||
void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_)
|
||||
{
|
||||
THTensor_(setStorage4d)(self, storage_, storageOffset_,
|
||||
size0_, stride0_,
|
||||
size1_, stride1_,
|
||||
-1, -1,
|
||||
-1, -1);
|
||||
}
|
||||
|
||||
void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_,
|
||||
long size2_, long stride2_)
|
||||
{
|
||||
THTensor_(setStorage4d)(self, storage_, storageOffset_,
|
||||
size0_, stride0_,
|
||||
size1_, stride1_,
|
||||
size2_, stride2_,
|
||||
-1, -1);
|
||||
}
|
||||
|
||||
void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_,
|
||||
long size2_, long stride2_,
|
||||
long size3_, long stride3_)
|
||||
{
|
||||
|
||||
long size[4] = {size0_, size1_, size2_, size3_};
|
||||
long stride[4] = {stride0_, stride1_, stride2_, stride3_};
|
||||
|
||||
THTensor_(rawSet)(self, storage_, storageOffset_, 4, size, stride);
|
||||
}
|
||||
|
||||
|
||||
void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension, long firstIndex, long size)
|
||||
{
|
||||
if(!src)
|
||||
src = self;
|
||||
|
||||
THArgCheck( (dimension >= 0) && (dimension < src->nDimension), 3, "out of range");
|
||||
THArgCheck( (firstIndex >= 0) && (firstIndex < src->size[dimension]), 4, "out of range");
|
||||
THArgCheck( (size > 0) && (firstIndex+size <= src->size[dimension]), 5, "out of range");
|
||||
|
||||
THTensor_(set)(self, src);
|
||||
|
||||
if(firstIndex > 0)
|
||||
self->storageOffset += firstIndex*self->stride[dimension];
|
||||
|
||||
self->size[dimension] = size;
|
||||
}
|
||||
|
||||
void THTensor_(select)(THTensor *self, THTensor *src, int dimension, long sliceIndex)
|
||||
{
|
||||
int d;
|
||||
|
||||
if(!src)
|
||||
src = self;
|
||||
|
||||
THArgCheck(src->nDimension > 1, 1, "cannot select on a vector");
|
||||
THArgCheck((dimension >= 0) && (dimension < src->nDimension), 3, "out of range");
|
||||
THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 4, "out of range");
|
||||
|
||||
THTensor_(set)(self, src);
|
||||
THTensor_(narrow)(self, NULL, dimension, sliceIndex, 1);
|
||||
for(d = dimension; d < self->nDimension-1; d++)
|
||||
{
|
||||
self->size[d] = self->size[d+1];
|
||||
self->stride[d] = self->stride[d+1];
|
||||
}
|
||||
self->nDimension--;
|
||||
}
|
||||
|
||||
void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1, int dimension2)
|
||||
{
|
||||
long z;
|
||||
|
||||
if(!src)
|
||||
src = self;
|
||||
|
||||
THArgCheck( (dimension1 >= 0) && (dimension1 < src->nDimension), 1, "out of range");
|
||||
THArgCheck( (dimension2 >= 0) && (dimension2 < src->nDimension), 2, "out of range");
|
||||
|
||||
THTensor_(set)(self, src);
|
||||
|
||||
if(dimension1 == dimension2)
|
||||
return;
|
||||
|
||||
z = self->stride[dimension1];
|
||||
self->stride[dimension1] = self->stride[dimension2];
|
||||
self->stride[dimension2] = z;
|
||||
z = self->size[dimension1];
|
||||
self->size[dimension1] = self->size[dimension2];
|
||||
self->size[dimension2] = z;
|
||||
}
|
||||
|
||||
void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension, long size, long step)
|
||||
{
|
||||
long *newSize;
|
||||
long *newStride;
|
||||
int d;
|
||||
|
||||
if(!src)
|
||||
src = self;
|
||||
|
||||
THArgCheck( (src->nDimension > 0), 1, "cannot unfold an empty tensor");
|
||||
THArgCheck(dimension < src->nDimension, 2, "out of range");
|
||||
THArgCheck(size <= src->size[dimension], 3, "out of range");
|
||||
THArgCheck(step > 0, 4, "invalid step");
|
||||
|
||||
THTensor_(set)(self, src);
|
||||
|
||||
newSize = THAlloc(sizeof(long)*(self->nDimension+1));
|
||||
newStride = THAlloc(sizeof(long)*(self->nDimension+1));
|
||||
|
||||
newSize[self->nDimension] = size;
|
||||
newStride[self->nDimension] = self->stride[dimension];
|
||||
for(d = 0; d < self->nDimension; d++)
|
||||
{
|
||||
if(d == dimension)
|
||||
{
|
||||
newSize[d] = (self->size[d] - size) / step + 1;
|
||||
newStride[d] = step*self->stride[d];
|
||||
}
|
||||
else
|
||||
{
|
||||
newSize[d] = self->size[d];
|
||||
newStride[d] = self->stride[d];
|
||||
}
|
||||
}
|
||||
|
||||
THFree(self->size);
|
||||
THFree(self->stride);
|
||||
|
||||
self->size = newSize;
|
||||
self->stride = newStride;
|
||||
self->nDimension++;
|
||||
}
|
||||
|
||||
/* we have to handle the case where the result is a number */
|
||||
void THTensor_(squeeze)(THTensor *self, THTensor *src)
|
||||
{
|
||||
int ndim = 0;
|
||||
int d;
|
||||
|
||||
if(!src)
|
||||
src = self;
|
||||
|
||||
THTensor_(set)(self, src);
|
||||
|
||||
for(d = 0; d < src->nDimension; d++)
|
||||
{
|
||||
if(src->size[d] != 1)
|
||||
{
|
||||
if(d != ndim)
|
||||
{
|
||||
self->size[ndim] = src->size[d];
|
||||
self->stride[ndim] = src->stride[d];
|
||||
}
|
||||
ndim++;
|
||||
}
|
||||
}
|
||||
|
||||
/* right now, we do not handle 0-dimension tensors */
|
||||
if(ndim == 0 && src->nDimension > 0)
|
||||
{
|
||||
self->size[0] = 1;
|
||||
self->stride[0] = 1;
|
||||
ndim = 1;
|
||||
}
|
||||
self->nDimension = ndim;
|
||||
}
|
||||
|
||||
void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension)
|
||||
{
|
||||
int d;
|
||||
|
||||
if(!src)
|
||||
src = self;
|
||||
|
||||
THArgCheck(dimension < src->nDimension, 3, "dimension out of range");
|
||||
|
||||
THTensor_(set)(self, src);
|
||||
|
||||
if(src->size[dimension] == 1 && src->nDimension > 1)
|
||||
{
|
||||
for(d = dimension; d < self->nDimension-1; d++)
|
||||
{
|
||||
self->size[d] = self->size[d+1];
|
||||
self->stride[d] = self->stride[d+1];
|
||||
}
|
||||
self->nDimension--;
|
||||
}
|
||||
}
|
||||
|
||||
int THTensor_(isContiguous)(THTensor *self)
|
||||
{
|
||||
long z = 1;
|
||||
int d;
|
||||
for(d = self->nDimension-1; d >= 0; d--)
|
||||
{
|
||||
if(self->size[d] != 1)
|
||||
{
|
||||
if(self->stride[d] == z)
|
||||
z *= self->size[d];
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
long THTensor_(nElement)(THTensor *self)
|
||||
{
|
||||
if(self->nDimension == 0)
|
||||
return 0;
|
||||
else
|
||||
{
|
||||
long nElement = 1;
|
||||
int d;
|
||||
for(d = 0; d < self->nDimension; d++)
|
||||
nElement *= self->size[d];
|
||||
return nElement;
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(retain)(THTensor *self)
|
||||
{
|
||||
if(self->flag & TH_TENSOR_REFCOUNTED)
|
||||
++self->refcount;
|
||||
}
|
||||
|
||||
void THTensor_(free)(THTensor *self)
|
||||
{
|
||||
if(!self)
|
||||
return;
|
||||
|
||||
if(self->flag & TH_TENSOR_REFCOUNTED)
|
||||
{
|
||||
if(--self->refcount == 0)
|
||||
{
|
||||
THFree(self->size);
|
||||
THFree(self->stride);
|
||||
if(self->storage)
|
||||
THStorage_(free)(self->storage);
|
||||
THFree(self);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst)
|
||||
{
|
||||
if(self != dst)
|
||||
THTensor_(copy)(dst, self);
|
||||
|
||||
THTensor_(free)(self);
|
||||
}
|
||||
|
||||
/*******************************************************************************/
|
||||
|
||||
static void THTensor_(rawInit)(THTensor *self)
|
||||
{
|
||||
self->refcount = 1;
|
||||
self->storage = NULL;
|
||||
self->storageOffset = 0;
|
||||
self->size = NULL;
|
||||
self->stride = NULL;
|
||||
self->nDimension = 0;
|
||||
self->flag = TH_TENSOR_REFCOUNTED;
|
||||
}
|
||||
|
||||
static void THTensor_(rawSet)(THTensor *self, THStorage *storage, long storageOffset, int nDimension, long *size, long *stride)
|
||||
{
|
||||
/* storage */
|
||||
if(self->storage != storage)
|
||||
{
|
||||
if(self->storage)
|
||||
THStorage_(free)(self->storage);
|
||||
|
||||
if(storage)
|
||||
{
|
||||
self->storage = storage;
|
||||
THStorage_(retain)(self->storage);
|
||||
}
|
||||
else
|
||||
self->storage = NULL;
|
||||
}
|
||||
|
||||
/* storageOffset */
|
||||
if(storageOffset < 0)
|
||||
THError("Tensor: invalid storage offset");
|
||||
self->storageOffset = storageOffset;
|
||||
|
||||
/* size and stride */
|
||||
THTensor_(rawResize)(self, nDimension, size, stride);
|
||||
}
|
||||
|
||||
static void THTensor_(rawResize)(THTensor *self, int nDimension, long *size, long *stride)
|
||||
{
|
||||
int d;
|
||||
int nDimension_;
|
||||
long totalSize;
|
||||
|
||||
nDimension_ = 0;
|
||||
for(d = 0; d < nDimension; d++)
|
||||
{
|
||||
if(size[d] > 0)
|
||||
nDimension_++;
|
||||
else
|
||||
break;
|
||||
}
|
||||
nDimension = nDimension_;
|
||||
|
||||
if(nDimension > 0)
|
||||
{
|
||||
if(nDimension != self->nDimension)
|
||||
{
|
||||
self->size = THRealloc(self->size, sizeof(long)*nDimension);
|
||||
self->stride = THRealloc(self->stride, sizeof(long)*nDimension);
|
||||
self->nDimension = nDimension;
|
||||
}
|
||||
|
||||
totalSize = 1;
|
||||
for(d = self->nDimension-1; d >= 0; d--)
|
||||
{
|
||||
self->size[d] = size[d];
|
||||
if(stride && (stride[d] >= 0) )
|
||||
self->stride[d] = stride[d];
|
||||
else
|
||||
{
|
||||
if(d == self->nDimension-1)
|
||||
self->stride[d] = 1;
|
||||
else
|
||||
self->stride[d] = self->size[d+1]*self->stride[d+1];
|
||||
}
|
||||
totalSize += (self->size[d]-1)*self->stride[d];
|
||||
}
|
||||
|
||||
if(totalSize+self->storageOffset > 0)
|
||||
{
|
||||
if(!self->storage)
|
||||
self->storage = THStorage_(new)();
|
||||
if(totalSize+self->storageOffset > self->storage->size)
|
||||
THStorage_(resize)(self->storage, totalSize+self->storageOffset);
|
||||
}
|
||||
}
|
||||
else
|
||||
self->nDimension = 0;
|
||||
}
|
||||
|
||||
void THTensor_(set1d)(THTensor *tensor, long x0, real value)
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 1, 1, "tensor must have one dimension");
|
||||
THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]), 2, "out of range");
|
||||
THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0], value);
|
||||
}
|
||||
|
||||
real THTensor_(get1d)(THTensor *tensor, long x0)
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 1, 1, "tensor must have one dimension");
|
||||
THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]), 2, "out of range");
|
||||
return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]);
|
||||
}
|
||||
|
||||
void THTensor_(set2d)(THTensor *tensor, long x0, long x1, real value)
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 2, 1, "tensor must have two dimensions");
|
||||
THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]), 2, "out of range");
|
||||
THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1], value);
|
||||
}
|
||||
|
||||
real THTensor_(get2d)(THTensor *tensor, long x0, long x1)
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 2, 1, "tensor must have two dimensions");
|
||||
THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]), 2, "out of range");
|
||||
return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]);
|
||||
}
|
||||
|
||||
void THTensor_(set3d)(THTensor *tensor, long x0, long x1, long x2, real value)
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 3, 1, "tensor must have three dimensions");
|
||||
THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]), 2, "out of range");
|
||||
THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2], value);
|
||||
}
|
||||
|
||||
real THTensor_(get3d)(THTensor *tensor, long x0, long x1, long x2)
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 3, 1, "tensor must have three dimensions");
|
||||
THArgCheck( (x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]), 2, "out of range");
|
||||
return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]);
|
||||
}
|
||||
|
||||
void THTensor_(set4d)(THTensor *tensor, long x0, long x1, long x2, long x3, real value)
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 4, 1, "tensor must have four dimensions");
|
||||
THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]) && (x3 >= 0) && (x3 < tensor->size[3]), 2, "out of range");
|
||||
THStorage_(set)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]+x3*tensor->stride[3], value);
|
||||
}
|
||||
|
||||
real THTensor_(get4d)(THTensor *tensor, long x0, long x1, long x2, long x3)
|
||||
{
|
||||
THArgCheck(tensor->nDimension == 4, 1, "tensor must have four dimensions");
|
||||
THArgCheck((x0 >= 0) && (x0 < tensor->size[0]) && (x1 >= 0) && (x1 < tensor->size[1]) && (x2 >= 0) && (x2 < tensor->size[2]) && (x3 >= 0) && (x3 < tensor->size[3]), 2, "out of range");
|
||||
return THStorage_(get)(tensor->storage, tensor->storageOffset+x0*tensor->stride[0]+x1*tensor->stride[1]+x2*tensor->stride[2]+x3*tensor->stride[3]);
|
||||
}
|
||||
|
||||
#endif
|
||||
123
lib/TH/generic/THTensor.h
Normal file
123
lib/TH/generic/THTensor.h
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensor.h"
|
||||
#else
|
||||
|
||||
/* a la lua? dim, storageoffset, ... et les methodes ? */
|
||||
|
||||
#define TH_TENSOR_REFCOUNTED 1
|
||||
|
||||
typedef struct THTensor
|
||||
{
|
||||
long *size;
|
||||
long *stride;
|
||||
int nDimension;
|
||||
|
||||
THStorage *storage;
|
||||
long storageOffset;
|
||||
int refcount;
|
||||
|
||||
char flag;
|
||||
|
||||
} THTensor;
|
||||
|
||||
|
||||
/**** access methods ****/
|
||||
TH_API THStorage* THTensor_(storage)(THTensor *self);
|
||||
TH_API long THTensor_(storageOffset)(THTensor *self);
|
||||
TH_API int THTensor_(nDimension)(THTensor *self);
|
||||
TH_API long THTensor_(size)(THTensor *self, int dim);
|
||||
TH_API long THTensor_(stride)(THTensor *self, int dim);
|
||||
TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self);
|
||||
TH_API THLongStorage *THTensor_(newStrideOf)(THTensor *self);
|
||||
TH_API real *THTensor_(data)(THTensor *self);
|
||||
|
||||
TH_API void THTensor_(setFlag)(THTensor *self, const char flag);
|
||||
TH_API void THTensor_(clearFlag)(THTensor *self, const char flag);
|
||||
|
||||
|
||||
/**** creation methods ****/
|
||||
TH_API THTensor *THTensor_(new)(void);
|
||||
TH_API THTensor *THTensor_(newWithTensor)(THTensor *tensor);
|
||||
/* stride might be NULL */
|
||||
TH_API THTensor *THTensor_(newWithStorage)(THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_);
|
||||
TH_API THTensor *THTensor_(newWithStorage1d)(THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_);
|
||||
TH_API THTensor *THTensor_(newWithStorage2d)(THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_);
|
||||
TH_API THTensor *THTensor_(newWithStorage3d)(THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_,
|
||||
long size2_, long stride2_);
|
||||
TH_API THTensor *THTensor_(newWithStorage4d)(THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_,
|
||||
long size2_, long stride2_,
|
||||
long size3_, long stride3_);
|
||||
|
||||
/* stride might be NULL */
|
||||
TH_API THTensor *THTensor_(newWithSize)(THLongStorage *size_, THLongStorage *stride_);
|
||||
TH_API THTensor *THTensor_(newWithSize1d)(long size0_);
|
||||
TH_API THTensor *THTensor_(newWithSize2d)(long size0_, long size1_);
|
||||
TH_API THTensor *THTensor_(newWithSize3d)(long size0_, long size1_, long size2_);
|
||||
TH_API THTensor *THTensor_(newWithSize4d)(long size0_, long size1_, long size2_, long size3_);
|
||||
|
||||
TH_API THTensor *THTensor_(newClone)(THTensor *self);
|
||||
TH_API THTensor *THTensor_(newContiguous)(THTensor *tensor);
|
||||
TH_API THTensor *THTensor_(newSelect)(THTensor *tensor, int dimension_, long sliceIndex_);
|
||||
TH_API THTensor *THTensor_(newNarrow)(THTensor *tensor, int dimension_, long firstIndex_, long size_);
|
||||
TH_API THTensor *THTensor_(newTranspose)(THTensor *tensor, int dimension1_, int dimension2_);
|
||||
TH_API THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, long size_, long step_);
|
||||
|
||||
TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride);
|
||||
TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src);
|
||||
TH_API void THTensor_(resize1d)(THTensor *tensor, long size0_);
|
||||
TH_API void THTensor_(resize2d)(THTensor *tensor, long size0_, long size1_);
|
||||
TH_API void THTensor_(resize3d)(THTensor *tensor, long size0_, long size1_, long size2_);
|
||||
TH_API void THTensor_(resize4d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_);
|
||||
TH_API void THTensor_(resize5d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_, long size4_);
|
||||
|
||||
TH_API void THTensor_(set)(THTensor *self, THTensor *src);
|
||||
TH_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_);
|
||||
TH_API void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_);
|
||||
TH_API void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_);
|
||||
TH_API void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_,
|
||||
long size2_, long stride2_);
|
||||
TH_API void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_,
|
||||
long size0_, long stride0_,
|
||||
long size1_, long stride1_,
|
||||
long size2_, long stride2_,
|
||||
long size3_, long stride3_);
|
||||
|
||||
TH_API void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension_, long firstIndex_, long size_);
|
||||
TH_API void THTensor_(select)(THTensor *self, THTensor *src, int dimension_, long sliceIndex_);
|
||||
TH_API void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1_, int dimension2_);
|
||||
TH_API void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension_, long size_, long step_);
|
||||
|
||||
TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src);
|
||||
TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_);
|
||||
|
||||
TH_API int THTensor_(isContiguous)(THTensor *self);
|
||||
TH_API long THTensor_(nElement)(THTensor *self);
|
||||
|
||||
TH_API void THTensor_(retain)(THTensor *self);
|
||||
TH_API void THTensor_(free)(THTensor *self);
|
||||
TH_API void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst);
|
||||
|
||||
/* Slow access methods [check everything] */
|
||||
TH_API void THTensor_(set1d)(THTensor *tensor, long x0, real value);
|
||||
TH_API void THTensor_(set2d)(THTensor *tensor, long x0, long x1, real value);
|
||||
TH_API void THTensor_(set3d)(THTensor *tensor, long x0, long x1, long x2, real value);
|
||||
TH_API void THTensor_(set4d)(THTensor *tensor, long x0, long x1, long x2, long x3, real value);
|
||||
|
||||
TH_API real THTensor_(get1d)(THTensor *tensor, long x0);
|
||||
TH_API real THTensor_(get2d)(THTensor *tensor, long x0, long x1);
|
||||
TH_API real THTensor_(get3d)(THTensor *tensor, long x0, long x1, long x2);
|
||||
TH_API real THTensor_(get4d)(THTensor *tensor, long x0, long x1, long x2, long x3);
|
||||
|
||||
#endif
|
||||
1489
lib/TH/generic/THTensorConv.c
Normal file
1489
lib/TH/generic/THTensorConv.c
Normal file
File diff suppressed because it is too large
Load Diff
78
lib/TH/generic/THTensorConv.h
Normal file
78
lib/TH/generic/THTensorConv.h
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensorConv.h"
|
||||
#else
|
||||
|
||||
|
||||
TH_API void THTensor_(validXCorr2Dptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long ir, long ic,
|
||||
real *k_, long kr, long kc,
|
||||
long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(validConv2Dptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long ir, long ic,
|
||||
real *k_, long kr, long kc,
|
||||
long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(fullXCorr2Dptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long ir, long ic,
|
||||
real *k_, long kr, long kc,
|
||||
long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(fullConv2Dptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long ir, long ic,
|
||||
real *k_, long kr, long kc,
|
||||
long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(validXCorr2DRevptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long ir, long ic,
|
||||
real *k_, long kr, long kc,
|
||||
long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(conv2DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol);
|
||||
TH_API void THTensor_(conv2Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc);
|
||||
TH_API void THTensor_(conv2Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc);
|
||||
TH_API void THTensor_(conv2Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc);
|
||||
TH_API void THTensor_(conv2Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long srow, long scol, const char *vf, const char *xc);
|
||||
|
||||
TH_API void THTensor_(validXCorr3Dptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long it, long ir, long ic,
|
||||
real *k_, long kt, long kr, long kc,
|
||||
long st, long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(validConv3Dptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long it, long ir, long ic,
|
||||
real *k_, long kt, long kr, long kc,
|
||||
long st, long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(fullXCorr3Dptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long it, long ir, long ic,
|
||||
real *k_, long kt, long kr, long kc,
|
||||
long st, long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(fullConv3Dptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long it, long ir, long ic,
|
||||
real *k_, long kt, long kr, long kc,
|
||||
long st, long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(validXCorr3DRevptr)(real *r_,
|
||||
real alpha,
|
||||
real *t_, long it, long ir, long ic,
|
||||
real *k_, long kt, long kr, long kc,
|
||||
long st, long sr, long sc);
|
||||
|
||||
TH_API void THTensor_(conv3DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol);
|
||||
TH_API void THTensor_(conv3Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc);
|
||||
TH_API void THTensor_(conv3Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc);
|
||||
TH_API void THTensor_(conv3Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc);
|
||||
TH_API void THTensor_(conv3Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc);
|
||||
|
||||
#endif
|
||||
21
lib/TH/generic/THTensorCopy.c
Normal file
21
lib/TH/generic/THTensorCopy.c
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensorCopy.c"
|
||||
#else
|
||||
|
||||
#define IMPLEMENT_THTensor_COPY(TYPENAMESRC, TYPE_SRC) \
|
||||
void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \
|
||||
{ \
|
||||
TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)(*src_data);) \
|
||||
}
|
||||
|
||||
IMPLEMENT_THTensor_COPY(, real)
|
||||
|
||||
IMPLEMENT_THTensor_COPY(Byte, unsigned char)
|
||||
IMPLEMENT_THTensor_COPY(Char, char)
|
||||
IMPLEMENT_THTensor_COPY(Short, short)
|
||||
IMPLEMENT_THTensor_COPY(Int, int)
|
||||
IMPLEMENT_THTensor_COPY(Long, long)
|
||||
IMPLEMENT_THTensor_COPY(Float, float)
|
||||
IMPLEMENT_THTensor_COPY(Double, double)
|
||||
|
||||
#endif
|
||||
16
lib/TH/generic/THTensorCopy.h
Normal file
16
lib/TH/generic/THTensorCopy.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensorCopy.h"
|
||||
#else
|
||||
|
||||
/* Support for copy between different Tensor types */
|
||||
|
||||
TH_API void THTensor_(copy)(THTensor *tensor, THTensor *src);
|
||||
TH_API void THTensor_(copyByte)(THTensor *tensor, struct THByteTensor *src);
|
||||
TH_API void THTensor_(copyChar)(THTensor *tensor, struct THCharTensor *src);
|
||||
TH_API void THTensor_(copyShort)(THTensor *tensor, struct THShortTensor *src);
|
||||
TH_API void THTensor_(copyInt)(THTensor *tensor, struct THIntTensor *src);
|
||||
TH_API void THTensor_(copyLong)(THTensor *tensor, struct THLongTensor *src);
|
||||
TH_API void THTensor_(copyFloat)(THTensor *tensor, struct THFloatTensor *src);
|
||||
TH_API void THTensor_(copyDouble)(THTensor *tensor, struct THDoubleTensor *src);
|
||||
|
||||
#endif
|
||||
343
lib/TH/generic/THTensorLapack.c
Normal file
343
lib/TH/generic/THTensorLapack.c
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensorLapack.c"
|
||||
#else
|
||||
|
||||
static int THTensor_(lapackClone)(THTensor *r_, THTensor *m, int forced)
|
||||
{
|
||||
int clone;
|
||||
|
||||
if (!forced && m->stride[0] == 1 && m->stride[1] == m->size[0])
|
||||
{
|
||||
clone = 0;
|
||||
THTensor_(set)(r_,m);
|
||||
}
|
||||
else
|
||||
{
|
||||
clone = 1;
|
||||
/* we need to copy */
|
||||
THTensor_(resize2d)(r_,m->size[1],m->size[0]);
|
||||
THTensor_(transpose)(r_,NULL,0,1);
|
||||
THTensor_(copy)(r_,m);
|
||||
}
|
||||
return clone;
|
||||
}
|
||||
|
||||
TH_API void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
|
||||
{
|
||||
int n, nrhs, lda, ldb, info;
|
||||
THIntTensor *ipiv;
|
||||
THTensor *ra__;
|
||||
THTensor *rb__;
|
||||
|
||||
int clonea;
|
||||
int cloneb;
|
||||
int destroya;
|
||||
int destroyb;
|
||||
|
||||
|
||||
if (a == NULL || ra_ == a) /* possibly destroy the inputs */
|
||||
{
|
||||
ra__ = THTensor_(new)();
|
||||
clonea = THTensor_(lapackClone)(ra__,ra_,0);
|
||||
destroya = 1;
|
||||
}
|
||||
else /*we want to definitely clone and use ra_ and rb_ as computational space*/
|
||||
{
|
||||
clonea = THTensor_(lapackClone)(ra_,a,1);
|
||||
ra__ = ra_;
|
||||
destroya = 0;
|
||||
}
|
||||
if (b == NULL || rb_ == b) /* possibly destroy the inputs */
|
||||
{
|
||||
rb__ = THTensor_(new)();
|
||||
cloneb = THTensor_(lapackClone)(rb__,rb_,0);
|
||||
destroyb = 1;
|
||||
}
|
||||
else /*we want to definitely clone and use ra_ and rb_ as computational space*/
|
||||
{
|
||||
cloneb = THTensor_(lapackClone)(rb_,b,1);
|
||||
rb__ = rb_;
|
||||
destroyb = 0;
|
||||
}
|
||||
|
||||
THArgCheck(ra__->nDimension == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(rb__->nDimension == 2, 2, "b should be 2 dimensional");
|
||||
THArgCheck(ra__->size[0] == ra__->size[1], 1, "A should be square");
|
||||
THArgCheck(rb__->size[0] == ra__->size[0], 2, "A,b size incomptable");
|
||||
|
||||
n = (int)ra__->size[0];
|
||||
nrhs = (int)rb__->size[1];
|
||||
lda = n;
|
||||
ldb = n;
|
||||
|
||||
ipiv = THIntTensor_newWithSize1d((long)n);
|
||||
THLapack_(gesv)(n, nrhs,
|
||||
THTensor_(data)(ra__), lda, THIntTensor_data(ipiv),
|
||||
THTensor_(data)(rb__), ldb, &info);
|
||||
|
||||
/* clean up */
|
||||
if (destroya)
|
||||
{
|
||||
if (clonea)
|
||||
{
|
||||
THTensor_(copy)(ra_,ra__);
|
||||
}
|
||||
THTensor_(free)(ra__);
|
||||
}
|
||||
if (destroyb)
|
||||
{
|
||||
if (cloneb)
|
||||
{
|
||||
THTensor_(copy)(rb_,rb__);
|
||||
}
|
||||
THTensor_(free)(rb__);
|
||||
}
|
||||
|
||||
if (info < 0)
|
||||
{
|
||||
THError("Lapack gesv : Argument %d : illegal value", -info);
|
||||
}
|
||||
else if (info > 0)
|
||||
{
|
||||
THError("Lapack gesv : U(%d,%d) is zero, singular U.", info,info);
|
||||
}
|
||||
|
||||
THIntTensor_free(ipiv);
|
||||
}
|
||||
|
||||
TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
|
||||
{
|
||||
int m, n, nrhs, lda, ldb, info, lwork;
|
||||
char transpose;
|
||||
THTensor *work = NULL;
|
||||
real wkopt = 0;
|
||||
|
||||
THTensor *ra__;
|
||||
THTensor *rb__;
|
||||
|
||||
int clonea;
|
||||
int cloneb;
|
||||
int destroya;
|
||||
int destroyb;
|
||||
|
||||
|
||||
if (a == NULL || ra_ == a) /* possibly destroy the inputs */
|
||||
{
|
||||
ra__ = THTensor_(new)();
|
||||
clonea = THTensor_(lapackClone)(ra__,ra_,0);
|
||||
destroya = 1;
|
||||
}
|
||||
else /*we want to definitely clone and use ra_ and rb_ as computational space*/
|
||||
{
|
||||
clonea = THTensor_(lapackClone)(ra_,a,1);
|
||||
ra__ = ra_;
|
||||
destroya = 0;
|
||||
}
|
||||
if (b == NULL || rb_ == b) /* possibly destroy the inputs */
|
||||
{
|
||||
rb__ = THTensor_(new)();
|
||||
cloneb = THTensor_(lapackClone)(rb__,rb_,0);
|
||||
destroyb = 1;
|
||||
}
|
||||
else /*we want to definitely clone and use ra_ and rb_ as computational space*/
|
||||
{
|
||||
cloneb = THTensor_(lapackClone)(rb_,b,1);
|
||||
rb__ = rb_;
|
||||
destroyb = 0;
|
||||
}
|
||||
|
||||
THArgCheck(ra__->nDimension == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(ra_->size[0] == rb__->size[0], 2, "size incompatible A,b");
|
||||
|
||||
m = ra__->size[0];
|
||||
n = ra__->size[1];
|
||||
nrhs = rb__->size[1];
|
||||
lda = m;
|
||||
ldb = m;
|
||||
info = 0;
|
||||
|
||||
// get optimal workspace size
|
||||
THLapack_(gels)('N', m, n, nrhs, THTensor_(data)(ra__), lda,
|
||||
THTensor_(data)(rb__), ldb,
|
||||
&wkopt, -1, &info);
|
||||
lwork = (int)wkopt;
|
||||
work = THTensor_(newWithSize1d)(lwork);
|
||||
THLapack_(gels)('N', m, n, nrhs, THTensor_(data)(ra__), lda,
|
||||
THTensor_(data)(rb__), ldb,
|
||||
THTensor_(data)(work), lwork, &info);
|
||||
|
||||
//printf("lwork = %d,%g\n",lwork,THTensor_(data)(work)[0]);
|
||||
if (info != 0)
|
||||
{
|
||||
THError("Lapack gels : Argument %d : illegal value", -info);
|
||||
}
|
||||
/* clean up */
|
||||
if (destroya)
|
||||
{
|
||||
if (clonea)
|
||||
{
|
||||
THTensor_(copy)(ra_,ra__);
|
||||
}
|
||||
THTensor_(free)(ra__);
|
||||
}
|
||||
if (destroyb)
|
||||
{
|
||||
if (cloneb)
|
||||
{
|
||||
THTensor_(copy)(rb_,rb__);
|
||||
}
|
||||
THTensor_(free)(rb__);
|
||||
}
|
||||
THTensor_(free)(work);
|
||||
}
|
||||
|
||||
TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a, const char *jobz, const char *uplo)
|
||||
{
|
||||
int n, lda, lwork, info;
|
||||
THTensor *work;
|
||||
real wkopt;
|
||||
|
||||
THTensor *rv__;
|
||||
|
||||
int clonea;
|
||||
int destroy;
|
||||
|
||||
if (a == NULL) /* possibly destroy the inputs */
|
||||
{
|
||||
rv__ = THTensor_(new)();
|
||||
clonea = THTensor_(lapackClone)(rv__,rv_,0);
|
||||
destroy = 1;
|
||||
}
|
||||
else /*we want to definitely clone and use ra_ and rb_ as computational space*/
|
||||
{
|
||||
clonea = THTensor_(lapackClone)(rv_,a,1);
|
||||
rv__ = rv_;
|
||||
destroy = 0;
|
||||
}
|
||||
|
||||
THArgCheck(rv__->nDimension == 2, 2, "A should be 2 dimensional");
|
||||
|
||||
n = rv__->size[0];
|
||||
lda = n;
|
||||
|
||||
THTensor_(resize1d)(re_,n);
|
||||
|
||||
// get optimal workspace size
|
||||
THLapack_(syev)(jobz[0], uplo[0], n, THTensor_(data)(rv__), lda,
|
||||
THTensor_(data)(re_), &wkopt, -1, &info);
|
||||
lwork = (int)wkopt;
|
||||
work = THTensor_(newWithSize1d)(lwork);
|
||||
THLapack_(syev)(jobz[0], uplo[0], n, THTensor_(data)(rv__), lda,
|
||||
THTensor_(data)(re_), THTensor_(data)(work), lwork, &info);
|
||||
|
||||
if (info > 0)
|
||||
{
|
||||
THError(" Lapack syev : Failed to converge. %d off-diagonal elements of an didn't converge to zero",info);
|
||||
}
|
||||
else if (info < 0)
|
||||
{
|
||||
THError("Lapack syev : Argument %d : illegal value", -info);
|
||||
}
|
||||
/* clean up */
|
||||
if (destroy)
|
||||
{
|
||||
if (clonea)
|
||||
{
|
||||
THTensor_(copy)(rv_,rv__);
|
||||
}
|
||||
THTensor_(free)(rv__);
|
||||
}
|
||||
THTensor_(free)(work);
|
||||
}
|
||||
|
||||
TH_API void THTensor_(gesvd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char* jobu)
|
||||
{
|
||||
THTensor *ra_ = THTensor_(new)();
|
||||
THTensor_(gesvd2)(ru_, rs_, rv_, ra_, a, jobu);
|
||||
THTensor_(free)(ra_);
|
||||
}
|
||||
|
||||
TH_API void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a, const char* jobu)
|
||||
{
|
||||
int k,m, n, lda, ldu, ldvt, lwork, info;
|
||||
THTensor *work;
|
||||
real wkopt;
|
||||
|
||||
THTensor *ra__;
|
||||
|
||||
int clonea;
|
||||
int destroy;
|
||||
|
||||
if (a == NULL) /* possibly destroy the inputs */
|
||||
{
|
||||
ra__ = THTensor_(new)();
|
||||
clonea = THTensor_(lapackClone)(ra__,ra_,0);
|
||||
destroy = 1;
|
||||
}
|
||||
else /*we want to definitely clone */
|
||||
{
|
||||
clonea = THTensor_(lapackClone)(ra_,a,1);
|
||||
ra__ = ra_;
|
||||
destroy = 0;
|
||||
}
|
||||
|
||||
THArgCheck(ra__->nDimension == 2, 2, "A should be 2 dimensional");
|
||||
|
||||
m = ra__->size[0];
|
||||
n = ra__->size[1];
|
||||
k = (m < n ? m : n);
|
||||
|
||||
lda = m;
|
||||
ldu = m;
|
||||
ldvt = n;
|
||||
THTensor_(resize1d)(rs_,k);
|
||||
THTensor_(resize2d)(rv_,ldvt,n);
|
||||
if (*jobu == 'A')
|
||||
{
|
||||
THTensor_(resize2d)(ru_,m,ldu);
|
||||
}
|
||||
else
|
||||
{
|
||||
THTensor_(resize2d)(ru_,k,ldu);
|
||||
}
|
||||
THTensor_(transpose)(ru_,NULL,0,1);
|
||||
THTensor_(transpose)(rv_,NULL,0,1);
|
||||
|
||||
THLapack_(gesvd)(jobu[0],jobu[0],
|
||||
m,n,THTensor_(data)(ra__),lda,
|
||||
THTensor_(data)(rs_),
|
||||
THTensor_(data)(ru_),
|
||||
ldu,
|
||||
THTensor_(data)(rv_), ldvt,
|
||||
&wkopt, -1, &info);
|
||||
lwork = (int)wkopt;
|
||||
work = THTensor_(newWithSize1d)(lwork);
|
||||
THLapack_(gesvd)(jobu[0],jobu[0],
|
||||
m,n,THTensor_(data)(ra__),lda,
|
||||
THTensor_(data)(rs_),
|
||||
THTensor_(data)(ru_),
|
||||
ldu,
|
||||
THTensor_(data)(rv_), ldvt,
|
||||
THTensor_(data)(work),lwork, &info);
|
||||
if (info > 0)
|
||||
{
|
||||
THError(" Lapack gesvd : %d superdiagonals failed to converge.",info);
|
||||
}
|
||||
else if (info < 0)
|
||||
{
|
||||
THError("Lapack gesvd : Argument %d : illegal value", -info);
|
||||
}
|
||||
|
||||
/* clean up */
|
||||
if (destroy)
|
||||
{
|
||||
if (clonea)
|
||||
{
|
||||
THTensor_(copy)(ra_,ra__);
|
||||
}
|
||||
THTensor_(free)(ra__);
|
||||
}
|
||||
THTensor_(free)(work);
|
||||
}
|
||||
|
||||
#endif
|
||||
11
lib/TH/generic/THTensorLapack.h
Normal file
11
lib/TH/generic/THTensorLapack.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensorLapack.h"
|
||||
#else
|
||||
|
||||
TH_API void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_);
|
||||
TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_);
|
||||
TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobz, const char *uplo);
|
||||
TH_API void THTensor_(gesvd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char *jobu);
|
||||
TH_API void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a, const char *jobu);
|
||||
|
||||
#endif
|
||||
1063
lib/TH/generic/THTensorMath.c
Normal file
1063
lib/TH/generic/THTensorMath.c
Normal file
File diff suppressed because it is too large
Load Diff
90
lib/TH/generic/THTensorMath.h
Normal file
90
lib/TH/generic/THTensorMath.h
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensorMath.h"
|
||||
#else
|
||||
|
||||
TH_API void THTensor_(fill)(THTensor *r_, real value);
|
||||
TH_API void THTensor_(zero)(THTensor *r_);
|
||||
|
||||
TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);
|
||||
|
||||
TH_API real THTensor_(minall)(THTensor *t);
|
||||
TH_API real THTensor_(maxall)(THTensor *t);
|
||||
TH_API accreal THTensor_(sumall)(THTensor *t);
|
||||
|
||||
TH_API void THTensor_(add)(THTensor *r_, THTensor *t, real value);
|
||||
TH_API void THTensor_(mul)(THTensor *r_, THTensor *t, real value);
|
||||
TH_API void THTensor_(div)(THTensor *r_, THTensor *t, real value);
|
||||
|
||||
TH_API void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src);
|
||||
TH_API void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src);
|
||||
TH_API void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src);
|
||||
|
||||
TH_API void THTensor_(addcmul)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2);
|
||||
TH_API void THTensor_(addcdiv)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2);
|
||||
|
||||
TH_API void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat, THTensor *vec);
|
||||
TH_API void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat1, THTensor *mat2);
|
||||
TH_API void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2);
|
||||
|
||||
TH_API long THTensor_(numel)(THTensor *t);
|
||||
TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension);
|
||||
TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension);
|
||||
TH_API void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension);
|
||||
TH_API void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension);
|
||||
TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension);
|
||||
TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension);
|
||||
TH_API void THTensor_(sign)(THTensor *r_, THTensor *t);
|
||||
TH_API accreal THTensor_(trace)(THTensor *t);
|
||||
TH_API void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension);
|
||||
|
||||
TH_API void THTensor_(zeros)(THTensor *r_, THLongStorage *size);
|
||||
TH_API void THTensor_(ones)(THTensor *r_, THLongStorage *size);
|
||||
TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k);
|
||||
TH_API void THTensor_(eye)(THTensor *r_, long n, long m);
|
||||
TH_API void THTensor_(range)(THTensor *r_, real xmin, real xmax, real step);
|
||||
TH_API void THTensor_(randperm)(THTensor *r_, long n);
|
||||
|
||||
TH_API void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size);
|
||||
TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder);
|
||||
TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, long k);
|
||||
TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, long k);
|
||||
TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension);
|
||||
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||
|
||||
TH_API void THTensor_(log)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(log1p)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(exp)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(cos)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(acos)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(cosh)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(sin)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(asin)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(sinh)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(tan)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(atan)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(tanh)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, real value);
|
||||
TH_API void THTensor_(sqrt)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(ceil)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(floor)(THTensor *r_, THTensor *t);
|
||||
TH_API void THTensor_(abs)(THTensor *r_, THTensor *t);
|
||||
|
||||
TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension);
|
||||
TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag);
|
||||
TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag);
|
||||
TH_API accreal THTensor_(norm)(THTensor *t, real value);
|
||||
TH_API accreal THTensor_(dist)(THTensor *a, THTensor *b, real value);
|
||||
|
||||
TH_API accreal THTensor_(meanall)(THTensor *self);
|
||||
TH_API accreal THTensor_(varall)(THTensor *self);
|
||||
TH_API accreal THTensor_(stdall)(THTensor *self);
|
||||
|
||||
TH_API void THTensor_(linspace)(THTensor *r_, real a, real b, long n);
|
||||
TH_API void THTensor_(logspace)(THTensor *r_, real a, real b, long n);
|
||||
TH_API void THTensor_(rand)(THTensor *r_, THLongStorage *size);
|
||||
TH_API void THTensor_(randn)(THTensor *r_, THLongStorage *size);
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
65
lib/TH/generic/THTensorRandom.c
Normal file
65
lib/TH/generic/THTensorRandom.c
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensorRandom.c"
|
||||
#else
|
||||
|
||||
TH_API void THTensor_(random)(THTensor *self)
|
||||
{
|
||||
#if defined(TH_REAL_IS_BYTE)
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (unsigned char)(THRandom_random() % (UCHAR_MAX+1)););
|
||||
#elif defined(TH_REAL_IS_CHAR)
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (char)(THRandom_random() % (CHAR_MAX+1)););
|
||||
#elif defined(TH_REAL_IS_SHORT)
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (short)(THRandom_random() % (SHRT_MAX+1)););
|
||||
#elif defined(TH_REAL_IS_INT)
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (int)(THRandom_random() % (INT_MAX+1UL)););
|
||||
#elif defined(TH_REAL_IS_LONG)
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (long)(THRandom_random() % (LONG_MAX+1UL)););
|
||||
#elif defined(TH_REAL_IS_FLOAT)
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (float)(THRandom_random() % ((1UL << FLT_MANT_DIG)+1)););
|
||||
#elif defined(TH_REAL_IS_DOUBLE)
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (float)(THRandom_random() % ((1UL << DBL_MANT_DIG)+1)););
|
||||
#else
|
||||
#error "Unknown type"
|
||||
#endif
|
||||
}
|
||||
|
||||
TH_API void THTensor_(geometric)(THTensor *self, double p)
|
||||
{
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_geometric(p););
|
||||
}
|
||||
|
||||
TH_API void THTensor_(bernoulli)(THTensor *self, double p)
|
||||
{
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_bernoulli(p););
|
||||
}
|
||||
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||
|
||||
TH_API void THTensor_(uniform)(THTensor *self, double a, double b)
|
||||
{
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_uniform(a, b););
|
||||
}
|
||||
|
||||
TH_API void THTensor_(normal)(THTensor *self, double mean, double stdv)
|
||||
{
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_normal(mean, stdv););
|
||||
}
|
||||
|
||||
TH_API void THTensor_(exponential)(THTensor *self, double lambda)
|
||||
{
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_exponential(lambda););
|
||||
}
|
||||
|
||||
TH_API void THTensor_(cauchy)(THTensor *self, double median, double sigma)
|
||||
{
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_cauchy(median, sigma););
|
||||
}
|
||||
|
||||
TH_API void THTensor_(logNormal)(THTensor *self, double mean, double stdv)
|
||||
{
|
||||
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_logNormal(mean, stdv););
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
17
lib/TH/generic/THTensorRandom.h
Normal file
17
lib/TH/generic/THTensorRandom.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THTensorRandom.h"
|
||||
#else
|
||||
|
||||
TH_API void THTensor_(random)(THTensor *self);
|
||||
TH_API void THTensor_(geometric)(THTensor *self, double p);
|
||||
TH_API void THTensor_(bernoulli)(THTensor *self, double p);
|
||||
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
|
||||
TH_API void THTensor_(uniform)(THTensor *self, double a, double b);
|
||||
TH_API void THTensor_(normal)(THTensor *self, double mean, double stdv);
|
||||
TH_API void THTensor_(exponential)(THTensor *self, double lambda);
|
||||
TH_API void THTensor_(cauchy)(THTensor *self, double median, double sigma);
|
||||
TH_API void THTensor_(logNormal)(THTensor *self, double mean, double stdv);
|
||||
#endif
|
||||
|
||||
#endif
|
||||
84
lib/TH/generic/THVector.c
Normal file
84
lib/TH/generic/THVector.c
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "generic/THVector.c"
|
||||
#else
|
||||
|
||||
static inline void THVector_(fill)(real *x, const real c, const long n) {
|
||||
long i = 0;
|
||||
|
||||
for(; i < n-4; i += 4)
|
||||
{
|
||||
x[i] = c;
|
||||
x[i+1] = c;
|
||||
x[i+2] = c;
|
||||
x[i+3] = c;
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
x[i] = c;
|
||||
}
|
||||
|
||||
static inline void THVector_(add)(real *y, const real *x, const real c, const long n)
|
||||
{
|
||||
long i = 0;
|
||||
|
||||
for(;i < n-4; i += 4)
|
||||
{
|
||||
y[i] += c * x[i];
|
||||
y[i+1] += c * x[i+1];
|
||||
y[i+2] += c * x[i+2];
|
||||
y[i+3] += c * x[i+3];
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
y[i] += c * x[i];
|
||||
}
|
||||
|
||||
static inline void THVector_(diff)(real *z, const real *x, const real *y, const long n)
|
||||
{
|
||||
long i = 0;
|
||||
|
||||
for(; i < n-4; i += 4)
|
||||
{
|
||||
z[i] = x[i] - y[i];
|
||||
z[i+1] + x[i+1] - y[i+1];
|
||||
z[i+2] = x[i+2] - y[i+2];
|
||||
z[i+3] = x[i+3] - y[i+3];
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
z[i] = x[i] - y[i];
|
||||
}
|
||||
|
||||
static inline void THVector_(scale)(real *y, const real c, const long n)
|
||||
{
|
||||
long i = 0;
|
||||
|
||||
for(; i < n-4; i +=4)
|
||||
{
|
||||
y[i] *= c;
|
||||
y[i+1] *= c;
|
||||
y[i+2] *= c;
|
||||
y[i+3] *= c;
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
y[i] *= c;
|
||||
}
|
||||
|
||||
static inline void THVector_(mul)(real *y, const real *x, const long n)
|
||||
{
|
||||
long i = 0;
|
||||
|
||||
for(; i < n-4; i += 4)
|
||||
{
|
||||
y[i] *= x[i];
|
||||
y[i+1] *= x[i+1];
|
||||
y[i+2] *= x[i+2];
|
||||
y[i+3] *= x[i+3];
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
y[i] *= x[i];
|
||||
}
|
||||
|
||||
#endif
|
||||
28
lib/luaT/CMakeLists.txt
Normal file
28
lib/luaT/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# -*- cmake -*-
|
||||
|
||||
FIND_PACKAGE(Lua REQUIRED)
|
||||
|
||||
INCLUDE_DIRECTORIES(${LUA_INCLUDE_DIR})
|
||||
|
||||
ADD_LIBRARY(luaT SHARED luaT.h luaT.c)
|
||||
TARGET_LINK_LIBRARIES(luaT ${LUA_LIBRARIES})
|
||||
|
||||
INSTALL(TARGETS luaT
|
||||
RUNTIME DESTINATION "${Torch_INSTALL_BIN_SUBDIR}"
|
||||
LIBRARY DESTINATION "${Torch_INSTALL_LIB_SUBDIR}"
|
||||
ARCHIVE DESTINATION "${Torch_INSTALL_LIB_SUBDIR}")
|
||||
|
||||
INSTALL(FILES luaT.h
|
||||
DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}")
|
||||
|
||||
# Create luaT.cmake
|
||||
GET_TARGET_PROPERTY(LUAT_OUTPUT_NAME luaT LOCATION)
|
||||
GET_FILENAME_COMPONENT(LUAT_OUTPUT_NAME ${LUAT_OUTPUT_NAME} NAME)
|
||||
SET(LUAT_LIBRARIES "${Torch_INSTALL_LIB}/${LUAT_OUTPUT_NAME}")
|
||||
SET(LUAT_INCLUDE_DIR "${Torch_INSTALL_INCLUDE}")
|
||||
CONFIGURE_FILE(luaTConfig.cmake.in "${Torch_BINARY_DIR}/cmake-external/luaTConfig.cmake")
|
||||
INSTALL(FILES "${Torch_BINARY_DIR}/cmake-external/luaTConfig.cmake"
|
||||
DESTINATION "${Torch_INSTALL_CMAKE_SUBDIR}")
|
||||
|
||||
# luaT help
|
||||
ADD_TORCH_DOK(dok luaT "Torch C Libraries" "luaT" 5.1)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user