mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
a module that returns a non-standard data structure currently breaks due to checks for backwards hooks. This refactors the code slightly so this will only break in the event of backwards hooks.
492 lines
18 KiB
Python
492 lines
18 KiB
Python
from itertools import chain
|
|
from collections import OrderedDict
|
|
import functools
|
|
|
|
import torch
|
|
from ..backends.thnn import backend as thnn_backend
|
|
from ..parameter import Parameter
|
|
from torch.autograd import Variable
|
|
import torch.utils.hooks as hooks
|
|
|
|
|
|
def _addindent(s_, numSpaces):
|
|
s = s_.split('\n')
|
|
# dont do anything for single-line stuff
|
|
if len(s) == 1:
|
|
return s_
|
|
first = s.pop(0)
|
|
s = [(numSpaces * ' ') + line for line in s]
|
|
s = '\n'.join(s)
|
|
s = first + '\n' + s
|
|
return s
|
|
|
|
|
|
class Module(object):
|
|
"""Base class for all neural network modules.
|
|
|
|
Your models should also subclass this class.
|
|
|
|
Modules can also contain other Modules, allowing to nest them in
|
|
a tree structure. You can assign the submodules as regular attributes::
|
|
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 20, 5)
|
|
self.conv2 = nn.Conv2d(20, 20, 5)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
return F.relu(self.conv2(x))
|
|
|
|
Submodules assigned in this way will be registered, and will have their
|
|
parameters converted too when you call .cuda(), etc.
|
|
"""
|
|
|
|
dump_patches = False
|
|
|
|
def __init__(self):
|
|
self._backend = thnn_backend
|
|
self._parameters = OrderedDict()
|
|
self._buffers = OrderedDict()
|
|
self._backward_hooks = OrderedDict()
|
|
self._forward_hooks = OrderedDict()
|
|
self._modules = OrderedDict()
|
|
self.training = True
|
|
|
|
def forward(self, *input):
|
|
"""Defines the computation performed at every call.
|
|
|
|
Should be overriden by all subclasses.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def register_buffer(self, name, tensor):
|
|
"""Adds a persistent buffer to the module.
|
|
|
|
This is typically used to register a buffer that should not to be
|
|
considered a model parameter. For example, BatchNorm's ``running_mean``
|
|
is not a parameter, but is part of the persistent state.
|
|
|
|
Buffers can be accessed as attributes using given names.
|
|
|
|
Example:
|
|
>>> self.register_buffer('running_mean', torch.zeros(num_features))
|
|
"""
|
|
self._buffers[name] = tensor
|
|
|
|
def register_parameter(self, name, param):
|
|
"""Adds a parameter to the module.
|
|
|
|
The parameter can be accessed as an attribute using given name.
|
|
"""
|
|
if '_parameters' not in self.__dict__:
|
|
raise AttributeError(
|
|
"cannot assign parameter before Module.__init__() call")
|
|
if param is None:
|
|
self._parameters[name] = None
|
|
elif not isinstance(param, Parameter):
|
|
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
|
"(torch.nn.Parameter or None required)"
|
|
.format(torch.typename(param), name))
|
|
elif param.grad_fn:
|
|
raise ValueError(
|
|
"Cannot assign non-leaf Variable to parameter '{0}'. Model "
|
|
"parameters must be created explicitly. To express '{0}' "
|
|
"as a function of another variable, compute the value in "
|
|
"the forward() method.".format(name))
|
|
else:
|
|
self._parameters[name] = param
|
|
|
|
def add_module(self, name, module):
|
|
"""Adds a child module to the current module.
|
|
|
|
The module can be accessed as an attribute using the given name.
|
|
"""
|
|
if hasattr(self, name):
|
|
raise KeyError("attribute already exists '{}'".format(name))
|
|
if not isinstance(module, Module) and module is not None:
|
|
raise TypeError("{} is not a Module subclass".format(
|
|
torch.typename(module)))
|
|
self._modules[name] = module
|
|
|
|
def _apply(self, fn):
|
|
for module in self.children():
|
|
module._apply(fn)
|
|
|
|
for param in self._parameters.values():
|
|
if param is not None:
|
|
# Variables stored in modules are graph leaves, and we don't
|
|
# want to create copy nodes, so we have to unpack the data.
|
|
param.data = fn(param.data)
|
|
if param._grad is not None:
|
|
param._grad.data = fn(param._grad.data)
|
|
|
|
for key, buf in self._buffers.items():
|
|
if buf is not None:
|
|
self._buffers[key] = fn(buf)
|
|
|
|
return self
|
|
|
|
def apply(self, fn):
|
|
for module in self.children():
|
|
module.apply(fn)
|
|
fn(self)
|
|
return self
|
|
|
|
def cuda(self, device_id=None):
|
|
"""Moves all model parameters and buffers to the GPU.
|
|
|
|
Arguments:
|
|
device_id (int, optional): if specified, all parameters will be
|
|
copied to that device
|
|
"""
|
|
return self._apply(lambda t: t.cuda(device_id))
|
|
|
|
def cpu(self, device_id=None):
|
|
"""Moves all model parameters and buffers to the CPU."""
|
|
return self._apply(lambda t: t.cpu())
|
|
|
|
def type(self, dst_type):
|
|
return self._apply(lambda t: t.type(dst_type))
|
|
|
|
def float(self):
|
|
"""Casts all parameters and buffers to float datatype."""
|
|
return self._apply(lambda t: t.float())
|
|
|
|
def double(self):
|
|
"""Casts all parameters and buffers to double datatype."""
|
|
return self._apply(lambda t: t.double())
|
|
|
|
def half(self):
|
|
"""Casts all parameters and buffers to half datatype."""
|
|
return self._apply(lambda t: t.half())
|
|
|
|
def register_backward_hook(self, hook):
|
|
"""Registers a backward hook on the module.
|
|
|
|
The hook will be called every time the gradients with respect to module
|
|
inputs are computed. The hook should have the following signature::
|
|
|
|
hook(module, grad_input, grad_output) -> Tensor or None
|
|
|
|
The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
|
|
module has multiple inputs or outputs. The hook should not modify its
|
|
arguments, but it can optionally return a new gradient with respect to
|
|
input that will be used in place of :attr:`grad_input` in subsequent
|
|
computations.
|
|
|
|
This function returns a handle with a method ``handle.remove()``
|
|
that removes the hook from the module.
|
|
"""
|
|
handle = hooks.RemovableHandle(self._backward_hooks)
|
|
self._backward_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def register_forward_hook(self, hook):
|
|
"""Registers a forward hook on the module.
|
|
|
|
The hook will be called every time :func:`forward` computes an output.
|
|
It should have the following signature::
|
|
|
|
hook(module, input, output) -> None
|
|
|
|
The hook should not modify the input or output.
|
|
This function returns a handle with a method ``handle.remove()``
|
|
that removes the hook from the module.
|
|
"""
|
|
handle = hooks.RemovableHandle(self._forward_hooks)
|
|
self._forward_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def __call__(self, *input, **kwargs):
|
|
result = self.forward(*input, **kwargs)
|
|
for hook in self._forward_hooks.values():
|
|
hook_result = hook(self, input, result)
|
|
if hook_result is not None:
|
|
raise RuntimeError(
|
|
"forward hooks should never return any values, but '{}'"
|
|
"didn't return None".format(hook))
|
|
if len(self._backward_hooks) > 0:
|
|
var = result
|
|
while not isinstance(var, Variable):
|
|
var = var[0]
|
|
grad_fn = var.grad_fn
|
|
if grad_fn is not None:
|
|
for hook in self._backward_hooks.values():
|
|
wrapper = functools.partial(hook, self)
|
|
functools.update_wrapper(wrapper, hook)
|
|
grad_fn.register_hook(wrapper)
|
|
return result
|
|
|
|
def __getattr__(self, name):
|
|
if '_parameters' in self.__dict__:
|
|
_parameters = self.__dict__['_parameters']
|
|
if name in _parameters:
|
|
return _parameters[name]
|
|
if '_buffers' in self.__dict__:
|
|
_buffers = self.__dict__['_buffers']
|
|
if name in _buffers:
|
|
return _buffers[name]
|
|
if '_modules' in self.__dict__:
|
|
modules = self.__dict__['_modules']
|
|
if name in modules:
|
|
return modules[name]
|
|
raise AttributeError("'{}' object has no attribute '{}'".format(
|
|
type(self).__name__, name))
|
|
|
|
def __setattr__(self, name, value):
|
|
def remove_from(*dicts):
|
|
for d in dicts:
|
|
if name in d:
|
|
del d[name]
|
|
|
|
params = self.__dict__.get('_parameters')
|
|
if isinstance(value, Parameter):
|
|
if params is None:
|
|
raise AttributeError(
|
|
"cannot assign parameters before Module.__init__() call")
|
|
remove_from(self.__dict__, self._buffers, self._modules)
|
|
self.register_parameter(name, value)
|
|
elif params is not None and name in params:
|
|
if value is not None:
|
|
raise TypeError("cannot assign '{}' as parameter '{}' "
|
|
"(torch.nn.Parameter or None expected)"
|
|
.format(torch.typename(value), name))
|
|
self.register_parameter(name, value)
|
|
else:
|
|
modules = self.__dict__.get('_modules')
|
|
if isinstance(value, Module):
|
|
if modules is None:
|
|
raise AttributeError(
|
|
"cannot assign module before Module.__init__() call")
|
|
remove_from(self.__dict__, self._parameters, self._buffers)
|
|
modules[name] = value
|
|
elif modules is not None and name in modules:
|
|
if value is not None:
|
|
raise TypeError("cannot assign '{}' as child module '{}' "
|
|
"(torch.nn.Module or None expected)"
|
|
.format(torch.typename(value), name))
|
|
modules[name] = value
|
|
else:
|
|
buffers = self.__dict__.get('_buffers')
|
|
if buffers is not None and name in buffers:
|
|
if value is not None and not torch.is_tensor(value):
|
|
raise TypeError("cannot assign '{}' as buffer '{}' "
|
|
"(torch.Tensor or None expected)"
|
|
.format(torch.typename(value), name))
|
|
buffers[name] = value
|
|
else:
|
|
object.__setattr__(self, name, value)
|
|
|
|
def __delattr__(self, name):
|
|
if name in self._parameters:
|
|
del self._parameters[name]
|
|
elif name in self._buffers:
|
|
del self._buffers[name]
|
|
elif name in self._modules:
|
|
del self._modules[name]
|
|
else:
|
|
object.__delattr__(self, name)
|
|
|
|
def state_dict(self, destination=None, prefix=''):
|
|
"""Returns a dictionary containing a whole state of the module.
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are
|
|
included. Keys are corresponding parameter and buffer names.
|
|
|
|
Example:
|
|
>>> module.state_dict().keys()
|
|
['bias', 'weight']
|
|
"""
|
|
if destination is None:
|
|
destination = OrderedDict()
|
|
for name, param in self._parameters.items():
|
|
if param is not None:
|
|
destination[prefix + name] = param.data
|
|
for name, buf in self._buffers.items():
|
|
if buf is not None:
|
|
destination[prefix + name] = buf
|
|
for name, module in self._modules.items():
|
|
if module is not None:
|
|
module.state_dict(destination, prefix + name + '.')
|
|
return destination
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Copies parameters and buffers from :attr:`state_dict` into
|
|
this module and its descendants. The keys of :attr:`state_dict` must
|
|
exactly match the keys returned by this module's :func:`state_dict()`
|
|
function.
|
|
|
|
Arguments:
|
|
state_dict (dict): A dict containing parameters and
|
|
persistent buffers.
|
|
"""
|
|
own_state = self.state_dict()
|
|
for name, param in state_dict.items():
|
|
if name not in own_state:
|
|
raise KeyError('unexpected key "{}" in state_dict'
|
|
.format(name))
|
|
if isinstance(param, Parameter):
|
|
# backwards compatibility for serialized parameters
|
|
param = param.data
|
|
own_state[name].copy_(param)
|
|
|
|
missing = set(own_state.keys()) - set(state_dict.keys())
|
|
if len(missing) > 0:
|
|
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
|
|
|
|
def parameters(self):
|
|
"""Returns an iterator over module parameters.
|
|
|
|
This is typically passed to an optimizer.
|
|
|
|
Example:
|
|
>>> for param in model.parameters():
|
|
>>> print(type(param.data), param.size())
|
|
<class 'torch.FloatTensor'> (20L,)
|
|
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
|
|
"""
|
|
for name, param in self.named_parameters():
|
|
yield param
|
|
|
|
def named_parameters(self, memo=None, prefix=''):
|
|
"""Returns an iterator over module parameters, yielding both the
|
|
name of the parameter as well as the parameter itself
|
|
|
|
Example:
|
|
>>> for name, param in self.named_parameters():
|
|
>>> if name in ['bias']:
|
|
>>> print(param.size())
|
|
"""
|
|
if memo is None:
|
|
memo = set()
|
|
for name, p in self._parameters.items():
|
|
if p is not None and p not in memo:
|
|
memo.add(p)
|
|
yield prefix + ('.' if prefix else '') + name, p
|
|
for mname, module in self.named_children():
|
|
submodule_prefix = prefix + ('.' if prefix else '') + mname
|
|
for name, p in module.named_parameters(memo, submodule_prefix):
|
|
yield name, p
|
|
|
|
def children(self):
|
|
"""Returns an iterator over immediate children modules."""
|
|
for name, module in self.named_children():
|
|
yield module
|
|
|
|
def named_children(self):
|
|
"""Returns an iterator over immediate children modules, yielding both
|
|
the name of the module as well as the module itself.
|
|
|
|
Example:
|
|
>>> for name, module in model.named_children():
|
|
>>> if name in ['conv4', 'conv5']:
|
|
>>> print(module)
|
|
"""
|
|
memo = set()
|
|
for name, module in self._modules.items():
|
|
if module is not None and module not in memo:
|
|
memo.add(module)
|
|
yield name, module
|
|
|
|
def modules(self):
|
|
"""Returns an iterator over all modules in the network.
|
|
|
|
Note:
|
|
Duplicate modules are returned only once. In the following
|
|
example, ``l`` will be returned only once.
|
|
|
|
>>> l = nn.Linear(2, 2)
|
|
>>> net = nn.Sequential(l, l)
|
|
>>> for idx, m in enumerate(net.modules()):
|
|
>>> print(idx, '->', m)
|
|
0 -> Sequential (
|
|
(0): Linear (2 -> 2)
|
|
(1): Linear (2 -> 2)
|
|
)
|
|
1 -> Linear (2 -> 2)
|
|
"""
|
|
for name, module in self.named_modules():
|
|
yield module
|
|
|
|
def named_modules(self, memo=None, prefix=''):
|
|
"""Returns an iterator over all modules in the network, yielding
|
|
both the name of the module as well as the module itself.
|
|
|
|
Note:
|
|
Duplicate modules are returned only once. In the following
|
|
example, ``l`` will be returned only once.
|
|
|
|
>>> l = nn.Linear(2, 2)
|
|
>>> net = nn.Sequential(l, l)
|
|
>>> for idx, m in enumerate(net.named_modules()):
|
|
>>> print(idx, '->', m)
|
|
0 -> ('', Sequential (
|
|
(0): Linear (2 -> 2)
|
|
(1): Linear (2 -> 2)
|
|
))
|
|
1 -> ('0', Linear (2 -> 2))
|
|
"""
|
|
|
|
if memo is None:
|
|
memo = set()
|
|
if self not in memo:
|
|
memo.add(self)
|
|
yield prefix, self
|
|
for name, module in self._modules.items():
|
|
submodule_prefix = prefix + ('.' if prefix else '') + name
|
|
for m in module.named_modules(memo, submodule_prefix):
|
|
yield m
|
|
|
|
def train(self, mode=True):
|
|
"""Sets the module in training mode.
|
|
|
|
This has any effect only on modules such as Dropout or BatchNorm.
|
|
"""
|
|
self.training = mode
|
|
for module in self.children():
|
|
module.train(mode)
|
|
return self
|
|
|
|
def eval(self):
|
|
"""Sets the module in evaluation mode.
|
|
|
|
This has any effect only on modules such as Dropout or BatchNorm.
|
|
"""
|
|
return self.train(False)
|
|
|
|
def zero_grad(self):
|
|
"""Sets gradients of all model parameters to zero."""
|
|
for p in self.parameters():
|
|
if p.grad is not None:
|
|
if p.grad.volatile:
|
|
p.grad.data.zero_()
|
|
else:
|
|
data = p.grad.data
|
|
p.grad = Variable(data.new().resize_as_(data).zero_())
|
|
|
|
def share_memory(self):
|
|
return self._apply(lambda t: t.share_memory_())
|
|
|
|
def __repr__(self):
|
|
tmpstr = self.__class__.__name__ + ' (\n'
|
|
for key, module in self._modules.items():
|
|
modstr = module.__repr__()
|
|
modstr = _addindent(modstr, 2)
|
|
tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n'
|
|
tmpstr = tmpstr + ')'
|
|
return tmpstr
|
|
|
|
def __dir__(self):
|
|
module_attrs = dir(self.__class__)
|
|
attrs = list(self.__dict__.keys())
|
|
parameters = list(self._parameters.keys())
|
|
modules = list(self._modules.keys())
|
|
buffers = list(self._buffers.keys())
|
|
keys = module_attrs + attrs + parameters + modules + buffers
|
|
return sorted(keys)
|