mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* _forward is renamed forward since users should override it * some __call__ overrides are changed to forward * function which return a single variable are changed to return that variable instead of a one-element tuple
83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
from collections import OrderedDict
|
|
|
|
import torch
|
|
from ..backends.thnn import backend as thnn_backend
|
|
from torch.autograd import Variable
|
|
|
|
|
|
class Module(object):
|
|
|
|
def __init__(self):
|
|
self._backend = thnn_backend
|
|
self.backward_hooks = OrderedDict()
|
|
self.forward_hooks = OrderedDict()
|
|
self.train = True
|
|
|
|
def forward(self, *input):
|
|
raise NotImplementedError
|
|
|
|
def type(self, type, *forwarded_args):
|
|
# Find all tensors and convert them
|
|
for key, value in self.__dict__.items():
|
|
if isinstance(value, Variable):
|
|
# Variables stored in modules are graph leaves,
|
|
# and we don't want to create copy nodes.
|
|
value._data = value.data.type(type, *forwarded_args)
|
|
elif torch.isTensor(value):
|
|
setattr(self, key, value.type(type, *forwarded_args))
|
|
elif isinstance(value, Module):
|
|
value.type(type, *forwarded_args)
|
|
return self
|
|
|
|
def cuda(self, device_id=None):
|
|
import torch.cuda
|
|
if device_id is not None:
|
|
return self.type(torch.cuda.FloatTensor, device_id)
|
|
else:
|
|
return self.type(torch.cuda.FloatTensor)
|
|
|
|
def float(self):
|
|
return self.type(torch.FloatTensor)
|
|
|
|
def double(self):
|
|
return self.type(torch.DoubleTensor)
|
|
|
|
def register_backward_hook(self, name, hook):
|
|
assert name not in self.backward_hooks, \
|
|
"Trying to register a second backward hook with name {}".format(name)
|
|
self.backward_hooks[name] = hook
|
|
|
|
def remove_backward_hook(self, name):
|
|
assert name in self.backward_hooks, \
|
|
"Trying to remove an inexistent backward hook with name {}".format(name)
|
|
del self.backward_hooks[name]
|
|
|
|
def register_forward_hook(self, name, hook):
|
|
assert name not in self.forward_hooks, \
|
|
"Trying to register a second forward hook with name {}".format(name)
|
|
self.forward_hooks[name] = hook
|
|
|
|
def remove_forward_hook(self, name):
|
|
assert name in self.forward_hooks, \
|
|
"Trying to remove an inexistent forward hook with name {}".format(name)
|
|
del self.forward_hooks[name]
|
|
|
|
def __call__(self, *input):
|
|
result = self.forward(*input)
|
|
for hook in self.forward_hooks.values():
|
|
hook(self, input, result)
|
|
fn = result.creator
|
|
for key, hook in self.backward_hooks.items():
|
|
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
|
|
return result
|
|
|
|
def parameters(self):
|
|
if hasattr(self, 'weight') and self.weight is not None:
|
|
yield self.weight
|
|
if hasattr(self, 'bias') and self.bias is not None:
|
|
yield self.bias
|
|
|
|
def zero_grad(self):
|
|
for p in self.parameters():
|
|
p.grad.zero_()
|