mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
294 lines
11 KiB
Python
294 lines
11 KiB
Python
from itertools import chain
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from ..backends.thnn import backend as thnn_backend
|
|
from ..parameter import Parameter
|
|
from torch.autograd import Variable
|
|
|
|
|
|
class Module(object):
|
|
"""Base class for all Modules defined in the nn package.
|
|
|
|
Even the Container class derives from it.
|
|
"""
|
|
def __init__(self):
|
|
self._backend = thnn_backend
|
|
self._parameters = OrderedDict()
|
|
self._buffers = OrderedDict()
|
|
self._backward_hooks = OrderedDict()
|
|
self._forward_hooks = OrderedDict()
|
|
self.training = True
|
|
for name, param in self._parameters.items():
|
|
if not isinstance(param, Parameter):
|
|
if isinstance(param, Variable):
|
|
raise TypeError("can't use a Variable as a module "
|
|
"parameter. Convert it to torch.nn.Parameter first.")
|
|
if param is not None:
|
|
param = Parameter(param)
|
|
self._parameters[name] = param
|
|
|
|
def forward(self, *input):
|
|
"""Defines the computation performed at every call.
|
|
|
|
Should be overriden by all subclasses.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def register_buffer(self, name, tensor):
|
|
"""Adds a persistent buffer to the module.
|
|
|
|
This is typically used to register a buffer that should not to be
|
|
considered a model parameter. For example, BatchNorm's ``running_mean``
|
|
is not a parameter, but is part of the persistent state.
|
|
|
|
Buffers can be accessed as attributes using given names.
|
|
|
|
Example:
|
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
|
"""
|
|
self._buffers[name] = tensor
|
|
|
|
def register_parameter(self, name, param):
|
|
"""Adds a parameter to the module.
|
|
|
|
The parameter can be accessed as an attribute using given name.
|
|
"""
|
|
if '_parameters' not in self.__dict__:
|
|
raise AttributeError(
|
|
"cannot assign parameter before Module.__init__() call")
|
|
if param is None:
|
|
self._parameters[name] = None
|
|
elif not isinstance(param, Parameter):
|
|
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
|
"(torch.nn.Parameter or None required)"
|
|
.format(torch.typename(param), name))
|
|
elif param.creator:
|
|
raise ValueError(
|
|
"Cannot assign non-leaf Variable to parameter '{0}'. Model "
|
|
"parameters must be created explicitly. To express '{0}' "
|
|
"as a function of another variable, compute the value in "
|
|
"the forward() method.".format(name))
|
|
else:
|
|
self._parameters[name] = param
|
|
|
|
def _apply(self, fn):
|
|
for param in self._parameters.values():
|
|
if param is not None:
|
|
# Variables stored in modules are graph leaves, and we don't
|
|
# want to create copy nodes, so we have to unpack the data.
|
|
param.data = fn(param.data)
|
|
if param.grad is not None:
|
|
param._grad = fn(param._grad)
|
|
|
|
for key, buf in self._buffers.items():
|
|
if buf is not None:
|
|
self._buffers[key] = fn(buf)
|
|
return self
|
|
|
|
def apply(self, fn):
|
|
fn(self)
|
|
return self
|
|
|
|
def cuda(self, device_id=None):
|
|
"""Moves all model parameters and buffers to the GPU.
|
|
|
|
Arguments:
|
|
device_id (int, optional): if specified, all parameters will be
|
|
copied to that device
|
|
"""
|
|
return self._apply(lambda t: t.cuda(device_id))
|
|
|
|
def cpu(self, device_id=None):
|
|
"""Moves all model parameters and buffers to the CPU."""
|
|
return self._apply(lambda t: t.cpu())
|
|
|
|
def type(self, dst_type):
|
|
return self._apply(lambda t: t.type(dst_type))
|
|
|
|
def float(self):
|
|
"""Casts all parameters and buffers to float datatype."""
|
|
return self._apply(lambda t: t.float())
|
|
|
|
def double(self):
|
|
"""Casts all parameters and buffers to double datatype."""
|
|
return self._apply(lambda t: t.double())
|
|
|
|
def half(self):
|
|
"""Casts all parameters and buffers to half datatype."""
|
|
return self._apply(lambda t: t.half())
|
|
|
|
def register_backward_hook(self, name, hook):
|
|
"""Registers a backward hook on the module, under a given name.
|
|
|
|
The hook will be called every time the gradient w.r.t. module inputs
|
|
is computed. The callable should accept two arguments - gradient w.r.t.
|
|
the input and gradient w.r.t. the output, where both arguments can be
|
|
tuples if the module had multiple inputs or outputs.
|
|
The hook should never modify its arguments in-place, but it can
|
|
optionally return a new gradient w.r.t. the input, that will be used
|
|
in subsequent computation.
|
|
"""
|
|
assert name not in self._backward_hooks, \
|
|
"Trying to register a second backward hook with name {}".format(name)
|
|
self._backward_hooks[name] = lambda gi, go: hook(self, gi, go)
|
|
|
|
def remove_backward_hook(self, name):
|
|
"""Removes a backward hook with a given name.
|
|
|
|
If no such hook exists, a RuntimeError is raised.
|
|
"""
|
|
assert name in self._backward_hooks, \
|
|
"Trying to remove an inexistent backward hook with name {}".format(name)
|
|
del self._backward_hooks[name]
|
|
|
|
def register_forward_hook(self, name, hook):
|
|
"""Registers a forward hook on the module, under a given name.
|
|
|
|
The hook will be called every time :func:`forward` computes an output.
|
|
The callable should accept two arguments - module's input and output.
|
|
Both should not be modified by the hook.
|
|
"""
|
|
assert name not in self._forward_hooks, \
|
|
"Trying to register a second forward hook with name {}".format(name)
|
|
self._forward_hooks[name] = hook
|
|
|
|
def remove_forward_hook(self, name):
|
|
"""Removes a forward hook with a given name.
|
|
|
|
If no such hook exists, a RuntimeError is raised.
|
|
"""
|
|
assert name in self._forward_hooks, \
|
|
"Trying to remove an inexistent forward hook with name {}".format(name)
|
|
del self._forward_hooks[name]
|
|
|
|
def __call__(self, *input, **kwargs):
|
|
result = self.forward(*input, **kwargs)
|
|
for name, hook in self._forward_hooks.items():
|
|
hook_result = hook(self, input, result)
|
|
if hook_result is not None:
|
|
raise RuntimeError("forward hooks should never return any "
|
|
"values, but '{}' didn't return None".format(name))
|
|
var = result
|
|
while not isinstance(var, Variable):
|
|
var = var[0]
|
|
creator = var.creator
|
|
if creator is not None:
|
|
creator._backward_hooks = self._backward_hooks
|
|
return result
|
|
|
|
def __getattr__(self, name):
|
|
if '_parameters' in self.__dict__:
|
|
_parameters = self.__dict__['_parameters']
|
|
if name in _parameters:
|
|
return _parameters[name]
|
|
if '_buffers' in self.__dict__:
|
|
_buffers = self.__dict__['_buffers']
|
|
if name in _buffers:
|
|
return _buffers[name]
|
|
return object.__getattribute__(self, name)
|
|
|
|
def __setattr__(self, name, value):
|
|
params = self.__dict__.get('_parameters')
|
|
if isinstance(value, Parameter) or (params and name in params):
|
|
self.register_parameter(name, value)
|
|
else:
|
|
object.__setattr__(self, name, value)
|
|
|
|
def __delattr__(self, name):
|
|
if name in self._parameters:
|
|
del self._parameters[name]
|
|
else:
|
|
object.__delattr__(self, name)
|
|
|
|
def state_dict(self, destination=None, prefix=''):
|
|
"""Returns a dictionary containing a whole state of the module.
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are
|
|
included. Keys are corresponding parameter and buffer names.
|
|
|
|
Example:
|
|
>>> print(module.state_dict().keys())
|
|
['bias', 'weight']
|
|
"""
|
|
if destination is None:
|
|
destination = OrderedDict()
|
|
for name, param in chain(self._buffers.items(), self._parameters.items()):
|
|
if param is not None:
|
|
destination[prefix + name] = param
|
|
return destination
|
|
|
|
def load_state_dict(self, state_dict, prefix=''):
|
|
"""Replaces module parameters using values from a given state_dict.
|
|
|
|
This will load all values from the state dict (including such that
|
|
weren't registered before loading).
|
|
|
|
Arguments:
|
|
state_dict (dict): A dict containing loaded parameters and
|
|
persistent buffers.
|
|
"""
|
|
for name, param in self._parameters.items():
|
|
new_param = state_dict.get(prefix + name, param)
|
|
if not isinstance(new_param, Parameter) and new_param is not None:
|
|
raise TypeError(
|
|
"expected torch.autograd.Parameter for key '{}' (got {})"
|
|
.format(prefix + name, torch.typename(new_param)))
|
|
self._parameters[name] = new_param
|
|
for name, buf in self._buffers.items():
|
|
self._buffers[name] = state_dict.get(prefix + name, buf)
|
|
|
|
def parameters(self, memo=None):
|
|
"""Returns an iterator over module parameters.
|
|
|
|
This is typically passed to an optimizer.
|
|
|
|
Example:
|
|
>>> for param in model.parameters():
|
|
>>> print(type(param.data), param.size())
|
|
<class 'torch.FloatTensor'> (20L,)
|
|
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
|
|
"""
|
|
if memo is None:
|
|
memo = set()
|
|
for p in self._parameters.values():
|
|
if p is not None and p not in memo:
|
|
memo.add(p)
|
|
yield p
|
|
|
|
def children(self):
|
|
"""Returns an iterator over children modules."""
|
|
if False:
|
|
yield
|
|
|
|
def modules(self, memo=None):
|
|
if memo is None:
|
|
memo = set()
|
|
if self not in memo:
|
|
memo.add(self)
|
|
yield self
|
|
|
|
def train(self):
|
|
"""Sets the module in training mode.
|
|
|
|
This has any effect only on modules such as Dropout or BatchNorm.
|
|
"""
|
|
self.training = True
|
|
return self
|
|
|
|
def eval(self):
|
|
"""Sets the module in evaluation mode.
|
|
|
|
This has any effect only on modules such as Dropout or BatchNorm.
|
|
"""
|
|
self.training = False
|
|
return self
|
|
|
|
def zero_grad(self):
|
|
"""Sets gradients of all model parameters to zero."""
|
|
for p in self.parameters():
|
|
p.grad.zero_()
|
|
|
|
def share_memory(self):
|
|
return self._apply(lambda t: t.share_memory_())
|