mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
133 lines
3.7 KiB
Lua
133 lines
3.7 KiB
Lua
--
|
|
-- require 'wrap'
|
|
---
|
|
|
|
interface = wrap.CInterface.new()
|
|
|
|
|
|
interface.dispatchregistry = {}
|
|
function interface:wrap(name, ...)
|
|
-- usual stuff
|
|
wrap.CInterface.wrap(self, name, ...)
|
|
|
|
-- dispatch function
|
|
if not interface.dispatchregistry[name] then
|
|
interface.dispatchregistry[name] = true
|
|
table.insert(interface.dispatchregistry, {name=name, wrapname=string.format("torch_%s", name)})
|
|
|
|
interface:print(string.gsub([[
|
|
static int torch_NAME(lua_State *L)
|
|
{
|
|
int narg = lua_gettop(L);
|
|
const void *id;
|
|
|
|
if(narg < 1 || !(id = torch_istensorid(L, luaT_id(L, 1)))) /* first argument is tensor? */
|
|
{
|
|
if(narg < 2 || !(id = torch_istensorid(L, luaT_id(L, 2)))) /* second? */
|
|
{
|
|
if(lua_isstring(L, -1) && (id = torch_istensorid(L, luaT_typename2id(L, lua_tostring(L, -1))))) /* do we have a valid string then? */
|
|
lua_pop(L, 1);
|
|
else if(!(id = torch_istensorid(L, torch_getdefaulttensorid())))
|
|
luaL_error(L, "internal error: the default tensor type does not seem to be an actual tensor");
|
|
}
|
|
}
|
|
|
|
lua_pushstring(L, "NAME");
|
|
lua_rawget(L, -2);
|
|
if(lua_isfunction(L, -1))
|
|
{
|
|
lua_insert(L, 1);
|
|
lua_pop(L, 2); /* the two tables we put on the stack above */
|
|
lua_call(L, lua_gettop(L)-1, LUA_MULTRET);
|
|
}
|
|
else
|
|
return luaL_error(L, "%s does not implement the torch.NAME() function", luaT_id2typename(L, id));
|
|
|
|
return lua_gettop(L);
|
|
}
|
|
]], 'NAME', name))
|
|
end
|
|
end
|
|
|
|
function interface:dispatchregister(name)
|
|
local txt = self.txt
|
|
table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
|
|
for _,reg in ipairs(self.dispatchregistry) do
|
|
table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname))
|
|
end
|
|
table.insert(txt, '{NULL, NULL}')
|
|
table.insert(txt, '};')
|
|
table.insert(txt, '')
|
|
self.dispatchregistry = {}
|
|
end
|
|
|
|
interface:print('/* WARNING: autogenerated file */')
|
|
interface:print('')
|
|
|
|
local reals = {ByteTensor='byte',
|
|
CharTensor='char',
|
|
ShortTensor='short',
|
|
IntTensor='int',
|
|
LongTensor='long',
|
|
FloatTensor='float',
|
|
DoubleTensor='double'}
|
|
|
|
for _,Tensor in ipairs({"FloatTensor", "DoubleTensor"}) do
|
|
|
|
local real = reals[Tensor]
|
|
|
|
function interface.luaname2wrapname(self, name)
|
|
return string.format('torch_%s_%s', Tensor, name)
|
|
end
|
|
|
|
local function cname(name)
|
|
return string.format('TH%s_%s', Tensor, name)
|
|
end
|
|
|
|
local function lastdim(argn)
|
|
return function(arg)
|
|
return string.format("TH%s_nDimension(%s)", Tensor, arg.args[argn]:carg())
|
|
end
|
|
end
|
|
|
|
|
|
-- for _,name in ipairs({"gesv","gels","eig","svd"}) do
|
|
-- interface:wrap(name,
|
|
-- cname(name),
|
|
-- {{name=Tensor, returned=true},
|
|
-- {name=Tensor, returned=true},
|
|
-- {name=Tensor},
|
|
-- {name=Tensor}},
|
|
-- cname(name),
|
|
-- {{name=Tensor, default=true, returned=true, invisible=true},
|
|
-- {name=Tensor, default=true, returned=true, invisible=true},
|
|
-- {name=Tensor},
|
|
-- {name=Tensor}}
|
|
-- )
|
|
-- end
|
|
|
|
|
|
--interface:register(string.format("torch_%sLapack__", Tensor))
|
|
|
|
-- interface:print(string.gsub([[
|
|
-- static void torch_TensorLapack_init(lua_State *L)
|
|
-- {
|
|
-- torch_Tensor_id = luaT_checktypename2id(L, "torch.Tensor");
|
|
-- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
|
|
|
|
-- luaT_pushmetaclass(L, torch_Tensor_id);
|
|
-- lua_getfield(L,-1,"torch");
|
|
-- luaL_register(L, NULL, torch_TensorLapack__);
|
|
-- lua_pop(L, 2);
|
|
-- }
|
|
-- ]], 'Tensor', Tensor))
|
|
end
|
|
|
|
interface:dispatchregister("torch_TensorLapack__")
|
|
|
|
if arg[1] then
|
|
interface:tofile(arg[1])
|
|
else
|
|
interface:tostdio()
|
|
end
|