mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
For example: self.linear = nn.Linear(10, 20) self.weight = torch.autograd.Variable(torch.Tensor(10, 20))
157 lines
5.5 KiB
Python
157 lines
5.5 KiB
Python
from itertools import chain
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from ..backends.thnn import backend as thnn_backend
|
|
from torch.autograd import Variable
|
|
|
|
|
|
class Module(object):
|
|
|
|
def __init__(self, **parameters):
|
|
self._backend = thnn_backend
|
|
self._parameters = OrderedDict(parameters)
|
|
self._buffers = {}
|
|
self.backward_hooks = OrderedDict()
|
|
self.forward_hooks = OrderedDict()
|
|
self.train = True
|
|
for name, param in self._parameters.items():
|
|
if param is not None and not isinstance(param, Variable):
|
|
param = Variable(param, requires_grad=True)
|
|
self._parameters[name] = param
|
|
|
|
def forward(self, *input):
|
|
raise NotImplementedError
|
|
|
|
def register_buffer(self, name, tensor):
|
|
self._buffers[name] = tensor
|
|
|
|
def _apply(self, fn):
|
|
for param in self._parameters.values():
|
|
if param is not None:
|
|
# Variables stored in modules are graph leaves, and we don't
|
|
# want to create copy nodes, so we have to unpack the data.
|
|
param.data = fn(param.data)
|
|
for key, buf in self._buffers.items():
|
|
if buf is not None:
|
|
self._buffers[key] = fn(buf)
|
|
return self
|
|
|
|
def cuda(self, device_id=None):
|
|
return self._apply(lambda t: t.cuda(device_id))
|
|
|
|
def cpu(self, device_id=None):
|
|
return self._apply(lambda t: t.cpu())
|
|
|
|
def float(self):
|
|
return self._apply(lambda t: t.float())
|
|
|
|
def double(self):
|
|
return self._apply(lambda t: t.double())
|
|
|
|
def register_backward_hook(self, name, hook):
|
|
assert name not in self.backward_hooks, \
|
|
"Trying to register a second backward hook with name {}".format(name)
|
|
self.backward_hooks[name] = hook
|
|
|
|
def remove_backward_hook(self, name):
|
|
assert name in self.backward_hooks, \
|
|
"Trying to remove an inexistent backward hook with name {}".format(name)
|
|
del self.backward_hooks[name]
|
|
|
|
def register_forward_hook(self, name, hook):
|
|
assert name not in self.forward_hooks, \
|
|
"Trying to register a second forward hook with name {}".format(name)
|
|
self.forward_hooks[name] = hook
|
|
|
|
def remove_forward_hook(self, name):
|
|
assert name in self.forward_hooks, \
|
|
"Trying to remove an inexistent forward hook with name {}".format(name)
|
|
del self.forward_hooks[name]
|
|
|
|
def __call__(self, *input):
|
|
result = self.forward(*input)
|
|
for hook in self.forward_hooks.values():
|
|
hook(self, input, result)
|
|
var = result
|
|
while not isinstance(var, Variable):
|
|
var= var[0]
|
|
creator = var.creator
|
|
for key, hook in self.backward_hooks.items():
|
|
creator.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
|
|
return result
|
|
|
|
def __getattr__(self, name):
|
|
if '_parameters' in self.__dict__:
|
|
_parameters = self.__dict__['_parameters']
|
|
if name in _parameters:
|
|
return _parameters[name]
|
|
if '_buffers' in self.__dict__:
|
|
_buffers = self.__dict__['_buffers']
|
|
if name in _buffers:
|
|
return _buffers[name]
|
|
return object.__getattribute__(self, name)
|
|
|
|
def __setattr__(self, name, value):
|
|
_parameters = self.__dict__.get('_parameters')
|
|
if isinstance(value, Variable):
|
|
if _parameters is None:
|
|
raise AttributeError(
|
|
"cannot assign parameter before Module.__init__() call")
|
|
if value.creator:
|
|
raise ValueError(
|
|
"Cannot assign non-leaf Variable to parameter '{0}'. Model "
|
|
"parameters must be created explicitly. To express '{0}' "
|
|
"as a function of another variable, compute the value in "
|
|
"the forward() method.".format(name))
|
|
_parameters[name] = value
|
|
elif _parameters and name in _parameters:
|
|
if value is not None:
|
|
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
|
"(torch.autograd.Variable or None required)"
|
|
.format(torch.typename(value), name))
|
|
_parameters[name] = value
|
|
else:
|
|
object.__setattr__(self, name, value)
|
|
|
|
def __delattr__(self, name):
|
|
if name in self._parameters:
|
|
del self._parameters[name]
|
|
else:
|
|
object.__delattr__(self, name)
|
|
|
|
def parameter_dict(self, destination=None, prefix=''):
|
|
if destination is None:
|
|
destination = OrderedDict()
|
|
for name, param in self._parameters.items():
|
|
if param is not None:
|
|
destination[prefix + name] = param
|
|
return destination
|
|
|
|
def load_parameter_dict(self, param_dict):
|
|
for name, param in self._parameters.items():
|
|
self._parameters[name] = param_dict.get(name, param)
|
|
|
|
def parameters(self, memo=None):
|
|
if memo is None:
|
|
memo = set()
|
|
for p in self._parameters.values():
|
|
if p is not None and p not in memo:
|
|
memo.add(p)
|
|
yield p
|
|
|
|
def children(self):
|
|
if False:
|
|
yield
|
|
|
|
def modules(self, memo=None):
|
|
if memo is None:
|
|
memo = set()
|
|
if self not in memo:
|
|
memo.add(self)
|
|
yield self
|
|
|
|
def zero_grad(self):
|
|
for p in self.parameters():
|
|
p.grad.zero_()
|