mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: 1. Adds a function to return auxiliary parameters for each optimizer. This function can be used to serialize the optimizers so that they can be recovered. 2. Fixes the bug that the iteration blob is not incremented by one in each iteration. Suppose there are k parameters using the adam learning rate optimizer, the iteration blob is incremented by k based on the original implementation. Reviewed By: azzolini Differential Revision: D4872397 fbshipit-source-id: d86711feedda2ba83af5f2a18141b06a6a473733
106 lines
4.0 KiB
Python
106 lines
4.0 KiB
Python
## @package optimizer_test_util
|
|
# Module caffe2.python.optimizer_test_util
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
import numpy as np
|
|
from caffe2.python import core, workspace, cnn
|
|
|
|
|
|
class OptimizerTestBase(object):
|
|
"""
|
|
This is an abstract base class.
|
|
Don't inherit from unittest.TestCase, and don't name it 'Test*'.
|
|
Do, however, do these things in classes which inherit from this.
|
|
"""
|
|
|
|
def testDense(self):
|
|
perfect_model = np.array([2, 6, 5, 0, 1]).astype(np.float32)
|
|
np.random.seed(123) # make test deterministic
|
|
data = np.random.randint(
|
|
2,
|
|
size=(20, perfect_model.size)).astype(np.float32)
|
|
label = np.dot(data, perfect_model)[:, np.newaxis]
|
|
|
|
model = cnn.CNNModelHelper("NCHW", name="test")
|
|
out = model.FC(
|
|
'data', 'fc', perfect_model.size, 1, ('ConstantFill', {}),
|
|
('ConstantFill', {}), axis=0
|
|
)
|
|
sq = model.SquaredL2Distance([out, 'label'])
|
|
loss = model.AveragedLoss(sq, "avg_loss")
|
|
grad_map = model.AddGradientOperators([loss])
|
|
self.assertIsInstance(grad_map['fc_w'], core.BlobReference)
|
|
optimizer = self.build_optimizer(model)
|
|
|
|
workspace.FeedBlob('data', data[0])
|
|
workspace.FeedBlob('label', label[0])
|
|
workspace.RunNetOnce(model.param_init_net)
|
|
workspace.CreateNet(model.net)
|
|
for _ in range(2000):
|
|
idx = np.random.randint(data.shape[0])
|
|
workspace.FeedBlob('data', data[idx])
|
|
workspace.FeedBlob('label', label[idx])
|
|
workspace.RunNet(model.net.Proto().name)
|
|
|
|
np.testing.assert_allclose(
|
|
perfect_model[np.newaxis, :],
|
|
workspace.FetchBlob('fc_w'),
|
|
atol=1e-2
|
|
)
|
|
self.check_optimizer(optimizer)
|
|
|
|
def testSparse(self):
|
|
# to test duplicated indices we assign two indices to each weight and
|
|
# thus each weight might count once or twice
|
|
DUPLICATION = 2
|
|
perfect_model = np.array([2, 6, 5, 0, 1]).astype(np.float32)
|
|
np.random.seed(123) # make test deterministic
|
|
data = np.random.randint(
|
|
2,
|
|
size=(20, perfect_model.size * DUPLICATION)).astype(np.float32)
|
|
label = np.dot(data, np.repeat(perfect_model, DUPLICATION))
|
|
|
|
model = cnn.CNNModelHelper("NCHW", name="test")
|
|
# imitate what model wrapper does
|
|
w = model.param_init_net.ConstantFill(
|
|
[], 'w', shape=[perfect_model.size], value=0.0)
|
|
model.params.append(w)
|
|
picked = model.net.Gather([w, 'indices'], 'gather')
|
|
out = model.ReduceFrontSum(picked, 'sum')
|
|
|
|
sq = model.SquaredL2Distance([out, 'label'])
|
|
loss = model.AveragedLoss(sq, "avg_loss")
|
|
grad_map = model.AddGradientOperators([loss])
|
|
self.assertIsInstance(grad_map['w'], core.GradientSlice)
|
|
optimizer = self.build_optimizer(model)
|
|
|
|
workspace.CreateBlob('indices')
|
|
workspace.CreateBlob('label')
|
|
|
|
for indices_type in [np.int32, np.int64]:
|
|
workspace.RunNetOnce(model.param_init_net)
|
|
workspace.CreateNet(model.net)
|
|
for _ in range(2000):
|
|
idx = np.random.randint(data.shape[0])
|
|
# transform into indices of binary features
|
|
indices = np.repeat(np.arange(perfect_model.size),
|
|
DUPLICATION)[data[idx] == 1]
|
|
if indices.size == 0:
|
|
continue
|
|
workspace.FeedBlob(
|
|
'indices',
|
|
indices.reshape((indices.size,)).astype(indices_type)
|
|
)
|
|
workspace.FeedBlob('label',
|
|
np.array(label[idx]).astype(np.float32))
|
|
workspace.RunNet(model.net.Proto().name)
|
|
|
|
np.testing.assert_allclose(
|
|
perfect_model,
|
|
workspace.FetchBlob('w'),
|
|
atol=1e-2
|
|
)
|
|
self.check_optimizer(optimizer)
|