mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Return only unique variables from parameters()
This commit is contained in:
parent
5030d76acf
commit
4cdeae3283
|
|
@ -203,6 +203,21 @@ class TestNN(NNTestCase):
|
||||||
module.__repr__()
|
module.__repr__()
|
||||||
str(module)
|
str(module)
|
||||||
|
|
||||||
|
def test_parameters(self):
|
||||||
|
def num_params(module):
|
||||||
|
return len(list(module.parameters()))
|
||||||
|
class Net(nn.Container):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__(
|
||||||
|
l1=l,
|
||||||
|
l2=l
|
||||||
|
)
|
||||||
|
l = nn.Linear(10, 20)
|
||||||
|
n = Net()
|
||||||
|
s = nn.Sequential(l, l, l, l)
|
||||||
|
self.assertEqual(num_params(l), 2)
|
||||||
|
self.assertEqual(num_params(n), 2)
|
||||||
|
self.assertEqual(num_params(s), 2)
|
||||||
|
|
||||||
def test_Dropout(self):
|
def test_Dropout(self):
|
||||||
input = torch.Tensor(1000)
|
input = torch.Tensor(1000)
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,20 @@ del old_flags
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
def typename(o):
|
def typename(o):
|
||||||
module = o.__module__ + '.'
|
module = ''
|
||||||
if module == '__builtin__.':
|
class_name = ''
|
||||||
module = ''
|
if hasattr(o, '__module__') and o.__module__ != 'builtins' \
|
||||||
return module + o.__class__.__name__
|
and o.__module__ != '__builtin__' and o.__module__ is not None:
|
||||||
|
module = o.__module__ + '.'
|
||||||
|
|
||||||
|
if hasattr(o, '__qualname__'):
|
||||||
|
class_name = o.__qualname__
|
||||||
|
elif hasattr(o, '__name__'):
|
||||||
|
class_name = o.__name__
|
||||||
|
else:
|
||||||
|
class_name = o.__class__.__name__
|
||||||
|
|
||||||
|
return module + class_name
|
||||||
|
|
||||||
|
|
||||||
def is_tensor(obj):
|
def is_tensor(obj):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
class Container(Module):
|
class Container(Module):
|
||||||
|
|
@ -30,7 +32,7 @@ class Container(Module):
|
||||||
|
|
||||||
```
|
```
|
||||||
# one can add modules to the container after construction
|
# one can add modules to the container after construction
|
||||||
model.add_module('pool1', nn.MaxPool2d(2, 2)
|
model.add_module('pool1', nn.MaxPool2d(2, 2))
|
||||||
```
|
```
|
||||||
|
|
||||||
The container has one additional method `parameters()` which
|
The container has one additional method `parameters()` which
|
||||||
|
|
@ -45,13 +47,18 @@ class Container(Module):
|
||||||
def add_module(self, name, module):
|
def add_module(self, name, module):
|
||||||
if hasattr(self, name):
|
if hasattr(self, name):
|
||||||
raise KeyError("attribute already exists '{}'".format(name))
|
raise KeyError("attribute already exists '{}'".format(name))
|
||||||
|
if not isinstance(module, Module):
|
||||||
|
raise ValueError("{} is not a Module subclass".format(
|
||||||
|
torch.typename(module)))
|
||||||
setattr(self, name, module)
|
setattr(self, name, module)
|
||||||
if module is not None:
|
if module is not None:
|
||||||
self.modules.append(module)
|
self.modules.append(module)
|
||||||
|
|
||||||
def parameters(self):
|
def parameters(self, memo=None):
|
||||||
|
if memo is None:
|
||||||
|
memo = set()
|
||||||
for module in self.modules:
|
for module in self.modules:
|
||||||
for p in module.parameters():
|
for p in module.parameters(memo):
|
||||||
yield p
|
yield p
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,10 +75,16 @@ class Module(object):
|
||||||
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
|
fn.register_hook(key, lambda gi,go,hook=hook: hook(self, gi, go))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def parameters(self):
|
def parameters(self, memo=None):
|
||||||
if hasattr(self, 'weight') and self.weight is not None:
|
if memo is None:
|
||||||
|
memo = set()
|
||||||
|
if hasattr(self, 'weight') and self.weight is not None \
|
||||||
|
and self.weight not in memo:
|
||||||
|
memo.add(self.weight)
|
||||||
yield self.weight
|
yield self.weight
|
||||||
if hasattr(self, 'bias') and self.bias is not None:
|
if hasattr(self, 'bias') and self.bias is not None \
|
||||||
|
and self.bias not in memo:
|
||||||
|
memo.add(self.bias)
|
||||||
yield self.bias
|
yield self.bias
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user