mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
- nn.Container.modules is just a python list and used by nn.Sequential - Every module in nn.Sequential has a name. This fixes Module.type() - nn.Sequential constructor accepts either a list or an OrderedDict. With a list, the modules are named "0", "1", "2", ...
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
from torch.autograd import Variable
|
|
from .module import Module
|
|
from collections import OrderedDict
|
|
|
|
|
|
class Container(Module):
|
|
|
|
def __init__(self, **kwargs):
|
|
super(Container, self).__init__()
|
|
self.modules = []
|
|
for key, value in kwargs.items():
|
|
self._assign_module(key, value)
|
|
|
|
def _assign_module(self, name, module):
|
|
# TODO: error message
|
|
assert not hasattr(self, name)
|
|
setattr(self, name, module)
|
|
self.modules.append(module)
|
|
|
|
def parameters(self):
|
|
for module in self.modules:
|
|
for p in module.parameters():
|
|
yield p
|
|
|
|
|
|
class Sequential(Container):
|
|
|
|
def __init__(self, *args):
|
|
super(Sequential, self).__init__()
|
|
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
|
for key, module in args[0].items():
|
|
self._assign_module(key, module)
|
|
else:
|
|
idx = 0
|
|
for module in args:
|
|
self._assign_module(str(idx), module)
|
|
idx += 1
|
|
|
|
def __getitem__(self, idx):
|
|
return self.modules[idx]
|
|
|
|
def _forward(self, input):
|
|
for module in self.modules:
|
|
input = module(input)
|
|
return (input,)
|