Return only unique variables from parameters()

This commit is contained in:
Adam Paszke 2016-09-24 10:36:11 -07:00
parent 5030d76acf
commit 4cdeae3283
4 changed files with 55 additions and 17 deletions

View File

@ -203,6 +203,21 @@ class TestNN(NNTestCase):
module.__repr__()
str(module)
def test_parameters(self):
def num_params(module):
return len(list(module.parameters()))
class Net(nn.Container):
def __init__(self):
super(Net, self).__init__(
l1=l,
l2=l
)
l = nn.Linear(10, 20)
n = Net()
s = nn.Sequential(l, l, l, l)
self.assertEqual(num_params(l), 2)
self.assertEqual(num_params(n), 2)
self.assertEqual(num_params(s), 2)
def test_Dropout(self):
input = torch.Tensor(1000)

View File

@ -24,10 +24,20 @@ del old_flags
################################################################################
def typename(o):
module = o.__module__ + '.'
if module == '__builtin__.':
module = ''
return module + o.__class__.__name__
module = ''
class_name = ''
if hasattr(o, '__module__') and o.__module__ != 'builtins' \
and o.__module__ != '__builtin__' and o.__module__ is not None:
module = o.__module__ + '.'
if hasattr(o, '__qualname__'):
class_name = o.__qualname__
elif hasattr(o, '__name__'):
class_name = o.__name__
else:
class_name = o.__class__.__name__
return module + class_name
def is_tensor(obj):

View File

@ -1,15 +1,17 @@
from collections import OrderedDict
import torch
from torch.autograd import Variable
from .module import Module
from collections import OrderedDict
class Container(Module):
"""This is the base container class for all neural networks you would define.
You will subclass your container from this class.
In the constructor you define the modules that you would want to use,
and in the __call__ function you use the constructed modules in
In the constructor you define the modules that you would want to use,
and in the __call__ function you use the constructed modules in
your operations.
To make it easier to understand, given is a small example.
```
# Example of using Container
@ -23,16 +25,16 @@ class Container(Module):
output = self.relu(self.conv1(x))
return output
model = Net()
```
```
One can also add new modules to a container after construction.
You can do this with the add_module function.
```
# one can add modules to the container after construction
model.add_module('pool1', nn.MaxPool2d(2, 2)
model.add_module('pool1', nn.MaxPool2d(2, 2))
```
The container has one additional method `parameters()` which
returns the list of learnable parameters in the container instance.
"""
@ -45,13 +47,18 @@ class Container(Module):
def add_module(self, name, module):
if hasattr(self, name):
raise KeyError("attribute already exists '{}'".format(name))
if not isinstance(module, Module):
raise ValueError("{} is not a Module subclass".format(
torch.typename(module)))
setattr(self, name, module)
if module is not None:
self.modules.append(module)
def parameters(self):
def parameters(self, memo=None):
if memo is None:
memo = set()
for module in self.modules:
for p in module.parameters():
for p in module.parameters(memo):
yield p

View File

@ -75,10 +75,16 @@ class Module(object):
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
return result
def parameters(self):
if hasattr(self, 'weight') and self.weight is not None:
def parameters(self, memo=None):
if memo is None:
memo = set()
if hasattr(self, 'weight') and self.weight is not None \
and self.weight not in memo:
memo.add(self.weight)
yield self.weight
if hasattr(self, 'bias') and self.bias is not None:
if hasattr(self, 'bias') and self.bias is not None \
and self.bias not in memo:
memo.add(self.bias)
yield self.bias
def zero_grad(self):