pytorch/torch/nn/modules/module.py
Adam Lerer 1213149a2f add bias option to linear; allow modules to return nested lists/tuples of tensors (#106)
* add bias option to linear; allow modules to return nested lists/tuples of tensors
2016-10-06 15:59:12 -04:00

146 lines
5.2 KiB
Python

from itertools import chain
from collections import OrderedDict
import torch
from ..backends.thnn import backend as thnn_backend
from torch.autograd import Variable
class Module(object):
def __init__(self, **parameters):
self._backend = thnn_backend
self._parameters = OrderedDict(parameters)
self._buffers = {}
self.backward_hooks = OrderedDict()
self.forward_hooks = OrderedDict()
self.train = True
for name, param in self._parameters.items():
if param is not None and not isinstance(param, Variable):
param = Variable(param, requires_grad=True)
self._parameters[name] = param
def forward(self, *input):
raise NotImplementedError
def register_buffer(self, name, tensor):
self._buffers[name] = tensor
def _apply(self, fn):
for param in self._parameters.values():
if param is not None:
# Variables stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param._data = fn(param.data)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def cuda(self, device_id=None):
return self._apply(lambda t: t.cuda(device_id))
def cpu(self, device_id=None):
return self._apply(lambda t: t.cpu())
def float(self):
return self._apply(lambda t: t.float())
def double(self):
return self._apply(lambda t: t.double())
def register_backward_hook(self, name, hook):
assert name not in self.backward_hooks, \
"Trying to register a second backward hook with name {}".format(name)
self.backward_hooks[name] = hook
def remove_backward_hook(self, name):
assert name in self.backward_hooks, \
"Trying to remove an inexistent backward hook with name {}".format(name)
del self.backward_hooks[name]
def register_forward_hook(self, name, hook):
assert name not in self.forward_hooks, \
"Trying to register a second forward hook with name {}".format(name)
self.forward_hooks[name] = hook
def remove_forward_hook(self, name):
assert name in self.forward_hooks, \
"Trying to remove an inexistent forward hook with name {}".format(name)
del self.forward_hooks[name]
def __call__(self, *input):
result = self.forward(*input)
for hook in self.forward_hooks.values():
hook(self, input, result)
var = result
while not isinstance(var, Variable):
var= var[0]
creator = var.creator
for key, hook in self.backward_hooks.items():
creator.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
return result
def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
return object.__getattribute__(self, name)
def __setattr__(self, name, value):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
if not isinstance(value, Variable):
raise RuntimeError(("assiging a {} object as parameter {} "
"- you can only assign Variables as model"
"parameters").format(type(value), name))
if value.creator:
raise RuntimeError(("All parameters should be leaf "
"variables - they should be created explicitly, "
"not as a result of computation on other "
"variables. If you want to express {} as a "
"function of another variable, simply repeat the "
"computation at every forward pass.").format(name))
return object.__setattr__(self, name, value)
def parameter_dict(self, destination=None, prefix=''):
if destination is None:
destination = OrderedDict()
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param
return destination
def load_parameter_dict(self, param_dict):
for name, param in self._parameters.items():
self._parameters[name] = param_dict.get(name, param)
def parameters(self, memo=None):
if memo is None:
memo = set()
for p in self._parameters.values():
if p is not None and p not in memo:
memo.add(p)
yield p
def children(self):
if False:
yield
def modules(self, memo=None):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield self
def zero_grad(self):
for p in self.parameters():
p.grad.zero_()