pytorch/torch/nn/modules/container.py
Sam Gross f45213a276 Simplify nn.Container and nn.Sequential
- nn.Container.modules is just a python list and used by nn.Sequential

 - Every module in nn.Sequential has a name. This fixes Module.type()

 - nn.Sequential constructor accepts either a list or an OrderedDict. With a
   list, the modules are named "0", "1", "2", ...
2016-08-31 11:15:31 -07:00

46 lines
1.2 KiB
Python

from torch.autograd import Variable
from .module import Module
from collections import OrderedDict
class Container(Module):
def __init__(self, **kwargs):
super(Container, self).__init__()
self.modules = []
for key, value in kwargs.items():
self._assign_module(key, value)
def _assign_module(self, name, module):
# TODO: error message
assert not hasattr(self, name)
setattr(self, name, module)
self.modules.append(module)
def parameters(self):
for module in self.modules:
for p in module.parameters():
yield p
class Sequential(Container):
def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self._assign_module(key, module)
else:
idx = 0
for module in args:
self._assign_module(str(idx), module)
idx += 1
def __getitem__(self, idx):
return self.modules[idx]
def _forward(self, input):
for module in self.modules:
input = module(input)
return (input,)