mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Return only unique variables from parameters()
This commit is contained in:
parent
5030d76acf
commit
4cdeae3283
|
|
@ -203,6 +203,21 @@ class TestNN(NNTestCase):
|
|||
module.__repr__()
|
||||
str(module)
|
||||
|
||||
def test_parameters(self):
|
||||
def num_params(module):
|
||||
return len(list(module.parameters()))
|
||||
class Net(nn.Container):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__(
|
||||
l1=l,
|
||||
l2=l
|
||||
)
|
||||
l = nn.Linear(10, 20)
|
||||
n = Net()
|
||||
s = nn.Sequential(l, l, l, l)
|
||||
self.assertEqual(num_params(l), 2)
|
||||
self.assertEqual(num_params(n), 2)
|
||||
self.assertEqual(num_params(s), 2)
|
||||
|
||||
def test_Dropout(self):
|
||||
input = torch.Tensor(1000)
|
||||
|
|
|
|||
|
|
@ -24,10 +24,20 @@ del old_flags
|
|||
################################################################################
|
||||
|
||||
def typename(o):
|
||||
module = o.__module__ + '.'
|
||||
if module == '__builtin__.':
|
||||
module = ''
|
||||
return module + o.__class__.__name__
|
||||
module = ''
|
||||
class_name = ''
|
||||
if hasattr(o, '__module__') and o.__module__ != 'builtins' \
|
||||
and o.__module__ != '__builtin__' and o.__module__ is not None:
|
||||
module = o.__module__ + '.'
|
||||
|
||||
if hasattr(o, '__qualname__'):
|
||||
class_name = o.__qualname__
|
||||
elif hasattr(o, '__name__'):
|
||||
class_name = o.__name__
|
||||
else:
|
||||
class_name = o.__class__.__name__
|
||||
|
||||
return module + class_name
|
||||
|
||||
|
||||
def is_tensor(obj):
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from .module import Module
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class Container(Module):
|
||||
"""This is the base container class for all neural networks you would define.
|
||||
You will subclass your container from this class.
|
||||
In the constructor you define the modules that you would want to use,
|
||||
and in the __call__ function you use the constructed modules in
|
||||
In the constructor you define the modules that you would want to use,
|
||||
and in the __call__ function you use the constructed modules in
|
||||
your operations.
|
||||
|
||||
|
||||
To make it easier to understand, given is a small example.
|
||||
```
|
||||
# Example of using Container
|
||||
|
|
@ -23,16 +25,16 @@ class Container(Module):
|
|||
output = self.relu(self.conv1(x))
|
||||
return output
|
||||
model = Net()
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
One can also add new modules to a container after construction.
|
||||
You can do this with the add_module function.
|
||||
|
||||
```
|
||||
# one can add modules to the container after construction
|
||||
model.add_module('pool1', nn.MaxPool2d(2, 2)
|
||||
model.add_module('pool1', nn.MaxPool2d(2, 2))
|
||||
```
|
||||
|
||||
|
||||
The container has one additional method `parameters()` which
|
||||
returns the list of learnable parameters in the container instance.
|
||||
"""
|
||||
|
|
@ -45,13 +47,18 @@ class Container(Module):
|
|||
def add_module(self, name, module):
|
||||
if hasattr(self, name):
|
||||
raise KeyError("attribute already exists '{}'".format(name))
|
||||
if not isinstance(module, Module):
|
||||
raise ValueError("{} is not a Module subclass".format(
|
||||
torch.typename(module)))
|
||||
setattr(self, name, module)
|
||||
if module is not None:
|
||||
self.modules.append(module)
|
||||
|
||||
def parameters(self):
|
||||
def parameters(self, memo=None):
|
||||
if memo is None:
|
||||
memo = set()
|
||||
for module in self.modules:
|
||||
for p in module.parameters():
|
||||
for p in module.parameters(memo):
|
||||
yield p
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -75,10 +75,16 @@ class Module(object):
|
|||
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
|
||||
return result
|
||||
|
||||
def parameters(self):
|
||||
if hasattr(self, 'weight') and self.weight is not None:
|
||||
def parameters(self, memo=None):
|
||||
if memo is None:
|
||||
memo = set()
|
||||
if hasattr(self, 'weight') and self.weight is not None \
|
||||
and self.weight not in memo:
|
||||
memo.add(self.weight)
|
||||
yield self.weight
|
||||
if hasattr(self, 'bias') and self.bias is not None:
|
||||
if hasattr(self, 'bias') and self.bias is not None \
|
||||
and self.bias not in memo:
|
||||
memo.add(self.bias)
|
||||
yield self.bias
|
||||
|
||||
def zero_grad(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user