From 4cdeae32836f002c20bd3e93543a22b45e198300 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Sat, 24 Sep 2016 10:36:11 -0700 Subject: [PATCH] Return only unique variables from parameters() --- test/test_nn.py | 15 +++++++++++++++ torch/__init__.py | 18 ++++++++++++++---- torch/nn/modules/container.py | 27 +++++++++++++++++---------- torch/nn/modules/module.py | 12 +++++++++--- 4 files changed, 55 insertions(+), 17 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 2d430beb161..75e346ddc91 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -203,6 +203,21 @@ class TestNN(NNTestCase): module.__repr__() str(module) + def test_parameters(self): + def num_params(module): + return len(list(module.parameters())) + class Net(nn.Container): + def __init__(self): + super(Net, self).__init__( + l1=l, + l2=l + ) + l = nn.Linear(10, 20) + n = Net() + s = nn.Sequential(l, l, l, l) + self.assertEqual(num_params(l), 2) + self.assertEqual(num_params(n), 2) + self.assertEqual(num_params(s), 2) def test_Dropout(self): input = torch.Tensor(1000) diff --git a/torch/__init__.py b/torch/__init__.py index 81be2887e69..a14d5eb2b3a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -24,10 +24,20 @@ del old_flags ################################################################################ def typename(o): - module = o.__module__ + '.' - if module == '__builtin__.': - module = '' - return module + o.__class__.__name__ + module = '' + class_name = '' + if hasattr(o, '__module__') and o.__module__ != 'builtins' \ + and o.__module__ != '__builtin__' and o.__module__ is not None: + module = o.__module__ + '.' + + if hasattr(o, '__qualname__'): + class_name = o.__qualname__ + elif hasattr(o, '__name__'): + class_name = o.__name__ + else: + class_name = o.__class__.__name__ + + return module + class_name def is_tensor(obj): diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index f30267abf3d..dd34b8ada34 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -1,15 +1,17 @@ +from collections import OrderedDict + +import torch from torch.autograd import Variable from .module import Module -from collections import OrderedDict class Container(Module): """This is the base container class for all neural networks you would define. You will subclass your container from this class. - In the constructor you define the modules that you would want to use, - and in the __call__ function you use the constructed modules in + In the constructor you define the modules that you would want to use, + and in the __call__ function you use the constructed modules in your operations. - + To make it easier to understand, given is a small example. ``` # Example of using Container @@ -23,16 +25,16 @@ class Container(Module): output = self.relu(self.conv1(x)) return output model = Net() - ``` - + ``` + One can also add new modules to a container after construction. You can do this with the add_module function. ``` # one can add modules to the container after construction - model.add_module('pool1', nn.MaxPool2d(2, 2) + model.add_module('pool1', nn.MaxPool2d(2, 2)) ``` - + The container has one additional method `parameters()` which returns the list of learnable parameters in the container instance. """ @@ -45,13 +47,18 @@ class Container(Module): def add_module(self, name, module): if hasattr(self, name): raise KeyError("attribute already exists '{}'".format(name)) + if not isinstance(module, Module): + raise ValueError("{} is not a Module subclass".format( + torch.typename(module))) setattr(self, name, module) if module is not None: self.modules.append(module) - def parameters(self): + def parameters(self, memo=None): + if memo is None: + memo = set() for module in self.modules: - for p in module.parameters(): + for p in module.parameters(memo): yield p diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 56f5ae45abb..69a74dc3f6c 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -75,10 +75,16 @@ class Module(object): fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go)) return result - def parameters(self): - if hasattr(self, 'weight') and self.weight is not None: + def parameters(self, memo=None): + if memo is None: + memo = set() + if hasattr(self, 'weight') and self.weight is not None \ + and self.weight not in memo: + memo.add(self.weight) yield self.weight - if hasattr(self, 'bias') and self.bias is not None: + if hasattr(self, 'bias') and self.bias is not None \ + and self.bias not in memo: + memo.add(self.bias) yield self.bias def zero_grad(self):