from itertools import chain from collections import OrderedDict import functools import torch from ..backends.thnn import backend as thnn_backend from ..parameter import Parameter from torch.autograd import Variable import torch.utils.hooks as hooks def _addindent(s_, numSpaces): s = s_.split('\n') # dont do anything for single-line stuff if len(s) == 1: return s_ first = s.pop(0) s = [(numSpaces * ' ') + line for line in s] s = '\n'.join(s) s = first + '\n' + s return s class Module(object): """Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call .cuda(), etc. """ dump_patches = False def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._modules = OrderedDict() self.training = True for name, param in self._parameters.items(): if not isinstance(param, Parameter): if isinstance(param, Variable): raise TypeError("can't use a Variable as a module " "parameter. Convert it to torch.nn.Parameter first.") if param is not None: param = Parameter(param) self._parameters[name] = param def forward(self, *input): """Defines the computation performed at every call. Should be overriden by all subclasses. """ raise NotImplementedError def register_buffer(self, name, tensor): """Adds a persistent buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the persistent state. Buffers can be accessed as attributes using given names. Example: >>> self.register_buffer('running_mean', torch.zeros(num_features)) """ self._buffers[name] = tensor def register_parameter(self, name, param): """Adds a parameter to the module. The parameter can be accessed as an attribute using given name. """ if '_parameters' not in self.__dict__: raise AttributeError( "cannot assign parameter before Module.__init__() call") if param is None: self._parameters[name] = None elif not isinstance(param, Parameter): raise TypeError("cannot assign '{}' object to parameter '{}' " "(torch.nn.Parameter or None required)" .format(torch.typename(param), name)) elif param.creator: raise ValueError( "Cannot assign non-leaf Variable to parameter '{0}'. Model " "parameters must be created explicitly. To express '{0}' " "as a function of another variable, compute the value in " "the forward() method.".format(name)) else: self._parameters[name] = param def add_module(self, name, module): if hasattr(self, name): raise KeyError("attribute already exists '{}'".format(name)) if not isinstance(module, Module) and module is not None: raise TypeError("{} is not a Module subclass".format( torch.typename(module))) self._modules[name] = module def _apply(self, fn): for module in self.children(): module._apply(fn) for param in self._parameters.values(): if param is not None: # Variables stored in modules are graph leaves, and we don't # want to create copy nodes, so we have to unpack the data. param.data = fn(param.data) if param.grad is not None: param._grad.data = fn(param._grad.data) for key, buf in self._buffers.items(): if buf is not None: self._buffers[key] = fn(buf) return self def apply(self, fn): for module in self.children(): module.apply(fn) fn(self) return self def cuda(self, device_id=None): """Moves all model parameters and buffers to the GPU. Arguments: device_id (int, optional): if specified, all parameters will be copied to that device """ return self._apply(lambda t: t.cuda(device_id)) def cpu(self, device_id=None): """Moves all model parameters and buffers to the CPU.""" return self._apply(lambda t: t.cpu()) def type(self, dst_type): return self._apply(lambda t: t.type(dst_type)) def float(self): """Casts all parameters and buffers to float datatype.""" return self._apply(lambda t: t.float()) def double(self): """Casts all parameters and buffers to double datatype.""" return self._apply(lambda t: t.double()) def half(self): """Casts all parameters and buffers to half datatype.""" return self._apply(lambda t: t.half()) def register_backward_hook(self, hook): """Registers a backward hook on the module. The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> Tensor or None The :attr:`grad_input` and :attr:`grad_output` may be tuples if the module has multiple inputs or outputs. The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of :attr:`grad_input` in subsequent computations. This function returns a handle with a method ``handle.remove()`` that removes the hook from the module. """ handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[id(handle)] = hook return handle def register_forward_hook(self, hook): """Registers a forward hook on the module. The hook will be called every time :func:`forward` computes an output. It should have the following signature:: hook(module, input, output) -> None The hook should not modify the input or output. This function returns a handle with a method ``handle.remove()`` that removes the hook from the module. """ handle = hooks.RemovableHandle(self._forward_hooks) self._forward_hooks[id(handle)] = hook return handle def __call__(self, *input, **kwargs): result = self.forward(*input, **kwargs) for hook in self._forward_hooks.values(): hook_result = hook(self, input, result) if hook_result is not None: raise RuntimeError( "forward hooks should never return any values, but '{}'" "didn't return None".format(hook)) var = result while not isinstance(var, Variable): var = var[0] creator = var.creator if creator is not None and len(self._backward_hooks) > 0: if creator._backward_hooks is None: creator._backward_hooks = OrderedDict() for hook in self._backward_hooks.values(): wrapper = functools.partial(hook, self) functools.update_wrapper(wrapper, hook) creator._backward_hooks[id(wrapper)] = wrapper return result def __getattr__(self, name): if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] if name in _parameters: return _parameters[name] if '_buffers' in self.__dict__: _buffers = self.__dict__['_buffers'] if name in _buffers: return _buffers[name] if '_modules' in self.__dict__: modules = self.__dict__['_modules'] if name in modules: return modules[name] return object.__getattribute__(self, name) def __setattr__(self, name, value): params = self.__dict__.get('_parameters') if isinstance(value, Parameter): if params is None: raise AttributeError( "cannot assign parameters before Module.__init__() call") self.register_parameter(name, value) elif params is not None and name in params: if value is not None: raise TypeError("cannot assign '{}' as parameter '{}' " "(torch.nn.Parameter or None expected)" .format(torch.typename(value), name)) self.register_parameter(name, value) else: modules = self.__dict__.get('_modules') if isinstance(value, Module): if modules is None: raise AttributeError( "cannot assign module before Module.__init__() call") modules[name] = value elif modules is not None and name in modules: if value is not None: raise TypeError("cannot assign '{}' as child module '{}' " "(torch.nn.Module or None expected)" .format(torch.typename(value), name)) modules[name] = value else: object.__setattr__(self, name, value) def __delattr__(self, name): if name in self._parameters: del self._parameters[name] elif name in self._modules: del self._modules[name] else: object.__delattr__(self, name) def state_dict(self, destination=None, prefix=''): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Example: >>> module.state_dict().keys() ['bias', 'weight'] """ if destination is None: destination = OrderedDict() for name, param in self._parameters.items(): if param is not None: destination[prefix + name] = param.data for name, buf in self._buffers.items(): if buf is not None: destination[prefix + name] = buf for name, module in self._modules.items(): if module is not None: module.state_dict(destination, prefix + name + '.') return destination def load_state_dict(self, state_dict): """Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. The keys of :attr:`state_dict` must exactly match the keys returned by this module's :func:`state_dict()` function. Arguments: state_dict (dict): A dict containing parameters and persistent buffers. """ own_state = self.state_dict() for name, param in state_dict.items(): if name not in own_state: raise KeyError('unexpected key "{}" in state_dict' .format(name)) if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data own_state[name].copy_(param) missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) def parameters(self, memo=None): """Returns an iterator over module parameters. This is typically passed to an optimizer. Example: >>> for param in model.parameters(): >>> print(type(param.data), param.size()) (20L,) (20L, 1L, 5L, 5L) """ if memo is None: memo = set() for p in self._parameters.values(): if p is not None and p not in memo: memo.add(p) yield p for module in self.children(): for p in module.parameters(memo): yield p def children(self): """Returns an iterator over children modules.""" memo = set() for module in self._modules.values(): if module is not None and module not in memo: memo.add(module) yield module def modules(self, memo=None): if memo is None: memo = set() if self not in memo: memo.add(self) yield self for module in self.children(): for m in module.modules(memo): yield m def train(self): """Sets the module in training mode. This has any effect only on modules such as Dropout or BatchNorm. """ self.training = True for module in self.children(): module.train() return self def eval(self): """Sets the module in evaluation mode. This has any effect only on modules such as Dropout or BatchNorm. """ self.training = False for module in self.children(): module.eval() return self def zero_grad(self): """Sets gradients of all model parameters to zero.""" for p in self.parameters(): p.grad.data.zero_() def share_memory(self): return self._apply(lambda t: t.share_memory_()) def __repr__(self): tmpstr = self.__class__.__name__ + ' (\n' for key, module in self._modules.items(): modstr = module.__repr__() modstr = _addindent(modstr, 2) tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n' tmpstr = tmpstr + ')' return tmpstr