mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Return only unique variables from parameters()
This commit is contained in:
parent
5030d76acf
commit
4cdeae3283
|
|
@ -203,6 +203,21 @@ class TestNN(NNTestCase):
|
||||||
module.__repr__()
|
module.__repr__()
|
||||||
str(module)
|
str(module)
|
||||||
|
|
||||||
|
def test_parameters(self):
|
||||||
|
def num_params(module):
|
||||||
|
return len(list(module.parameters()))
|
||||||
|
class Net(nn.Container):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__(
|
||||||
|
l1=l,
|
||||||
|
l2=l
|
||||||
|
)
|
||||||
|
l = nn.Linear(10, 20)
|
||||||
|
n = Net()
|
||||||
|
s = nn.Sequential(l, l, l, l)
|
||||||
|
self.assertEqual(num_params(l), 2)
|
||||||
|
self.assertEqual(num_params(n), 2)
|
||||||
|
self.assertEqual(num_params(s), 2)
|
||||||
|
|
||||||
def test_Dropout(self):
|
def test_Dropout(self):
|
||||||
input = torch.Tensor(1000)
|
input = torch.Tensor(1000)
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,20 @@ del old_flags
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
def typename(o):
|
def typename(o):
|
||||||
module = o.__module__ + '.'
|
module = ''
|
||||||
if module == '__builtin__.':
|
class_name = ''
|
||||||
module = ''
|
if hasattr(o, '__module__') and o.__module__ != 'builtins' \
|
||||||
return module + o.__class__.__name__
|
and o.__module__ != '__builtin__' and o.__module__ is not None:
|
||||||
|
module = o.__module__ + '.'
|
||||||
|
|
||||||
|
if hasattr(o, '__qualname__'):
|
||||||
|
class_name = o.__qualname__
|
||||||
|
elif hasattr(o, '__name__'):
|
||||||
|
class_name = o.__name__
|
||||||
|
else:
|
||||||
|
class_name = o.__class__.__name__
|
||||||
|
|
||||||
|
return module + class_name
|
||||||
|
|
||||||
|
|
||||||
def is_tensor(obj):
|
def is_tensor(obj):
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,17 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
class Container(Module):
|
class Container(Module):
|
||||||
"""This is the base container class for all neural networks you would define.
|
"""This is the base container class for all neural networks you would define.
|
||||||
You will subclass your container from this class.
|
You will subclass your container from this class.
|
||||||
In the constructor you define the modules that you would want to use,
|
In the constructor you define the modules that you would want to use,
|
||||||
and in the __call__ function you use the constructed modules in
|
and in the __call__ function you use the constructed modules in
|
||||||
your operations.
|
your operations.
|
||||||
|
|
||||||
To make it easier to understand, given is a small example.
|
To make it easier to understand, given is a small example.
|
||||||
```
|
```
|
||||||
# Example of using Container
|
# Example of using Container
|
||||||
|
|
@ -23,16 +25,16 @@ class Container(Module):
|
||||||
output = self.relu(self.conv1(x))
|
output = self.relu(self.conv1(x))
|
||||||
return output
|
return output
|
||||||
model = Net()
|
model = Net()
|
||||||
```
|
```
|
||||||
|
|
||||||
One can also add new modules to a container after construction.
|
One can also add new modules to a container after construction.
|
||||||
You can do this with the add_module function.
|
You can do this with the add_module function.
|
||||||
|
|
||||||
```
|
```
|
||||||
# one can add modules to the container after construction
|
# one can add modules to the container after construction
|
||||||
model.add_module('pool1', nn.MaxPool2d(2, 2)
|
model.add_module('pool1', nn.MaxPool2d(2, 2))
|
||||||
```
|
```
|
||||||
|
|
||||||
The container has one additional method `parameters()` which
|
The container has one additional method `parameters()` which
|
||||||
returns the list of learnable parameters in the container instance.
|
returns the list of learnable parameters in the container instance.
|
||||||
"""
|
"""
|
||||||
|
|
@ -45,13 +47,18 @@ class Container(Module):
|
||||||
def add_module(self, name, module):
|
def add_module(self, name, module):
|
||||||
if hasattr(self, name):
|
if hasattr(self, name):
|
||||||
raise KeyError("attribute already exists '{}'".format(name))
|
raise KeyError("attribute already exists '{}'".format(name))
|
||||||
|
if not isinstance(module, Module):
|
||||||
|
raise ValueError("{} is not a Module subclass".format(
|
||||||
|
torch.typename(module)))
|
||||||
setattr(self, name, module)
|
setattr(self, name, module)
|
||||||
if module is not None:
|
if module is not None:
|
||||||
self.modules.append(module)
|
self.modules.append(module)
|
||||||
|
|
||||||
def parameters(self):
|
def parameters(self, memo=None):
|
||||||
|
if memo is None:
|
||||||
|
memo = set()
|
||||||
for module in self.modules:
|
for module in self.modules:
|
||||||
for p in module.parameters():
|
for p in module.parameters(memo):
|
||||||
yield p
|
yield p
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,10 +75,16 @@ class Module(object):
|
||||||
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
|
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def parameters(self):
|
def parameters(self, memo=None):
|
||||||
if hasattr(self, 'weight') and self.weight is not None:
|
if memo is None:
|
||||||
|
memo = set()
|
||||||
|
if hasattr(self, 'weight') and self.weight is not None \
|
||||||
|
and self.weight not in memo:
|
||||||
|
memo.add(self.weight)
|
||||||
yield self.weight
|
yield self.weight
|
||||||
if hasattr(self, 'bias') and self.bias is not None:
|
if hasattr(self, 'bias') and self.bias is not None \
|
||||||
|
and self.bias not in memo:
|
||||||
|
memo.add(self.bias)
|
||||||
yield self.bias
|
yield self.bias
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user