mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
534 lines
19 KiB
Python
534 lines
19 KiB
Python
from collections import OrderedDict
|
|
import functools
|
|
|
|
import torch
|
|
from ..backends.thnn import backend as thnn_backend
|
|
from ..parameter import Parameter
|
|
from torch.autograd import Variable
|
|
import torch.utils.hooks as hooks
|
|
|
|
|
|
def _addindent(s_, numSpaces):
|
|
s = s_.split('\n')
|
|
# don't do anything for single-line stuff
|
|
if len(s) == 1:
|
|
return s_
|
|
first = s.pop(0)
|
|
s = [(numSpaces * ' ') + line for line in s]
|
|
s = '\n'.join(s)
|
|
s = first + '\n' + s
|
|
return s
|
|
|
|
|
|
class Module(object):
|
|
"""Base class for all neural network modules.
|
|
|
|
Your models should also subclass this class.
|
|
|
|
Modules can also contain other Modules, allowing to nest them in
|
|
a tree structure. You can assign the submodules as regular attributes::
|
|
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 20, 5)
|
|
self.conv2 = nn.Conv2d(20, 20, 5)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
return F.relu(self.conv2(x))
|
|
|
|
Submodules assigned in this way will be registered, and will have their
|
|
parameters converted too when you call .cuda(), etc.
|
|
"""
|
|
|
|
dump_patches = False
|
|
|
|
def __init__(self):
|
|
self._backend = thnn_backend
|
|
self._parameters = OrderedDict()
|
|
self._buffers = OrderedDict()
|
|
self._backward_hooks = OrderedDict()
|
|
self._forward_hooks = OrderedDict()
|
|
self._forward_pre_hooks = OrderedDict()
|
|
self._modules = OrderedDict()
|
|
self.training = True
|
|
|
|
def forward(self, *input):
|
|
"""Defines the computation performed at every call.
|
|
|
|
Should be overriden by all subclasses.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def register_buffer(self, name, tensor):
|
|
"""Adds a persistent buffer to the module.
|
|
|
|
This is typically used to register a buffer that should not to be
|
|
considered a model parameter. For example, BatchNorm's ``running_mean``
|
|
is not a parameter, but is part of the persistent state.
|
|
|
|
Buffers can be accessed as attributes using given names.
|
|
|
|
Example:
|
|
>>> self.register_buffer('running_mean', torch.zeros(num_features))
|
|
"""
|
|
self._buffers[name] = tensor
|
|
|
|
def register_parameter(self, name, param):
|
|
"""Adds a parameter to the module.
|
|
|
|
The parameter can be accessed as an attribute using given name.
|
|
"""
|
|
if '_parameters' not in self.__dict__:
|
|
raise AttributeError(
|
|
"cannot assign parameter before Module.__init__() call")
|
|
if param is None:
|
|
self._parameters[name] = None
|
|
elif not isinstance(param, Parameter):
|
|
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
|
"(torch.nn.Parameter or None required)"
|
|
.format(torch.typename(param), name))
|
|
elif param.grad_fn:
|
|
raise ValueError(
|
|
"Cannot assign non-leaf Variable to parameter '{0}'. Model "
|
|
"parameters must be created explicitly. To express '{0}' "
|
|
"as a function of another variable, compute the value in "
|
|
"the forward() method.".format(name))
|
|
else:
|
|
self._parameters[name] = param
|
|
|
|
def add_module(self, name, module):
|
|
"""Adds a child module to the current module.
|
|
|
|
The module can be accessed as an attribute using the given name.
|
|
"""
|
|
if hasattr(self, name):
|
|
raise KeyError("attribute already exists '{}'".format(name))
|
|
if not isinstance(module, Module) and module is not None:
|
|
raise TypeError("{} is not a Module subclass".format(
|
|
torch.typename(module)))
|
|
self._modules[name] = module
|
|
|
|
def _apply(self, fn):
|
|
for module in self.children():
|
|
module._apply(fn)
|
|
|
|
for param in self._parameters.values():
|
|
if param is not None:
|
|
# Variables stored in modules are graph leaves, and we don't
|
|
# want to create copy nodes, so we have to unpack the data.
|
|
param.data = fn(param.data)
|
|
if param._grad is not None:
|
|
param._grad.data = fn(param._grad.data)
|
|
|
|
for key, buf in self._buffers.items():
|
|
if buf is not None:
|
|
self._buffers[key] = fn(buf)
|
|
|
|
return self
|
|
|
|
def apply(self, fn):
|
|
for module in self.children():
|
|
module.apply(fn)
|
|
fn(self)
|
|
return self
|
|
|
|
def cuda(self, device_id=None):
|
|
"""Moves all model parameters and buffers to the GPU.
|
|
|
|
Arguments:
|
|
device_id (int, optional): if specified, all parameters will be
|
|
copied to that device
|
|
"""
|
|
return self._apply(lambda t: t.cuda(device_id))
|
|
|
|
def cpu(self):
|
|
"""Moves all model parameters and buffers to the CPU."""
|
|
return self._apply(lambda t: t.cpu())
|
|
|
|
def type(self, dst_type):
|
|
return self._apply(lambda t: t.type(dst_type))
|
|
|
|
def float(self):
|
|
"""Casts all parameters and buffers to float datatype."""
|
|
return self._apply(lambda t: t.float())
|
|
|
|
def double(self):
|
|
"""Casts all parameters and buffers to double datatype."""
|
|
return self._apply(lambda t: t.double())
|
|
|
|
def half(self):
|
|
"""Casts all parameters and buffers to half datatype."""
|
|
return self._apply(lambda t: t.half())
|
|
|
|
def register_backward_hook(self, hook):
|
|
"""Registers a backward hook on the module.
|
|
|
|
The hook will be called every time the gradients with respect to module
|
|
inputs are computed. The hook should have the following signature::
|
|
|
|
hook(module, grad_input, grad_output) -> Tensor or None
|
|
|
|
The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
|
|
module has multiple inputs or outputs. The hook should not modify its
|
|
arguments, but it can optionally return a new gradient with respect to
|
|
input that will be used in place of :attr:`grad_input` in subsequent
|
|
computations.
|
|
|
|
This function returns a handle with a method ``handle.remove()``
|
|
that removes the hook from the module.
|
|
"""
|
|
handle = hooks.RemovableHandle(self._backward_hooks)
|
|
self._backward_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def register_forward_pre_hook(self, hook):
|
|
"""Registers a forward pre-hook on the module.
|
|
|
|
The hook will be called before :func:`forward` is invoked.
|
|
It should have the following signature::
|
|
|
|
hook(module, input) -> None
|
|
|
|
The hook should not modify the input.
|
|
This function returns a handle with a method ``handle.remove()``
|
|
that removes the hook from the module.
|
|
"""
|
|
handle = hooks.RemovableHandle(self._forward_pre_hooks)
|
|
self._forward_pre_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def register_forward_hook(self, hook):
|
|
"""Registers a forward hook on the module.
|
|
|
|
The hook will be called every time :func:`forward` computes an output.
|
|
It should have the following signature::
|
|
|
|
hook(module, input, output) -> None
|
|
|
|
The hook should not modify the input or output.
|
|
This function returns a handle with a method ``handle.remove()``
|
|
that removes the hook from the module.
|
|
"""
|
|
handle = hooks.RemovableHandle(self._forward_hooks)
|
|
self._forward_hooks[handle.id] = hook
|
|
return handle
|
|
|
|
def __call__(self, *input, **kwargs):
|
|
for hook in self._forward_pre_hooks.values():
|
|
hook(self, input)
|
|
result = self.forward(*input, **kwargs)
|
|
for hook in self._forward_hooks.values():
|
|
hook_result = hook(self, input, result)
|
|
if hook_result is not None:
|
|
raise RuntimeError(
|
|
"forward hooks should never return any values, but '{}'"
|
|
"didn't return None".format(hook))
|
|
if len(self._backward_hooks) > 0:
|
|
var = result
|
|
while not isinstance(var, Variable):
|
|
var = var[0]
|
|
grad_fn = var.grad_fn
|
|
if grad_fn is not None:
|
|
for hook in self._backward_hooks.values():
|
|
wrapper = functools.partial(hook, self)
|
|
functools.update_wrapper(wrapper, hook)
|
|
grad_fn.register_hook(wrapper)
|
|
return result
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__.update(state)
|
|
if '_forward_pre_hooks' not in self.__dict__:
|
|
self._forward_pre_hooks = OrderedDict()
|
|
|
|
def __getattr__(self, name):
|
|
if '_parameters' in self.__dict__:
|
|
_parameters = self.__dict__['_parameters']
|
|
if name in _parameters:
|
|
return _parameters[name]
|
|
if '_buffers' in self.__dict__:
|
|
_buffers = self.__dict__['_buffers']
|
|
if name in _buffers:
|
|
return _buffers[name]
|
|
if '_modules' in self.__dict__:
|
|
modules = self.__dict__['_modules']
|
|
if name in modules:
|
|
return modules[name]
|
|
raise AttributeError("'{}' object has no attribute '{}'".format(
|
|
type(self).__name__, name))
|
|
|
|
def __setattr__(self, name, value):
|
|
def remove_from(*dicts):
|
|
for d in dicts:
|
|
if name in d:
|
|
del d[name]
|
|
|
|
params = self.__dict__.get('_parameters')
|
|
if isinstance(value, Parameter):
|
|
if params is None:
|
|
raise AttributeError(
|
|
"cannot assign parameters before Module.__init__() call")
|
|
remove_from(self.__dict__, self._buffers, self._modules)
|
|
self.register_parameter(name, value)
|
|
elif params is not None and name in params:
|
|
if value is not None:
|
|
raise TypeError("cannot assign '{}' as parameter '{}' "
|
|
"(torch.nn.Parameter or None expected)"
|
|
.format(torch.typename(value), name))
|
|
self.register_parameter(name, value)
|
|
else:
|
|
modules = self.__dict__.get('_modules')
|
|
if isinstance(value, Module):
|
|
if modules is None:
|
|
raise AttributeError(
|
|
"cannot assign module before Module.__init__() call")
|
|
remove_from(self.__dict__, self._parameters, self._buffers)
|
|
modules[name] = value
|
|
elif modules is not None and name in modules:
|
|
if value is not None:
|
|
raise TypeError("cannot assign '{}' as child module '{}' "
|
|
"(torch.nn.Module or None expected)"
|
|
.format(torch.typename(value), name))
|
|
modules[name] = value
|
|
else:
|
|
buffers = self.__dict__.get('_buffers')
|
|
if buffers is not None and name in buffers:
|
|
if value is not None and not torch.is_tensor(value):
|
|
raise TypeError("cannot assign '{}' as buffer '{}' "
|
|
"(torch.Tensor or None expected)"
|
|
.format(torch.typename(value), name))
|
|
buffers[name] = value
|
|
else:
|
|
object.__setattr__(self, name, value)
|
|
|
|
def __delattr__(self, name):
|
|
if name in self._parameters:
|
|
del self._parameters[name]
|
|
elif name in self._buffers:
|
|
del self._buffers[name]
|
|
elif name in self._modules:
|
|
del self._modules[name]
|
|
else:
|
|
object.__delattr__(self, name)
|
|
|
|
def state_dict(self, destination=None, prefix=''):
|
|
"""Returns a dictionary containing a whole state of the module.
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are
|
|
included. Keys are corresponding parameter and buffer names.
|
|
|
|
Example:
|
|
>>> module.state_dict().keys()
|
|
['bias', 'weight']
|
|
"""
|
|
if destination is None:
|
|
destination = OrderedDict()
|
|
for name, param in self._parameters.items():
|
|
if param is not None:
|
|
destination[prefix + name] = param.data
|
|
for name, buf in self._buffers.items():
|
|
if buf is not None:
|
|
destination[prefix + name] = buf
|
|
for name, module in self._modules.items():
|
|
if module is not None:
|
|
module.state_dict(destination, prefix + name + '.')
|
|
return destination
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Copies parameters and buffers from :attr:`state_dict` into
|
|
this module and its descendants. The keys of :attr:`state_dict` must
|
|
exactly match the keys returned by this module's :func:`state_dict()`
|
|
function.
|
|
|
|
Arguments:
|
|
state_dict (dict): A dict containing parameters and
|
|
persistent buffers.
|
|
"""
|
|
own_state = self.state_dict()
|
|
for name, param in state_dict.items():
|
|
if name not in own_state:
|
|
raise KeyError('unexpected key "{}" in state_dict'
|
|
.format(name))
|
|
if isinstance(param, Parameter):
|
|
# backwards compatibility for serialized parameters
|
|
param = param.data
|
|
try:
|
|
own_state[name].copy_(param)
|
|
except:
|
|
print('While copying the parameter named {}, whose dimensions in the model are'
|
|
' {} and whose dimensions in the checkpoint are {}, ...'.format(
|
|
name, own_state[name].size(), param.size()))
|
|
raise
|
|
|
|
missing = set(own_state.keys()) - set(state_dict.keys())
|
|
if len(missing) > 0:
|
|
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
|
|
|
|
def parameters(self):
|
|
"""Returns an iterator over module parameters.
|
|
|
|
This is typically passed to an optimizer.
|
|
|
|
Example:
|
|
>>> for param in model.parameters():
|
|
>>> print(type(param.data), param.size())
|
|
<class 'torch.FloatTensor'> (20L,)
|
|
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
|
|
"""
|
|
for name, param in self.named_parameters():
|
|
yield param
|
|
|
|
def named_parameters(self, memo=None, prefix=''):
|
|
"""Returns an iterator over module parameters, yielding both the
|
|
name of the parameter as well as the parameter itself
|
|
|
|
Example:
|
|
>>> for name, param in self.named_parameters():
|
|
>>> if name in ['bias']:
|
|
>>> print(param.size())
|
|
"""
|
|
if memo is None:
|
|
memo = set()
|
|
for name, p in self._parameters.items():
|
|
if p is not None and p not in memo:
|
|
memo.add(p)
|
|
yield prefix + ('.' if prefix else '') + name, p
|
|
for mname, module in self.named_children():
|
|
submodule_prefix = prefix + ('.' if prefix else '') + mname
|
|
for name, p in module.named_parameters(memo, submodule_prefix):
|
|
yield name, p
|
|
|
|
def _all_buffers(self, memo=None):
|
|
if memo is None:
|
|
memo = set()
|
|
for name, b in self._buffers.items():
|
|
if b is not None and b not in memo:
|
|
memo.add(b)
|
|
yield b
|
|
for module in self.children():
|
|
for b in module._all_buffers(memo):
|
|
yield b
|
|
|
|
def children(self):
|
|
"""Returns an iterator over immediate children modules."""
|
|
for name, module in self.named_children():
|
|
yield module
|
|
|
|
def named_children(self):
|
|
"""Returns an iterator over immediate children modules, yielding both
|
|
the name of the module as well as the module itself.
|
|
|
|
Example:
|
|
>>> for name, module in model.named_children():
|
|
>>> if name in ['conv4', 'conv5']:
|
|
>>> print(module)
|
|
"""
|
|
memo = set()
|
|
for name, module in self._modules.items():
|
|
if module is not None and module not in memo:
|
|
memo.add(module)
|
|
yield name, module
|
|
|
|
def modules(self):
|
|
"""Returns an iterator over all modules in the network.
|
|
|
|
Note:
|
|
Duplicate modules are returned only once. In the following
|
|
example, ``l`` will be returned only once.
|
|
|
|
>>> l = nn.Linear(2, 2)
|
|
>>> net = nn.Sequential(l, l)
|
|
>>> for idx, m in enumerate(net.modules()):
|
|
>>> print(idx, '->', m)
|
|
0 -> Sequential (
|
|
(0): Linear (2 -> 2)
|
|
(1): Linear (2 -> 2)
|
|
)
|
|
1 -> Linear (2 -> 2)
|
|
"""
|
|
for name, module in self.named_modules():
|
|
yield module
|
|
|
|
def named_modules(self, memo=None, prefix=''):
|
|
"""Returns an iterator over all modules in the network, yielding
|
|
both the name of the module as well as the module itself.
|
|
|
|
Note:
|
|
Duplicate modules are returned only once. In the following
|
|
example, ``l`` will be returned only once.
|
|
|
|
>>> l = nn.Linear(2, 2)
|
|
>>> net = nn.Sequential(l, l)
|
|
>>> for idx, m in enumerate(net.named_modules()):
|
|
>>> print(idx, '->', m)
|
|
0 -> ('', Sequential (
|
|
(0): Linear (2 -> 2)
|
|
(1): Linear (2 -> 2)
|
|
))
|
|
1 -> ('0', Linear (2 -> 2))
|
|
"""
|
|
|
|
if memo is None:
|
|
memo = set()
|
|
if self not in memo:
|
|
memo.add(self)
|
|
yield prefix, self
|
|
for name, module in self._modules.items():
|
|
if module is None:
|
|
continue
|
|
submodule_prefix = prefix + ('.' if prefix else '') + name
|
|
for m in module.named_modules(memo, submodule_prefix):
|
|
yield m
|
|
|
|
def train(self, mode=True):
|
|
"""Sets the module in training mode.
|
|
|
|
This has any effect only on modules such as Dropout or BatchNorm.
|
|
"""
|
|
self.training = mode
|
|
for module in self.children():
|
|
module.train(mode)
|
|
return self
|
|
|
|
def eval(self):
|
|
"""Sets the module in evaluation mode.
|
|
|
|
This has any effect only on modules such as Dropout or BatchNorm.
|
|
"""
|
|
return self.train(False)
|
|
|
|
def zero_grad(self):
|
|
"""Sets gradients of all model parameters to zero."""
|
|
for p in self.parameters():
|
|
if p.grad is not None:
|
|
if p.grad.volatile:
|
|
p.grad.data.zero_()
|
|
else:
|
|
data = p.grad.data
|
|
p.grad = Variable(data.new().resize_as_(data).zero_())
|
|
|
|
def share_memory(self):
|
|
return self._apply(lambda t: t.share_memory_())
|
|
|
|
def __repr__(self):
|
|
tmpstr = self.__class__.__name__ + ' (\n'
|
|
for key, module in self._modules.items():
|
|
modstr = module.__repr__()
|
|
modstr = _addindent(modstr, 2)
|
|
tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n'
|
|
tmpstr = tmpstr + ')'
|
|
return tmpstr
|
|
|
|
def __dir__(self):
|
|
module_attrs = dir(self.__class__)
|
|
attrs = list(self.__dict__.keys())
|
|
parameters = list(self._parameters.keys())
|
|
modules = list(self._modules.keys())
|
|
buffers = list(self._buffers.keys())
|
|
keys = module_attrs + attrs + parameters + modules + buffers
|
|
return sorted(keys)
|