Return only unique variables from parameters()

This commit is contained in:
Adam Paszke 2016-09-24 10:36:11 -07:00
parent 5030d76acf
commit 4cdeae3283
4 changed files with 55 additions and 17 deletions

View File

@ -203,6 +203,21 @@ class TestNN(NNTestCase):
module.__repr__() module.__repr__()
str(module) str(module)
def test_parameters(self):
def num_params(module):
return len(list(module.parameters()))
class Net(nn.Container):
def __init__(self):
super(Net, self).__init__(
l1=l,
l2=l
)
l = nn.Linear(10, 20)
n = Net()
s = nn.Sequential(l, l, l, l)
self.assertEqual(num_params(l), 2)
self.assertEqual(num_params(n), 2)
self.assertEqual(num_params(s), 2)
def test_Dropout(self): def test_Dropout(self):
input = torch.Tensor(1000) input = torch.Tensor(1000)

View File

@ -24,10 +24,20 @@ del old_flags
################################################################################ ################################################################################
def typename(o): def typename(o):
module = o.__module__ + '.' module = ''
if module == '__builtin__.': class_name = ''
module = '' if hasattr(o, '__module__') and o.__module__ != 'builtins' \
return module + o.__class__.__name__ and o.__module__ != '__builtin__' and o.__module__ is not None:
module = o.__module__ + '.'
if hasattr(o, '__qualname__'):
class_name = o.__qualname__
elif hasattr(o, '__name__'):
class_name = o.__name__
else:
class_name = o.__class__.__name__
return module + class_name
def is_tensor(obj): def is_tensor(obj):

View File

@ -1,6 +1,8 @@
from collections import OrderedDict
import torch
from torch.autograd import Variable from torch.autograd import Variable
from .module import Module from .module import Module
from collections import OrderedDict
class Container(Module): class Container(Module):
@ -30,7 +32,7 @@ class Container(Module):
``` ```
# one can add modules to the container after construction # one can add modules to the container after construction
model.add_module('pool1', nn.MaxPool2d(2, 2) model.add_module('pool1', nn.MaxPool2d(2, 2))
``` ```
The container has one additional method `parameters()` which The container has one additional method `parameters()` which
@ -45,13 +47,18 @@ class Container(Module):
def add_module(self, name, module): def add_module(self, name, module):
if hasattr(self, name): if hasattr(self, name):
raise KeyError("attribute already exists '{}'".format(name)) raise KeyError("attribute already exists '{}'".format(name))
if not isinstance(module, Module):
raise ValueError("{} is not a Module subclass".format(
torch.typename(module)))
setattr(self, name, module) setattr(self, name, module)
if module is not None: if module is not None:
self.modules.append(module) self.modules.append(module)
def parameters(self): def parameters(self, memo=None):
if memo is None:
memo = set()
for module in self.modules: for module in self.modules:
for p in module.parameters(): for p in module.parameters(memo):
yield p yield p

View File

@ -75,10 +75,16 @@ class Module(object):
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go)) fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
return result return result
def parameters(self): def parameters(self, memo=None):
if hasattr(self, 'weight') and self.weight is not None: if memo is None:
memo = set()
if hasattr(self, 'weight') and self.weight is not None \
and self.weight not in memo:
memo.add(self.weight)
yield self.weight yield self.weight
if hasattr(self, 'bias') and self.bias is not None: if hasattr(self, 'bias') and self.bias is not None \
and self.bias not in memo:
memo.add(self.bias)
yield self.bias yield self.bias
def zero_grad(self): def zero_grad(self):