mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
348 lines
12 KiB
Python
348 lines
12 KiB
Python
import unittest
|
|
import functools
|
|
from copy import deepcopy
|
|
import torch
|
|
import torch.optim as optim
|
|
import torch.legacy.optim as old_optim
|
|
from torch.autograd import Variable
|
|
|
|
from common import TestCase
|
|
|
|
|
|
def rosenbrock(tensor):
|
|
x, y = tensor
|
|
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
|
|
|
|
|
|
def drosenbrock(tensor):
|
|
x, y = tensor
|
|
return torch.DoubleTensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
|
|
|
|
|
|
def wrap_old_fn(old_fn, **config):
|
|
def wrapper(closure, params, state):
|
|
return old_fn(closure, params, config, state)
|
|
return wrapper
|
|
|
|
|
|
class TestOptim(TestCase):
|
|
|
|
def _test_rosenbrock(self, constructor, old_fn):
|
|
params_t = torch.Tensor([1.5, 1.5])
|
|
state = {}
|
|
|
|
params = Variable(torch.Tensor([1.5, 1.5]), requires_grad=True)
|
|
optimizer = constructor([params])
|
|
|
|
solution = torch.Tensor([1, 1])
|
|
initial_dist = params.data.dist(solution)
|
|
|
|
def eval():
|
|
optimizer.zero_grad()
|
|
loss = rosenbrock(params)
|
|
loss.backward()
|
|
# loss.backward() will give **slightly** different
|
|
# gradients, than drosenbtock, because of a different ordering
|
|
# of floating point operations. In most cases it doesn't matter,
|
|
# but some optimizers are so sensitive that they can temporarily
|
|
# diverge up to 1e-4, just to converge again. This makes the
|
|
# comparison more stable.
|
|
params.grad.data.copy_(drosenbrock(params.data))
|
|
return loss
|
|
|
|
for i in range(2000):
|
|
optimizer.step(eval)
|
|
old_fn(lambda _: (rosenbrock(params_t), drosenbrock(params_t)),
|
|
params_t, state)
|
|
self.assertEqual(params.data, params_t)
|
|
|
|
self.assertLessEqual(params.data.dist(solution), initial_dist)
|
|
|
|
def _test_basic_cases_template(self, weight, bias, input, constructor):
|
|
weight = Variable(weight, requires_grad=True)
|
|
bias = Variable(bias, requires_grad=True)
|
|
input = Variable(input)
|
|
optimizer = constructor(weight, bias)
|
|
|
|
def fn():
|
|
optimizer.zero_grad()
|
|
y = weight.mv(input)
|
|
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
|
y = y.cuda(bias.get_device())
|
|
loss = (y + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
initial_value = fn().data[0]
|
|
for i in range(200):
|
|
optimizer.step(fn)
|
|
self.assertLess(fn().data[0], initial_value)
|
|
|
|
def _test_state_dict(self, weight, bias, input, constructor):
|
|
weight = Variable(weight, requires_grad=True)
|
|
bias = Variable(bias, requires_grad=True)
|
|
input = Variable(input)
|
|
|
|
def fn_base(optimizer, weight, bias):
|
|
optimizer.zero_grad()
|
|
loss = (weight.mv(input) + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer = constructor(weight, bias)
|
|
fn = functools.partial(fn_base, optimizer, weight, bias)
|
|
|
|
# Prime the optimizer
|
|
for i in range(20):
|
|
optimizer.step(fn)
|
|
# Clone the weights and construct new optimizer for them
|
|
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
|
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
|
optimizer_c = constructor(weight_c, bias_c)
|
|
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
|
# Load state dict
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_c.load_state_dict(state_dict_c)
|
|
# Run both optimizations in parallel
|
|
for i in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_c.step(fn_c)
|
|
self.assertEqual(weight, weight_c)
|
|
self.assertEqual(bias, bias_c)
|
|
# Make sure state dict wasn't modified
|
|
self.assertEqual(state_dict, state_dict_c)
|
|
|
|
def _test_basic_cases(self, constructor, ignore_multidevice=False):
|
|
self._test_state_dict(
|
|
torch.randn(10, 5),
|
|
torch.randn(10),
|
|
torch.randn(5),
|
|
constructor
|
|
)
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5),
|
|
torch.randn(10),
|
|
torch.randn(5),
|
|
constructor
|
|
)
|
|
# non-contiguous parameters
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5, 2)[...,0],
|
|
torch.randn(10, 2)[...,0],
|
|
torch.randn(5),
|
|
constructor
|
|
)
|
|
# CUDA
|
|
if not torch.cuda.is_available():
|
|
return
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(),
|
|
torch.randn(10).cuda(),
|
|
torch.randn(5).cuda(),
|
|
constructor
|
|
)
|
|
# Multi-GPU
|
|
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
|
return
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(0),
|
|
torch.randn(10).cuda(1),
|
|
torch.randn(5).cuda(0),
|
|
constructor
|
|
)
|
|
|
|
def _build_params_dict(self, weight, bias, **kwargs):
|
|
return [dict(params=[weight]), dict(params=[bias], **kwargs)]
|
|
|
|
def test_sgd(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.SGD(params, lr=1e-3),
|
|
wrap_old_fn(old_optim.sgd, learningRate=1e-3)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.SGD(params, lr=1e-3, momentum=0.9,
|
|
dampening=0, weight_decay=1e-4),
|
|
wrap_old_fn(old_optim.sgd, learningRate=1e-3, momentum=0.9,
|
|
dampening=0, weightDecay=1e-4)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.SGD([weight, bias], lr=1e-3)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.SGD(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3)
|
|
)
|
|
|
|
def test_adam(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adam(params, lr=1e-2),
|
|
wrap_old_fn(old_optim.adam, learningRate=1e-2)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adam(params, lr=1e-2, weight_decay=1e-2),
|
|
wrap_old_fn(old_optim.adam, learningRate=1e-2, weightDecay=1e-2)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adam([weight, bias], lr=1e-3)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3)
|
|
)
|
|
|
|
def test_adadelta(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adadelta(params),
|
|
wrap_old_fn(old_optim.adadelta)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adadelta(params, rho=0.95),
|
|
wrap_old_fn(old_optim.adadelta, rho=0.95)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adadelta(params, weight_decay=1e-2),
|
|
wrap_old_fn(old_optim.adadelta, weightDecay=1e-2)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adadelta([weight, bias])
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adadelta(
|
|
self._build_params_dict(weight, bias, rho=0.95))
|
|
)
|
|
|
|
def test_adagrad(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adagrad(params, lr=1e-1),
|
|
wrap_old_fn(old_optim.adagrad, learningRate=1e-1)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adagrad(params, lr=1e-1, lr_decay=1e-3),
|
|
wrap_old_fn(old_optim.adagrad, learningRate=1e-1, learningRateDecay=1e-3)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adagrad(params, lr=1e-1, weight_decay=1e-2),
|
|
wrap_old_fn(old_optim.adagrad, learningRate=1e-1, weightDecay=1e-2)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adagrad([weight, bias], lr=1e-1)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1)
|
|
)
|
|
|
|
def test_adamax(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adamax(params, lr=1e-1),
|
|
wrap_old_fn(old_optim.adamax, learningRate=1e-1)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adamax(params, lr=1e-1, weight_decay=1e-2),
|
|
wrap_old_fn(old_optim.adamax, learningRate=1e-1, weightDecay=1e-2)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Adamax(params, lr=1e-1, betas=(0.95, 0.998)),
|
|
wrap_old_fn(old_optim.adamax, learningRate=1e-1, beta1=0.95, beta2=0.998)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adagrad([weight, bias], lr=1e-1)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1)
|
|
)
|
|
|
|
def test_rmsprop(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.RMSprop(params, lr=1e-2),
|
|
wrap_old_fn(old_optim.rmsprop, learningRate=1e-2)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.RMSprop(params, lr=1e-2, weight_decay=1e-2),
|
|
wrap_old_fn(old_optim.rmsprop, learningRate=1e-2, weightDecay=1e-2)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.RMSprop(params, lr=1e-2, alpha=0.95),
|
|
wrap_old_fn(old_optim.rmsprop, learningRate=1e-2, alpha=0.95)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adagrad([weight, bias], lr=1e-2)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2)
|
|
)
|
|
|
|
def test_asgd(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.ASGD(params, lr=1e-3),
|
|
wrap_old_fn(old_optim.asgd, eta0=1e-3)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.ASGD(params, lr=1e-3, alpha=0.8),
|
|
wrap_old_fn(old_optim.asgd, eta0=1e-3, alpha=0.8)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.ASGD(params, lr=1e-3, t0=1e3),
|
|
wrap_old_fn(old_optim.asgd, eta0=1e-3, t0=1e3)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.ASGD([weight, bias], lr=1e-3, t0=100)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.ASGD(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3, t0=100)
|
|
)
|
|
|
|
def test_rprop(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Rprop(params, lr=1e-3),
|
|
wrap_old_fn(old_optim.rprop, stepsize=1e-3)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Rprop(params, lr=1e-3, etas=(0.6, 1.1)),
|
|
wrap_old_fn(old_optim.rprop, stepsize=1e-3, etaminus=0.6, etaplus=1.1)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.Rprop(params, lr=1e-3, step_sizes=(1e-4, 3)),
|
|
wrap_old_fn(old_optim.rprop, stepsize=1e-3, stepsizemin=1e-4, stepsizemax=3)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Rprop([weight, bias], lr=1e-3)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.Rprop(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3)
|
|
)
|
|
|
|
def test_lbfgs(self):
|
|
self._test_rosenbrock(
|
|
lambda params: optim.LBFGS(params),
|
|
wrap_old_fn(old_optim.lbfgs)
|
|
)
|
|
self._test_rosenbrock(
|
|
lambda params: optim.LBFGS(params, lr=5e-2, max_iter=5),
|
|
wrap_old_fn(old_optim.lbfgs, learningRate=5e-2, maxIter=5)
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.LBFGS([weight, bias]),
|
|
ignore_multidevice=True
|
|
)
|
|
|
|
def test_invalid_param_type(self):
|
|
with self.assertRaises(TypeError):
|
|
optim.SGD(Variable(torch.randn(5, 5)), lr=3)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|