Return only unique variables from parameters()

This commit is contained in:
Adam Paszke 2016-09-24 10:36:11 -07:00
parent 5030d76acf
commit 4cdeae3283
4 changed files with 55 additions and 17 deletions

View File

@ -203,6 +203,21 @@ class TestNN(NNTestCase):
module.__repr__() module.__repr__()
str(module) str(module)
def test_parameters(self):
def num_params(module):
return len(list(module.parameters()))
class Net(nn.Container):
def __init__(self):
super(Net, self).__init__(
l1=l,
l2=l
)
l = nn.Linear(10, 20)
n = Net()
s = nn.Sequential(l, l, l, l)
self.assertEqual(num_params(l), 2)
self.assertEqual(num_params(n), 2)
self.assertEqual(num_params(s), 2)
def test_Dropout(self): def test_Dropout(self):
input = torch.Tensor(1000) input = torch.Tensor(1000)

View File

@ -24,10 +24,20 @@ del old_flags
################################################################################ ################################################################################
def typename(o): def typename(o):
module = o.__module__ + '.' module = ''
if module == '__builtin__.': class_name = ''
module = '' if hasattr(o, '__module__') and o.__module__ != 'builtins' \
return module + o.__class__.__name__ and o.__module__ != '__builtin__' and o.__module__ is not None:
module = o.__module__ + '.'
if hasattr(o, '__qualname__'):
class_name = o.__qualname__
elif hasattr(o, '__name__'):
class_name = o.__name__
else:
class_name = o.__class__.__name__
return module + class_name
def is_tensor(obj): def is_tensor(obj):

View File

@ -1,15 +1,17 @@
from collections import OrderedDict
import torch
from torch.autograd import Variable from torch.autograd import Variable
from .module import Module from .module import Module
from collections import OrderedDict
class Container(Module): class Container(Module):
"""This is the base container class for all neural networks you would define. """This is the base container class for all neural networks you would define.
You will subclass your container from this class. You will subclass your container from this class.
In the constructor you define the modules that you would want to use, In the constructor you define the modules that you would want to use,
and in the __call__ function you use the constructed modules in and in the __call__ function you use the constructed modules in
your operations. your operations.
To make it easier to understand, given is a small example. To make it easier to understand, given is a small example.
``` ```
# Example of using Container # Example of using Container
@ -23,16 +25,16 @@ class Container(Module):
output = self.relu(self.conv1(x)) output = self.relu(self.conv1(x))
return output return output
model = Net() model = Net()
``` ```
One can also add new modules to a container after construction. One can also add new modules to a container after construction.
You can do this with the add_module function. You can do this with the add_module function.
``` ```
# one can add modules to the container after construction # one can add modules to the container after construction
model.add_module('pool1', nn.MaxPool2d(2, 2) model.add_module('pool1', nn.MaxPool2d(2, 2))
``` ```
The container has one additional method `parameters()` which The container has one additional method `parameters()` which
returns the list of learnable parameters in the container instance. returns the list of learnable parameters in the container instance.
""" """
@ -45,13 +47,18 @@ class Container(Module):
def add_module(self, name, module): def add_module(self, name, module):
if hasattr(self, name): if hasattr(self, name):
raise KeyError("attribute already exists '{}'".format(name)) raise KeyError("attribute already exists '{}'".format(name))
if not isinstance(module, Module):
raise ValueError("{} is not a Module subclass".format(
torch.typename(module)))
setattr(self, name, module) setattr(self, name, module)
if module is not None: if module is not None:
self.modules.append(module) self.modules.append(module)
def parameters(self): def parameters(self, memo=None):
if memo is None:
memo = set()
for module in self.modules: for module in self.modules:
for p in module.parameters(): for p in module.parameters(memo):
yield p yield p

View File

@ -75,10 +75,16 @@ class Module(object):
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go)) fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
return result return result
def parameters(self): def parameters(self, memo=None):
if hasattr(self, 'weight') and self.weight is not None: if memo is None:
memo = set()
if hasattr(self, 'weight') and self.weight is not None \
and self.weight not in memo:
memo.add(self.weight)
yield self.weight yield self.weight
if hasattr(self, 'bias') and self.bias is not None: if hasattr(self, 'bias') and self.bias is not None \
and self.bias not in memo:
memo.add(self.bias)
yield self.bias yield self.bias
def zero_grad(self): def zero_grad(self):