pytorch/TensorConvWrap.lua
2012-01-25 14:55:20 +01:00

128 lines
3.5 KiB
Lua

--
-- require 'wrap'
---
interface = wrap.CInterface.new()
interface.dispatchregistry = {}
function interface:wrap(name, ...)
-- usual stuff
--wrap.CInterface.wrap(self, name, ...)
-- dispatch function
if not interface.dispatchregistry[name] then
interface.dispatchregistry[name] = true
table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)})
interface:print(string.gsub([[
static int torch_NAME(lua_State *L)
{
int narg = lua_gettop(L);
const void *id;
if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */
{
if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */
{
if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */
lua_pop(L, 1);
else if(!(id = torch_istensorid(L, torch_getdefaulttensorid())))
luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor");
}
}
lua_pushstring(L, "NAME");
lua_rawget(L, -2);
if(lua_isfunction(L, -1))
{
lua_insert(L, 1);
lua_pop(L, 2); /* the two tables we put on the stack above */
lua_call(L, lua_gettop(L)-1, LUA_MULTRET);
}
else
return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id));
return lua_gettop(L);
}
]], 'NAME', name))
end
end
function interface:dispatchregister(name)
local txt = self.txt
table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
for _,reg in ipairs(self.dispatchregistry) do
table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname))
end
table.insert(txt, '{NULL, NULL}')
table.insert(txt, '};')
table.insert(txt, '')
self.dispatchregistry = {}
end
interface:print('/* WARNING: autogenerated file */')
interface:print('')
local reals = {ByteTensor='byte',
CharTensor='char',
ShortTensor='short',
IntTensor='int',
LongTensor='long',
FloatTensor='float',
DoubleTensor='double'}
for _,Tensor in ipairs({"FloatTensor", "DoubleTensor", "IntTensor", "LongTensor", "ByteTensor", "CharTensor","ShortTensor"}) do
local real = reals[Tensor]
function interface.luaname2wrapname(self, name)
return string.format('torch_%s_%s', Tensor, name)
end
local function cname(name)
return string.format('TH%s_%s', Tensor, name)
end
local function lastdim(argn)
return function(arg)
return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg())
end
end
for _,name in ipairs({"conv2","xcorr2","conv3","xcorr3"}) do
interface:wrap(name,
cname(name),
{{name=Tensor, default=true, returned=true},
{name=Tensor, default=true, returned=true},
{name=Tensor},
{name=Tensor}}
)
end
--interface:register(string.format("torch_%sLapack__", Tensor))
-- interface:print(string.gsub([[
-- static void torch_TensorLapack_init(lua_State *L)
-- {
-- torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor");
-- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
-- luaT_pushmetaclass(L, torch_Tensor_id);
-- lua_getfield(L,-1,"torch");
-- luaL_register(L, NULL, torch_TensorLapack__);
-- lua_pop(L, 2);
-- }
-- ]], 'Tensor', Tensor))
end
interface:dispatchregister("torch_TensorConv__")
if arg[1] then
interface:tofile(arg[1])
else
interface:tostdio()
end