mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676 Approved by: https://github.com/ezyang
4542 lines
171 KiB
Python
4542 lines
171 KiB
Python
# Owner(s): ["module: optimizer"]
|
|
|
|
import warnings
|
|
import math
|
|
import unittest
|
|
import functools
|
|
import itertools
|
|
import pickle
|
|
from copy import deepcopy
|
|
import weakref
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter
|
|
from torch.optim import Adam, SGD, Optimizer
|
|
from torch import sparse
|
|
from torch.optim.lr_scheduler import (
|
|
LambdaLR,
|
|
MultiplicativeLR,
|
|
SequentialLR,
|
|
StepLR,
|
|
MultiStepLR,
|
|
ConstantLR,
|
|
LinearLR,
|
|
ExponentialLR,
|
|
CosineAnnealingLR,
|
|
ReduceLROnPlateau,
|
|
LRScheduler,
|
|
CyclicLR,
|
|
CosineAnnealingWarmRestarts,
|
|
OneCycleLR,
|
|
ChainedScheduler,
|
|
PolynomialLR,
|
|
EPOCH_DEPRECATION_WARNING,
|
|
)
|
|
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
run_tests,
|
|
TEST_WITH_UBSAN,
|
|
load_tests,
|
|
parametrize,
|
|
instantiate_parametrized_tests,
|
|
gradcheck,
|
|
skipIfRocm,
|
|
skipIfTorchDynamo
|
|
)
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
|
from typing import Dict, Any, Tuple
|
|
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
|
|
def rosenbrock(tensor):
|
|
x, y = tensor
|
|
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
|
|
|
|
|
|
def drosenbrock(tensor):
|
|
x, y = tensor
|
|
return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
|
|
|
|
|
|
class TestOptim(TestCase):
|
|
exact_dtype = True
|
|
|
|
def _test_rosenbrock_sparse(
|
|
self,
|
|
constructor,
|
|
scheduler_constructors=None,
|
|
sparse_only=False,
|
|
maximize=False,
|
|
):
|
|
if scheduler_constructors is None:
|
|
scheduler_constructors = []
|
|
params_t = torch.tensor([1.5, 1.5])
|
|
|
|
params = Parameter(params_t)
|
|
optimizer = constructor([params])
|
|
schedulers = []
|
|
for scheduler_constructor in scheduler_constructors:
|
|
schedulers.append(scheduler_constructor(optimizer))
|
|
|
|
if not sparse_only:
|
|
params_c = Parameter(params_t.clone())
|
|
optimizer_c = constructor([params_c])
|
|
|
|
solution = torch.tensor([1, 1])
|
|
with torch.no_grad():
|
|
initial_dist = params.dist(solution)
|
|
|
|
def eval(params, sparse_grad, w):
|
|
# Depending on w, provide only the x or y gradient
|
|
optimizer.zero_grad()
|
|
loss = rosenbrock(params)
|
|
loss.backward()
|
|
grad = drosenbrock(params.data)
|
|
# NB: We torture test the optimizer by returning an
|
|
# uncoalesced sparse tensor
|
|
if w:
|
|
i = torch.LongTensor([[0, 0]])
|
|
x = grad[0]
|
|
v = torch.tensor([x / 4.0, x - x / 4.0])
|
|
else:
|
|
i = torch.LongTensor([[1, 1]])
|
|
y = grad[1]
|
|
v = torch.tensor([y - y / 4.0, y / 4.0])
|
|
x = sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype)
|
|
with torch.no_grad():
|
|
if sparse_grad:
|
|
params.grad = x
|
|
else:
|
|
params.grad = x.to_dense()
|
|
return loss
|
|
|
|
for i in range(2000):
|
|
# Do cyclic coordinate descent
|
|
w = i % 2
|
|
optimizer.step(functools.partial(eval, params, True, w))
|
|
for scheduler in schedulers:
|
|
if isinstance(scheduler, ReduceLROnPlateau):
|
|
scheduler.step(rosenbrock(params))
|
|
else:
|
|
scheduler.step()
|
|
if not sparse_only:
|
|
optimizer_c.step(functools.partial(eval, params_c, False, w))
|
|
self.assertEqual(params, params_c)
|
|
|
|
if not maximize:
|
|
self.assertLessEqual(params.data.dist(solution), initial_dist)
|
|
else:
|
|
self.assertGreaterEqual(rosenbrock(params), rosenbrock(params_t))
|
|
|
|
def _test_basic_cases_template(
|
|
self,
|
|
weight_tensor,
|
|
bias_tensor,
|
|
input_tensor,
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=False,
|
|
):
|
|
maximize_options = {False, constructor_accepts_maximize}
|
|
foreach_options = {False, constructor_accepts_foreach}
|
|
|
|
four_arg_constructor = constructor
|
|
if constructor_accepts_maximize and constructor_accepts_foreach:
|
|
pass
|
|
elif constructor_accepts_maximize:
|
|
|
|
def four_arg_constructor(weight, bias, maximize, foreach):
|
|
self.assertFalse(foreach)
|
|
return constructor(weight, bias, maximize)
|
|
|
|
elif constructor_accepts_foreach:
|
|
|
|
def four_arg_constructor(weight, bias, maximize, foreach):
|
|
self.assertFalse(maximize)
|
|
return constructor(weight, bias, foreach)
|
|
|
|
else:
|
|
|
|
def four_arg_constructor(weight, bias, maximize, foreach):
|
|
self.assertFalse(maximize or foreach)
|
|
return constructor(weight, bias)
|
|
|
|
for maximize, foreach in itertools.product(maximize_options, foreach_options):
|
|
with torch.no_grad():
|
|
weight = Parameter(weight_tensor.clone().detach())
|
|
bias = Parameter(bias_tensor.clone().detach())
|
|
input = input_tensor.clone().detach().requires_grad_()
|
|
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
|
|
schedulers = []
|
|
for scheduler_constructor in scheduler_constructors:
|
|
schedulers.append(scheduler_constructor(optimizer))
|
|
|
|
# to check if the optimizer can be printed as a string
|
|
optimizer.__repr__()
|
|
|
|
def fn():
|
|
optimizer.zero_grad()
|
|
y = weight.mv(input)
|
|
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
|
y = y.cuda(bias.get_device())
|
|
loss = (y + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
initial_value = fn().item()
|
|
for _ in range(200):
|
|
for scheduler in schedulers:
|
|
if isinstance(scheduler, ReduceLROnPlateau):
|
|
val_loss = fn()
|
|
scheduler.step(val_loss)
|
|
else:
|
|
scheduler.step()
|
|
optimizer.step(fn)
|
|
if maximize:
|
|
self.assertGreater(fn().item(), initial_value)
|
|
else:
|
|
self.assertLess(fn().item(), initial_value)
|
|
|
|
def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None):
|
|
weight = Parameter(weight)
|
|
bias = Parameter(bias)
|
|
with torch.no_grad():
|
|
input = input.clone().detach().requires_grad_()
|
|
|
|
def fn_base(optimizer, weight, bias):
|
|
optimizer.zero_grad()
|
|
i = input_cuda if weight.is_cuda else input
|
|
loss = (weight.mv(i) + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer = constructor(weight, bias)
|
|
fn = functools.partial(fn_base, optimizer, weight, bias)
|
|
|
|
# Prime the optimizer
|
|
for _i in range(20):
|
|
optimizer.step(fn)
|
|
# Clone the weights and construct new optimizer for them
|
|
with torch.no_grad():
|
|
weight_c = Parameter(weight.clone().detach())
|
|
bias_c = Parameter(bias.clone().detach())
|
|
optimizer_c = constructor(weight_c, bias_c)
|
|
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
|
# Load state dict
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_c.load_state_dict(state_dict_c)
|
|
# Run both optimizations in parallel
|
|
for _ in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_c.step(fn_c)
|
|
self.assertEqual(weight, weight_c)
|
|
self.assertEqual(bias, bias_c)
|
|
# Make sure state dict wasn't modified
|
|
self.assertEqual(state_dict, state_dict_c)
|
|
# Make sure state dict is deterministic with equal but not identical parameters
|
|
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
|
# Make sure repeated parameters have identical representation in state dict
|
|
optimizer_c.param_groups.extend(optimizer_c.param_groups)
|
|
self.assertEqual(
|
|
optimizer.state_dict()["param_groups"][-1],
|
|
optimizer_c.state_dict()["param_groups"][-1],
|
|
)
|
|
|
|
# Make sure that optimizers that support maximize can load older models
|
|
state_dict = optimizer.state_dict()
|
|
if "maximize" in state_dict["param_groups"][0]:
|
|
for group in state_dict["param_groups"]:
|
|
del group["maximize"]
|
|
optimizer.load_state_dict(state_dict)
|
|
# Make sure we can still step
|
|
optimizer.step()
|
|
# Make sure that optimizers that support foreach can load older models
|
|
state_dict = optimizer.state_dict()
|
|
if "foreach" in state_dict["param_groups"][0]:
|
|
for group in state_dict["param_groups"]:
|
|
del group["foreach"]
|
|
optimizer.load_state_dict(state_dict)
|
|
# Make sure we can still step
|
|
optimizer.step()
|
|
|
|
# Make sure that loading optimizers with step not wrapped in tensor can work
|
|
state_dict = optimizer.state_dict()
|
|
if "step" in state_dict["state"][0] and torch.is_tensor(
|
|
state_dict["state"][0]["step"]
|
|
):
|
|
for state in state_dict["state"].values():
|
|
state["step"] = state["step"].item()
|
|
optimizer.load_state_dict(state_dict)
|
|
optimizer.step()
|
|
|
|
# Check that state dict can be loaded even when we cast parameters
|
|
# to a different type and move to a different device.
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
with torch.no_grad():
|
|
input_cuda = input.clone().detach().to(dtype=torch.float32, device="cuda")
|
|
weight_cuda = Parameter(
|
|
weight.clone().detach().to(dtype=torch.float32, device="cuda")
|
|
)
|
|
bias_cuda = Parameter(
|
|
bias.clone().detach().to(dtype=torch.float32, device="cuda")
|
|
)
|
|
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
|
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
|
|
|
state_dict = deepcopy(optimizer.state_dict())
|
|
state_dict_c = deepcopy(optimizer.state_dict())
|
|
optimizer_cuda.load_state_dict(state_dict_c)
|
|
|
|
# Make sure state dict wasn't modified
|
|
self.assertEqual(state_dict, state_dict_c)
|
|
|
|
# Make sure that device of state['step'] is still CPU
|
|
new_state_dict = optimizer_cuda.state_dict()
|
|
if "step" in state_dict["state"][0] and torch.is_tensor(
|
|
state_dict["state"][0]["step"]
|
|
):
|
|
for state in new_state_dict["state"].values():
|
|
self.assertEqual(state["step"].device.type, "cpu")
|
|
|
|
for _i in range(20):
|
|
optimizer.step(fn)
|
|
optimizer_cuda.step(fn_cuda)
|
|
self.assertEqual(weight, weight_cuda)
|
|
self.assertEqual(bias, bias_cuda, atol=atol, rtol=rtol)
|
|
|
|
# validate deepcopy() copies all public attributes
|
|
def getPublicAttr(obj):
|
|
return {k for k in obj.__dict__ if not k.startswith("_")}
|
|
|
|
self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
|
|
|
|
def _test_basic_cases(
|
|
self,
|
|
constructor,
|
|
scheduler_constructors=None,
|
|
ignore_multidevice=False,
|
|
constructor_accepts_maximize=False,
|
|
constructor_accepts_foreach=False,
|
|
atol=None,
|
|
rtol=None,
|
|
):
|
|
if scheduler_constructors is None:
|
|
scheduler_constructors = []
|
|
|
|
def make_two_arg_constructor(
|
|
constructor, maximize: bool = False, foreach: bool = False
|
|
):
|
|
if constructor_accepts_maximize and constructor_accepts_foreach:
|
|
return lambda weight, bias: constructor(weight, bias, maximize, foreach)
|
|
if constructor_accepts_maximize:
|
|
return lambda weight, bias: constructor(weight, bias, maximize)
|
|
if constructor_accepts_foreach:
|
|
return lambda weight, bias: constructor(weight, bias, foreach)
|
|
return constructor
|
|
|
|
for maximize, foreach in itertools.product(
|
|
{False, constructor_accepts_maximize},
|
|
{False, constructor_accepts_foreach},
|
|
):
|
|
self._test_state_dict(
|
|
torch.randn(10, 5),
|
|
torch.randn(10),
|
|
torch.randn(5),
|
|
make_two_arg_constructor(constructor, maximize, foreach),
|
|
atol=atol,
|
|
rtol=rtol,
|
|
)
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5),
|
|
torch.randn(10),
|
|
torch.randn(5),
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize,
|
|
constructor_accepts_foreach,
|
|
)
|
|
# non-contiguous parameters
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5, 2)[..., 0],
|
|
torch.randn(10, 2)[..., 0],
|
|
torch.randn(5),
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize,
|
|
constructor_accepts_foreach,
|
|
)
|
|
# CUDA
|
|
if not torch.cuda.is_available():
|
|
return
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(),
|
|
torch.randn(10).cuda(),
|
|
torch.randn(5).cuda(),
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize,
|
|
constructor_accepts_foreach,
|
|
)
|
|
# Multi-GPU
|
|
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
|
return
|
|
self._test_basic_cases_template(
|
|
torch.randn(10, 5).cuda(0),
|
|
torch.randn(10).cuda(1),
|
|
torch.randn(5).cuda(0),
|
|
constructor,
|
|
scheduler_constructors,
|
|
constructor_accepts_maximize,
|
|
constructor_accepts_foreach,
|
|
)
|
|
|
|
def _test_complex_optimizer(self, optimizer_constructor):
|
|
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
|
|
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
|
|
complex_opt = optimizer_constructor(complex_param)
|
|
real_opt = optimizer_constructor(real_param)
|
|
|
|
for _ in range(3):
|
|
complex_param.grad = torch.randn_like(complex_param)
|
|
real_param.grad = torch.view_as_real(complex_param.grad)
|
|
complex_opt.step()
|
|
real_opt.step()
|
|
|
|
self.assertEqual(torch.view_as_real(complex_param), real_param)
|
|
|
|
def _test_complex_2d(self, optimizer_constructor, f=None):
|
|
if f is None:
|
|
f = rosenbrock
|
|
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
|
|
a1_real = a1.real.clone().detach()
|
|
a1_imag = a1.imag.clone().detach()
|
|
a1_real.requires_grad_()
|
|
a1_imag.requires_grad_()
|
|
optim1 = optimizer_constructor([a1])
|
|
optim2 = optimizer_constructor([a1_real, a1_imag])
|
|
|
|
for _ in range(10):
|
|
optim1.zero_grad()
|
|
optim2.zero_grad()
|
|
a2 = torch.complex(a1_real, a1_imag)
|
|
f(a1).backward()
|
|
f(a2).backward()
|
|
|
|
self.assertEqual(a1.grad.real, a1_real.grad)
|
|
self.assertEqual(a1.grad.imag, a1_imag.grad)
|
|
|
|
optim1.step()
|
|
optim2.step()
|
|
self.assertEqual(a1.real, a1_real)
|
|
self.assertEqual(a1.imag, a1_imag)
|
|
|
|
def _build_params_dict(self, weight, bias, **kwargs):
|
|
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
|
|
|
|
def _build_params_dict_single(self, weight, bias, **kwargs):
|
|
return [dict(params=bias, **kwargs)]
|
|
|
|
def test_sgd(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
self._build_params_dict_single(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
self._build_params_dict_single(weight, bias, lr=1e-2),
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[
|
|
lambda opt: LinearLR(
|
|
opt, start_factor=0.4, end_factor=0.8, total_iters=4
|
|
)
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: LinearLR(
|
|
opt, start_factor=0.4, end_factor=0.6, total_iters=4
|
|
),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
|
lambda opt: ExponentialLR(opt, gamma=0.99),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
momentum=0.5,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
momentum=0.5,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias],
|
|
nesterov=True,
|
|
lr=1e-3,
|
|
momentum=0.5,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.SGD(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"):
|
|
optim.SGD(None, lr=1e-2, momentum=-0.5)
|
|
|
|
def test_sgd_sparse(self):
|
|
for foreach in (False, True):
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.SGD(params, lr=4.8e-3, foreach=foreach)
|
|
)
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.SGD(params, lr=0.0048, foreach=foreach),
|
|
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
|
|
)
|
|
|
|
def test_sgd_complex(self):
|
|
for foreach in (False, True):
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD([param], lr=0.001, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD([param], lr=0.001, momentum=1, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD(
|
|
[param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach
|
|
)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD(
|
|
[param],
|
|
lr=0.001,
|
|
nesterov=True,
|
|
momentum=1,
|
|
weight_decay=1,
|
|
foreach=foreach,
|
|
)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.SGD(
|
|
[param],
|
|
lr=0.001,
|
|
momentum=1,
|
|
dampening=0.5,
|
|
weight_decay=1,
|
|
foreach=foreach,
|
|
)
|
|
)
|
|
|
|
|
|
def _test_derived_optimizers_varying_tensors(self, optimizer_with_kwargs, kwarg):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
assert kwarg in ("foreach", "fused")
|
|
|
|
# Specifically test that inputting params of different dtypes and devices
|
|
# is handled equivalently on the foreach and fused implementations as the
|
|
# single tensor implementations. We need multiple GPUs (vs just a CPU and
|
|
# GPU) because fused adam only works on GPUs. (Thus we only run the tests
|
|
# that call into this helper when TEST_MULTIGPU.)
|
|
params = [
|
|
torch.rand(2, 3, dtype=torch.float64, device='cuda:0', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float32, device='cuda:0', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float16, device='cuda:0', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.bfloat16, device='cuda:0', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float64, device='cuda:1', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float32, device='cuda:1', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.float16, device='cuda:1', requires_grad=True),
|
|
torch.rand(2, 3, dtype=torch.bfloat16, device='cuda:1', requires_grad=True),
|
|
torch.randint(1024, (2, 3), dtype=torch.int64, device='cuda:1', requires_grad=False),
|
|
]
|
|
|
|
for p in params:
|
|
if p.requires_grad:
|
|
p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype)
|
|
|
|
kIterations = 7 if kwarg == "foreach" else 1
|
|
for optimizer_constructor, kwargs in optimizer_with_kwargs:
|
|
res, state = [], []
|
|
for enabled in (False, True):
|
|
kwargs_clone = deepcopy(kwargs)
|
|
kwargs_clone[kwarg] = enabled
|
|
|
|
params_clone = []
|
|
for p in params:
|
|
p_clone = p.clone().detach()
|
|
if p.requires_grad:
|
|
p_clone.requires_grad = True
|
|
p_clone.grad = p.grad.clone().detach()
|
|
params_clone.append(p_clone)
|
|
|
|
optimizer = optimizer_constructor(params_clone, **kwargs_clone)
|
|
for _ in range(kIterations):
|
|
optimizer.step()
|
|
|
|
state.append(optimizer.state)
|
|
res.append(params_clone)
|
|
|
|
st_state = state[0]
|
|
mt_state = state[1]
|
|
for st_p, mt_p in zip(res[0], res[1]):
|
|
self.assertEqual(st_p, mt_p)
|
|
|
|
# check that optimizer states are the same
|
|
st_p_state = st_state[st_p]
|
|
mt_p_state = mt_state[mt_p]
|
|
|
|
for k in st_p_state:
|
|
actual = mt_p_state[k]
|
|
# If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`,
|
|
# `step` Tensor is 1D while usually it's 0D.
|
|
if (
|
|
k == "step"
|
|
and isinstance(actual, torch.Tensor)
|
|
and actual.ndim == 1
|
|
):
|
|
actual = actual[0]
|
|
self.assertEqual(st_p_state[k], actual)
|
|
|
|
|
|
def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
assert flag in ("foreach", "fused")
|
|
|
|
# why 7? iteration 7 is where we start to see differences for RAdam
|
|
# params interacting with the small eps value, because that's right
|
|
# after rho_t becomes greater than 5 in step 6.
|
|
kIterations = 7
|
|
device = "cuda"
|
|
for optimizer_constructor, params in optimizer_pairs_with_flags:
|
|
res, state = [], []
|
|
for flag_value in (False, True):
|
|
input = torch.tensor(
|
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device
|
|
).reshape(3, 2)
|
|
|
|
torch.manual_seed(1)
|
|
model = torch.nn.Sequential(
|
|
torch.nn.Linear(2, 3),
|
|
torch.nn.Sigmoid(),
|
|
torch.nn.Linear(3, 1),
|
|
torch.nn.Sigmoid(),
|
|
)
|
|
model.to(dtype=torch.float64, device=device)
|
|
params_with_flags = deepcopy(params)
|
|
params_with_flags[flag] = flag_value
|
|
|
|
optimizer = optimizer_constructor(
|
|
model.parameters(), **params_with_flags
|
|
)
|
|
|
|
for _ in range(kIterations):
|
|
optimizer.zero_grad()
|
|
output = model(input)
|
|
loss = output.sum()
|
|
loss.backward()
|
|
|
|
# Test that step behaves as expected (a no-op) when grads are set to None
|
|
if iter == 0:
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
optimizer.step()
|
|
|
|
state.append(optimizer.state)
|
|
res.append(model.parameters())
|
|
|
|
st_state = state[0]
|
|
mt_state = state[1]
|
|
for st_p, mt_p in zip(res[0], res[1]):
|
|
self.assertEqual(st_p, mt_p)
|
|
|
|
# check that optimizer states are the same
|
|
st_p_state = st_state[st_p]
|
|
mt_p_state = mt_state[mt_p]
|
|
|
|
for k in st_p_state:
|
|
actual = mt_p_state[k]
|
|
# If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`,
|
|
# `step` Tensor is 1D while usually it's 0D.
|
|
if (
|
|
k == "step"
|
|
and isinstance(actual, torch.Tensor)
|
|
and actual.ndim == 1
|
|
):
|
|
actual = actual[0]
|
|
self.assertEqual(st_p_state[k], actual)
|
|
|
|
def test_multi_tensor_optimizers(self):
|
|
optimizer_pairs_with_flags = [
|
|
(optim.Adam, dict(weight_decay=1.0, amsgrad=True, fused=False)),
|
|
(optim.Adam, dict(weight_decay=1.0, amsgrad=False, fused=False)),
|
|
(optim.Adam, dict(weight_decay=0.0, amsgrad=True, fused=False)),
|
|
(optim.Adam, dict(weight_decay=0.0, amsgrad=False, fused=False)),
|
|
(optim.AdamW, dict(weight_decay=1.0, amsgrad=True)),
|
|
(optim.AdamW, dict(weight_decay=1.0, amsgrad=False)),
|
|
(optim.AdamW, dict(weight_decay=0.0, amsgrad=True)),
|
|
(optim.AdamW, dict(weight_decay=0.0, amsgrad=False)),
|
|
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=6e-3)),
|
|
(optim.NAdam, dict(weight_decay=1.0, momentum_decay=6e-3)),
|
|
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=4e-3)),
|
|
(optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)),
|
|
(
|
|
optim.SGD,
|
|
dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True),
|
|
),
|
|
(
|
|
optim.SGD,
|
|
dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False),
|
|
),
|
|
(optim.RAdam, dict(weight_decay=0, eps=1e-6)),
|
|
(optim.RAdam, dict(weight_decay=0)),
|
|
(optim.RAdam, dict(weight_decay=1, eps=1e-6)),
|
|
(optim.RAdam, dict(weight_decay=1)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)),
|
|
(optim.RMSprop, dict(weight_decay=0, momentum=1, centered=False)),
|
|
(optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))),
|
|
(optim.ASGD, dict(weight_decay=0)),
|
|
(optim.ASGD, dict(weight_decay=1)),
|
|
(optim.Adamax, dict(weight_decay=0)),
|
|
(optim.Adamax, dict(weight_decay=1)),
|
|
(optim.Adadelta, dict(weight_decay=0)),
|
|
(optim.Adadelta, dict(weight_decay=1)),
|
|
(optim.Adagrad, dict(weight_decay=0)),
|
|
(optim.Adagrad, dict(weight_decay=1)),
|
|
]
|
|
self._test_derived_optimizers(optimizer_pairs_with_flags, "foreach")
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
|
def test_multi_tensor_optimizers_with_varying_tensors(self):
|
|
optimizer_pairs_with_flags = [
|
|
(optim.Adam, dict(weight_decay=1.0, amsgrad=True, fused=False)),
|
|
(optim.Adam, dict(weight_decay=1.0, amsgrad=False, fused=False)),
|
|
(optim.Adam, dict(weight_decay=0.0, amsgrad=True, fused=False)),
|
|
(optim.Adam, dict(weight_decay=0.0, amsgrad=False, fused=False)),
|
|
(optim.AdamW, dict(weight_decay=1.0, amsgrad=True)),
|
|
(optim.AdamW, dict(weight_decay=1.0, amsgrad=False)),
|
|
(optim.AdamW, dict(weight_decay=0.0, amsgrad=True)),
|
|
(optim.AdamW, dict(weight_decay=0.0, amsgrad=False)),
|
|
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=6e-3)),
|
|
(optim.NAdam, dict(weight_decay=1.0, momentum_decay=6e-3)),
|
|
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=4e-3)),
|
|
(optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)),
|
|
(
|
|
optim.SGD,
|
|
dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True),
|
|
),
|
|
(
|
|
optim.SGD,
|
|
dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False),
|
|
),
|
|
(optim.RAdam, dict(weight_decay=0, eps=1e-6)),
|
|
(optim.RAdam, dict(weight_decay=0)),
|
|
(optim.RAdam, dict(weight_decay=1, eps=1e-6)),
|
|
(optim.RAdam, dict(weight_decay=1)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)),
|
|
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)),
|
|
(optim.RMSprop, dict(weight_decay=0, momentum=1, centered=False)),
|
|
(optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))),
|
|
(optim.ASGD, dict(weight_decay=0)),
|
|
(optim.ASGD, dict(weight_decay=1)),
|
|
(optim.Adamax, dict(weight_decay=0)),
|
|
(optim.Adamax, dict(weight_decay=1)),
|
|
(optim.Adadelta, dict(weight_decay=0)),
|
|
(optim.Adadelta, dict(weight_decay=1)),
|
|
(optim.Adagrad, dict(weight_decay=0)),
|
|
(optim.Adagrad, dict(weight_decay=1)),
|
|
]
|
|
self._test_derived_optimizers_varying_tensors(optimizer_pairs_with_flags, "foreach")
|
|
|
|
def test_fused_optimizers(self):
|
|
optimizer_pairs_with_flags = tuple(itertools.product(
|
|
(optim.Adam, optim.AdamW),
|
|
(
|
|
dict(weight_decay=1., amsgrad=False),
|
|
dict(weight_decay=1., amsgrad=True),
|
|
dict(weight_decay=0., amsgrad=False),
|
|
dict(weight_decay=0., amsgrad=True),
|
|
),
|
|
))
|
|
self._test_derived_optimizers(optimizer_pairs_with_flags, "fused")
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
|
def test_fused_optimizers_with_varying_tensors(self):
|
|
optimizer_pairs_with_flags = tuple(itertools.product(
|
|
(optim.Adam, optim.AdamW),
|
|
(
|
|
dict(weight_decay=1., amsgrad=False),
|
|
dict(weight_decay=1., amsgrad=True),
|
|
dict(weight_decay=0., amsgrad=False),
|
|
dict(weight_decay=0., amsgrad=True),
|
|
),
|
|
))
|
|
self._test_derived_optimizers_varying_tensors(optimizer_pairs_with_flags, "fused")
|
|
|
|
def test_adam(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=0.1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
|
|
lambda opt: ExponentialLR(opt, gamma=0.9),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: ExponentialLR(opt, gamma=0.9),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adam(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_complex_2d(optim.Adam)
|
|
self._test_complex_2d(functools.partial(optim.Adam, foreach=True))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 0: 1.0"
|
|
):
|
|
optim.Adam(None, lr=1e-2, betas=(1.0, 0.0))
|
|
|
|
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
|
|
optim.Adam(None, lr=1e-2, weight_decay=-1)
|
|
|
|
def test_adamw(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.AdamW(
|
|
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.AdamW(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.AdamW(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.AdamW(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=1,
|
|
amsgrad=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_complex_2d(optim.AdamW)
|
|
self._test_complex_2d(functools.partial(optim.AdamW, foreach=True))
|
|
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
|
|
optim.AdamW(None, lr=1e-2, weight_decay=-1)
|
|
|
|
def test_sparse_adam(self):
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.SparseAdam(params, lr=4e-2), [], True
|
|
)
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True),
|
|
[],
|
|
True,
|
|
True,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 0: 1.0"
|
|
):
|
|
optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "SparseAdam requires dense parameter tensors"
|
|
):
|
|
optim.SparseAdam([torch.zeros(3, layout=torch.sparse_coo)])
|
|
with self.assertRaisesRegex(
|
|
ValueError, "SparseAdam requires dense parameter tensors"
|
|
):
|
|
optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}])
|
|
|
|
# ROCm precision is too low to pass this test
|
|
def test_adadelta(self):
|
|
# Handles https://github.com/pytorch/pytorch/issues/69698
|
|
self.rel_tol = 4e-3
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adadelta(
|
|
[weight, bias], maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adadelta(
|
|
self._build_params_dict(weight, bias, rho=0.95),
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adadelta(
|
|
self._build_params_dict(weight, bias, rho=0.95),
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adadelta(
|
|
[weight, bias], weight_decay=1, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
|
|
optim.Adadelta(None, lr=1e-2, rho=1.1)
|
|
|
|
def test_adadelta_complex(self):
|
|
# Handles https://github.com/pytorch/pytorch/issues/69698
|
|
self.rel_tol = 2e-2
|
|
for optimizer in [optim.Adadelta]:
|
|
self._test_complex_optimizer(lambda weight: optimizer([weight]))
|
|
self._test_complex_optimizer(lambda weight: optimizer([weight], rho=0.95))
|
|
self._test_complex_optimizer(
|
|
lambda weight: optimizer([weight], rho=0.95, weight_decay=1)
|
|
)
|
|
|
|
def test_nadam(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.NAdam(
|
|
[weight, bias], lr=1e-3, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.NAdam(
|
|
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.NAdam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=0.1,
|
|
momentum_decay=6e-3,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.NAdam(
|
|
[weight, bias],
|
|
lr=1e-3,
|
|
weight_decay=0.1,
|
|
momentum_decay=6e-3,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 0: 1.0"
|
|
):
|
|
optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0))
|
|
with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
|
|
optim.NAdam(None, lr=1e-2, momentum_decay=-0.2)
|
|
|
|
def test_adagrad(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
[weight, bias],
|
|
lr=1e-1,
|
|
initial_accumulator_value=0.1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[lambda opt: ReduceLROnPlateau(opt)],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adagrad(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
[
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
lambda opt: ExponentialLR(opt, gamma=0.99),
|
|
],
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"):
|
|
optim.Adagrad(None, lr=1e-2, lr_decay=-0.5)
|
|
|
|
def test_adagrad_sparse(self):
|
|
for foreach in (False, True):
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.Adagrad(params, lr=1e-1, foreach=foreach)
|
|
)
|
|
self._test_rosenbrock_sparse(
|
|
lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach),
|
|
[
|
|
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
|
|
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
|
|
],
|
|
)
|
|
|
|
def test_adagrad_complex(self):
|
|
for foreach in (False, True):
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.Adagrad([param], lr=1e-1, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.Adagrad(
|
|
[param],
|
|
lr=1e-1,
|
|
initial_accumulator_value=0.1,
|
|
foreach=foreach,
|
|
)
|
|
)
|
|
|
|
def test_adamax(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adamax(
|
|
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adamax(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Adamax(
|
|
[weight, bias],
|
|
lr=1e-1,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_complex_2d(optim.Adamax)
|
|
self._test_complex_2d(functools.partial(optim.Adamax, foreach=True))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 1: 1.0"
|
|
):
|
|
optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0))
|
|
|
|
def test_radam(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.RAdam(
|
|
[weight, bias], lr=1e-3, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.RAdam(
|
|
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.RAdam(
|
|
[weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach
|
|
),
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, foreach: optim.RAdam(
|
|
[weight, bias], lr=1e-3, foreach=foreach
|
|
),
|
|
[
|
|
lambda opt: ExponentialLR(opt, gamma=0.9),
|
|
lambda opt: ReduceLROnPlateau(opt),
|
|
],
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Invalid beta parameter at index 0: 1.0"
|
|
):
|
|
optim.RAdam(None, lr=1e-2, betas=(1.0, 0.0))
|
|
|
|
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
|
|
optim.RAdam(None, lr=1e-2, weight_decay=-1)
|
|
|
|
def test_rmsprop(self):
|
|
for foreach in (False, True):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
[weight, bias], lr=1e-2, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
centered=True,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
centered=True,
|
|
momentum=0.1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
momentum=0.1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.RMSprop(
|
|
self._build_params_dict(weight, bias, lr=1e-3),
|
|
lr=1e-2,
|
|
momentum=0.1,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_complex_2d(lambda param: optim.RMSprop(param, foreach=foreach))
|
|
self._test_complex_2d(
|
|
lambda param: optim.RMSprop(param, centered=True, foreach=foreach)
|
|
)
|
|
self._test_complex_2d(
|
|
lambda param: optim.RMSprop(param, momentum=0.1, foreach=foreach)
|
|
)
|
|
self._test_complex_2d(
|
|
lambda param: optim.RMSprop(param, maximize=True, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.RMSprop([param], foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.RMSprop([param], centered=True, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.RMSprop([param], momentum=0.1, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.RMSprop([param], maximize=True, foreach=foreach)
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
|
|
optim.RMSprop(None, lr=1e-2, momentum=-1.0, foreach=foreach)
|
|
|
|
def test_asgd(self):
|
|
for foreach in (False, True):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.ASGD(
|
|
[weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.ASGD(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
t0=100,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.ASGD(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=1e-3,
|
|
weight_decay=1,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
# Ref: https://github.com/pytorch/pytorch/issues/84560
|
|
# self._test_complex_2d(optimizer)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD([params], foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD([params], maximize=True, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD(
|
|
[params], maximize=True, weight_decay=0.9, foreach=foreach
|
|
)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD(
|
|
[params], maximize=False, weight_decay=0.9, foreach=foreach
|
|
)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda params: optim.ASGD([params], weight_decay=0.9, foreach=foreach)
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"):
|
|
optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach)
|
|
|
|
@skipIfRocm
|
|
def test_rprop(self):
|
|
is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(
|
|
0
|
|
) == (8, 6)
|
|
for foreach in (False, True):
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Rprop(
|
|
[weight, bias], lr=2e-4, maximize=maximize, foreach=foreach
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias, maximize, foreach: optim.Rprop(
|
|
self._build_params_dict(weight, bias, lr=1e-2),
|
|
lr=2e-4,
|
|
maximize=maximize,
|
|
foreach=foreach,
|
|
),
|
|
constructor_accepts_maximize=True,
|
|
constructor_accepts_foreach=True,
|
|
atol=4e-5 if is_cuda_sm86 else None,
|
|
rtol=3e-5 if is_cuda_sm86 else None,
|
|
)
|
|
self._test_complex_2d(lambda param: optim.Rprop(param, foreach=foreach))
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.Rprop([param], lr=0.001, foreach=foreach)
|
|
)
|
|
self._test_complex_optimizer(
|
|
lambda param: optim.Rprop(
|
|
[param], lr=0.001, maximize=True, foreach=foreach
|
|
)
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
|
|
optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5), foreach=foreach)
|
|
|
|
def test_lbfgs(self):
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.LBFGS([weight, bias]), ignore_multidevice=True
|
|
)
|
|
self._test_basic_cases(
|
|
lambda weight, bias: optim.LBFGS(
|
|
[weight, bias], line_search_fn="strong_wolfe"
|
|
),
|
|
ignore_multidevice=True,
|
|
)
|
|
|
|
@unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
|
|
def test_lbfgs_return_type(self):
|
|
params = [torch.randn(10, 5), torch.randn(10)]
|
|
opt1 = optim.LBFGS(params, 0.01, tolerance_grad=math.inf)
|
|
opt2 = optim.LBFGS(params, 0.01, tolerance_grad=-math.inf)
|
|
|
|
def closure():
|
|
return torch.tensor([10])
|
|
|
|
res1 = opt1.step(closure)
|
|
res2 = opt2.step(closure)
|
|
self.assertEqual(type(res1), type(res2))
|
|
|
|
def test_invalid_param_type(self):
|
|
with self.assertRaises(TypeError):
|
|
optim.SGD(Parameter(torch.randn(5, 5)), lr=3)
|
|
|
|
def test_duplicate_params_in_param_group(self):
|
|
param = Parameter(torch.randn(5, 5))
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
optim.SGD([param, param], lr=0.1)
|
|
self.assertEqual(len(w), 1)
|
|
self.assertIn(
|
|
"a parameter group with duplicate parameters", str(w[0].message)
|
|
)
|
|
|
|
def test_no_grad_for_all_params(self):
|
|
params = [torch.randn(5, 5, requires_grad=False) for _ in range(2)]
|
|
|
|
optimizer_list = [
|
|
optim.Adadelta,
|
|
optim.AdamW,
|
|
optim.Adam,
|
|
optim.Adagrad,
|
|
optim.Adamax,
|
|
optim.RMSprop,
|
|
optim.SGD,
|
|
optim.SparseAdam,
|
|
optim.ASGD,
|
|
]
|
|
for optim_ctr in optimizer_list:
|
|
opt = optim_ctr(params, lr=0.1)
|
|
# make sure step can still run even if
|
|
# all params have no grad
|
|
opt.step()
|
|
|
|
# make sure that `state_steps` is correctly either updated or not updated when `found_inf`.
|
|
def test_functional_fused_optimizer_with_foundinf(self):
|
|
if not torch.cuda.is_available():
|
|
self.skipTest("CUDA is required.")
|
|
|
|
from torch.optim import adam, adamw
|
|
|
|
num_tensors = 5
|
|
for functional_optim, amsgrad in itertools.product((adam.adam, adamw.adamw), (False, True)):
|
|
params, grads, exp_avgs, exp_avg_sqs = [[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)]
|
|
max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
|
|
state_steps = [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
|
|
grad_scale = torch.ones((1,), dtype=torch.float32, device="cuda")
|
|
found_inf = torch.ones((1,), dtype=torch.float32, device="cuda")
|
|
|
|
functional_optim(
|
|
params,
|
|
grads,
|
|
exp_avgs,
|
|
exp_avg_sqs,
|
|
max_exp_avg_sqs,
|
|
state_steps,
|
|
foreach=False,
|
|
capturable=False,
|
|
fused=True,
|
|
amsgrad=amsgrad,
|
|
beta1=0.9,
|
|
beta2=0.99,
|
|
lr=1e-2,
|
|
weight_decay=0.0,
|
|
eps=1e-8,
|
|
maximize=False,
|
|
grad_scale=grad_scale,
|
|
found_inf=found_inf,
|
|
)
|
|
|
|
self.assertEqual(
|
|
state_steps,
|
|
[
|
|
torch.ones((1,), dtype=torch.float32, device="cuda")
|
|
for _ in range(num_tensors)
|
|
],
|
|
)
|
|
|
|
def test_empty_grad(self):
|
|
optimizers = [
|
|
torch.optim.Adadelta,
|
|
torch.optim.Adagrad,
|
|
torch.optim.Adam,
|
|
torch.optim.AdamW,
|
|
torch.optim.Adamax,
|
|
torch.optim.ASGD,
|
|
torch.optim.NAdam,
|
|
torch.optim.RAdam,
|
|
torch.optim.RMSprop,
|
|
torch.optim.Rprop,
|
|
torch.optim.SGD,
|
|
torch.optim.SparseAdam,
|
|
]
|
|
|
|
for optimizer in optimizers:
|
|
net = torch.nn.Embedding(
|
|
5, 1, padding_idx=0, sparse=optimizer is torch.optim.SparseAdam
|
|
)
|
|
original_params = (param.detach().clone() for param in net.parameters())
|
|
# Simulate a batch that only indexes the embedding at padding_idx
|
|
x = torch.tensor([[0, 0]]).int()
|
|
y = torch.tensor([[[3.0], [4.0]]])
|
|
opt = optimizer(net.parameters(), lr=1e-5)
|
|
torch.nn.MSELoss()(net.forward(x), y).backward()
|
|
|
|
opt.step()
|
|
|
|
for original_param, param in zip(original_params, net.parameters()):
|
|
# assert that the parameters have not changed
|
|
self.assertEqual(original_param, param)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_post_hook(self):
|
|
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data += 2
|
|
|
|
params = [torch.Tensor([1, 1])]
|
|
opt = SGD(params, lr=0.001)
|
|
data = 2
|
|
hook_handle = opt.register_step_post_hook(post_hook)
|
|
|
|
opt.step()
|
|
opt.step()
|
|
# check if pre hooks were registered
|
|
self.assertEqual(data, 6)
|
|
|
|
# remove handles, take step and verify that hook is no longer registered
|
|
hook_handle.remove()
|
|
|
|
opt.step()
|
|
self.assertEqual(data, 6)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_pre_hook(self):
|
|
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data += 2
|
|
|
|
params = [torch.Tensor([1, 1])]
|
|
opt = SGD(params, lr=0.001)
|
|
data = 5
|
|
hook_handle = opt.register_step_pre_hook(pre_hook)
|
|
|
|
opt.step()
|
|
opt.step()
|
|
# check if pre hooks were registered
|
|
self.assertEqual(data, 9)
|
|
|
|
# remove handles, take step and verify that hook is no longer registered
|
|
hook_handle.remove()
|
|
|
|
opt.step()
|
|
self.assertEqual(data, 9)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_pre_and_post_hook(self):
|
|
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data.append(0)
|
|
|
|
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data.append(5)
|
|
|
|
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data.append(1)
|
|
|
|
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
|
nonlocal data
|
|
data.append(2)
|
|
|
|
params = [torch.Tensor([1, 1])]
|
|
opt1 = SGD(params, lr=0.001)
|
|
opt2 = Adam(params, lr=0.01)
|
|
data = []
|
|
|
|
# register global hooks to both optimizers
|
|
global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
|
|
global_post_handle = register_optimizer_step_post_hook(global_post_hook)
|
|
|
|
# register local hooks
|
|
first_pre_handle = opt1.register_step_pre_hook(local_pre_hook)
|
|
first_post_handle = opt1.register_step_post_hook(local_post_hook)
|
|
second_pre_handle = opt2.register_step_pre_hook(local_pre_hook)
|
|
second_post_handle = opt2.register_step_post_hook(local_post_hook)
|
|
|
|
opt1.step()
|
|
self.assertListEqual(data, [0, 1, 2, 5])
|
|
opt2.step()
|
|
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
|
|
opt1.step()
|
|
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
|
|
|
|
# remove all hooks
|
|
global_pre_handle.remove()
|
|
global_post_handle.remove()
|
|
first_pre_handle.remove()
|
|
first_post_handle.remove()
|
|
second_pre_handle.remove()
|
|
second_post_handle.remove()
|
|
|
|
opt1.step()
|
|
opt2.step()
|
|
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
|
|
|
|
def test_fused_optimizer_raises(self):
|
|
if not torch.cuda.is_available():
|
|
self.skipTest("Requires CUDA devices")
|
|
for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
|
|
with self.assertRaisesRegex(RuntimeError, "`fused` and `foreach` cannot be `True` together."):
|
|
optimizer_ctor([torch.empty((), device="cuda")], foreach=True, fused=True)
|
|
with self.assertRaisesRegex(RuntimeError, "`fused` does not support `differentiable`"):
|
|
optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)
|
|
|
|
|
|
class SchedulerTestNet(torch.nn.Module):
|
|
def __init__(self):
|
|
super(SchedulerTestNet, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
|
self.conv2 = torch.nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv2(F.relu(self.conv1(x)))
|
|
|
|
|
|
class LambdaLRTestObject:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __call__(self, epoch):
|
|
return self.value * epoch
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, self.__class__):
|
|
return self.__dict__ == other.__dict__
|
|
else:
|
|
return False
|
|
|
|
|
|
class TestLRScheduler(TestCase):
|
|
exact_dtype = True
|
|
|
|
def setUp(self):
|
|
super(TestLRScheduler, self).setUp()
|
|
self.net = SchedulerTestNet()
|
|
self.opt = SGD(
|
|
[
|
|
{"params": self.net.conv1.parameters()},
|
|
{"params": self.net.conv2.parameters(), "lr": 0.5},
|
|
],
|
|
lr=0.05,
|
|
)
|
|
|
|
def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1):
|
|
"""This function swallows the epoch deprecation warning which is produced when we
|
|
call `scheduler.step(epoch)` with some not `None` value of `epoch`.
|
|
this is deprecated, and this function will need to be removed/updated when
|
|
the schedulers no longer accept the parameter at all.
|
|
"""
|
|
self.assertEqual(len(w), num_warnings)
|
|
for warning in w:
|
|
self.assertEqual(len(warning.message.args), 1)
|
|
self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING)
|
|
|
|
def test_error_when_getlr_has_epoch(self):
|
|
class MultiStepLR(torch.optim.lr_scheduler.LRScheduler):
|
|
def __init__(self, optimizer, gamma, milestones, last_epoch=-1):
|
|
self.init_lr = [group["lr"] for group in optimizer.param_groups]
|
|
self.gamma = gamma
|
|
self.milestones = milestones
|
|
super().__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self, step):
|
|
global_step = self.last_epoch
|
|
gamma_power = (
|
|
[0]
|
|
+ [i + 1 for i, m in enumerate(self.milestones) if global_step >= m]
|
|
)[-1]
|
|
return [
|
|
init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr
|
|
]
|
|
|
|
optimizer = torch.optim.SGD([torch.rand(1)], lr=1)
|
|
|
|
with self.assertRaises(TypeError):
|
|
scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20])
|
|
|
|
@skipIfTorchDynamo("Torchdynamo keeps references to optim in the guards and the stack of the graph break frames")
|
|
def test_no_cyclic_references(self):
|
|
import gc
|
|
|
|
param = Parameter(torch.empty(10))
|
|
optim = SGD([param], lr=0.5)
|
|
scheduler = LambdaLR(optim, lambda epoch: 1.0)
|
|
del scheduler
|
|
|
|
self.assertTrue(
|
|
len(gc.get_referrers(optim)) == 0,
|
|
"Optimizer should contain no cyclic references",
|
|
)
|
|
|
|
gc.collect()
|
|
del optim
|
|
self.assertEqual(
|
|
gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__"
|
|
)
|
|
|
|
@skipIfTorchDynamo("Torchdynamo keeps references to optim in the guards and the stack of the graph break frames")
|
|
def test_no_cyclic_references_in_step(self):
|
|
import gc
|
|
import weakref
|
|
|
|
def run():
|
|
param = torch.empty(10, requires_grad=True)
|
|
optim = SGD(params=[param], lr=0.5)
|
|
scheduler = LambdaLR(optim, lambda epoch: 1.0)
|
|
param.sum().backward()
|
|
optim.step()
|
|
scheduler.step()
|
|
|
|
return weakref.ref(scheduler)
|
|
|
|
# To ensure that there are no reference cycles in scheduler,
|
|
# we need to turn off the garbage collector. Since gc will
|
|
# automatically collect unreachable objects.
|
|
gc.disable()
|
|
ref = run()
|
|
|
|
assert ref() is None
|
|
gc.enable() # restore
|
|
|
|
def test_old_pattern_warning(self):
|
|
epochs = 35
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
def old_pattern():
|
|
for _ in range(epochs):
|
|
scheduler.step()
|
|
self.opt.step()
|
|
|
|
self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern)
|
|
|
|
def test_old_pattern_warning_with_arg(self):
|
|
epochs = 35
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
def old_pattern2():
|
|
for _ in range(epochs):
|
|
scheduler.step()
|
|
self.opt.step()
|
|
|
|
self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
|
|
|
|
def test_old_pattern_warning_resuming(self):
|
|
epochs = 35
|
|
for i, group in enumerate(self.opt.param_groups):
|
|
group["initial_lr"] = 0.01
|
|
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
def old_pattern():
|
|
for _ in range(epochs):
|
|
scheduler.step()
|
|
self.opt.step()
|
|
|
|
self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern)
|
|
|
|
def test_old_pattern_warning_resuming_with_arg(self):
|
|
epochs = 35
|
|
for i, group in enumerate(self.opt.param_groups):
|
|
group["initial_lr"] = 0.01
|
|
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
def old_pattern2():
|
|
for _ in range(epochs):
|
|
scheduler.step()
|
|
self.opt.step()
|
|
|
|
self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
|
|
|
|
def test_old_pattern_warning_with_overridden_optim_step(self):
|
|
epochs = 35
|
|
for i, group in enumerate(self.opt.param_groups):
|
|
group["initial_lr"] = 0.01
|
|
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
# emulate use-case with optimizer.step overridden
|
|
import types
|
|
|
|
old_step = self.opt.step
|
|
|
|
def new_step(o, *args, **kwargs):
|
|
retval = old_step(*args, **kwargs)
|
|
return retval
|
|
|
|
self.opt.step = types.MethodType(new_step, self.opt)
|
|
|
|
def old_pattern2():
|
|
for _ in range(epochs):
|
|
scheduler.step()
|
|
self.opt.step()
|
|
|
|
self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)
|
|
|
|
def test_new_pattern_no_warning(self):
|
|
epochs = 35
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
for _ in range(epochs):
|
|
self.opt.step()
|
|
scheduler.step()
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
def test_new_pattern_no_warning_with_arg(self):
|
|
epochs = 35
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
for _ in range(epochs):
|
|
self.opt.step()
|
|
scheduler.step()
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
def test_new_pattern_no_warning_with_overridden_optim_step(self):
|
|
epochs = 35
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
|
|
|
# emulate use-case with optimizer.step overridden
|
|
import types
|
|
|
|
old_step = self.opt.step
|
|
|
|
def new_step(o, *args, **kwargs):
|
|
retval = old_step(*args, **kwargs)
|
|
return retval
|
|
|
|
self.opt.step = types.MethodType(new_step, self.opt)
|
|
|
|
def new_pattern():
|
|
for e in range(epochs):
|
|
self.opt.step()
|
|
scheduler.step()
|
|
|
|
self.assertWarnsRegex(
|
|
UserWarning, r"`optimizer.step\(\)` has been overridden", new_pattern
|
|
)
|
|
|
|
def _test_lr_is_constant_for_constant_epoch(self, scheduler):
|
|
l = []
|
|
|
|
for _ in range(10):
|
|
scheduler.optimizer.step()
|
|
with warnings.catch_warnings(record=True) as w:
|
|
scheduler.step(2)
|
|
self._check_warning_is_epoch_deprecation_warning(w)
|
|
|
|
l.append(self.opt.param_groups[0]["lr"])
|
|
self.assertEqual(min(l), max(l))
|
|
|
|
def test_step_lr_is_constant_for_constant_epoch(self):
|
|
scheduler = StepLR(self.opt, 2)
|
|
self._test_lr_is_constant_for_constant_epoch(scheduler)
|
|
|
|
def test_exponential_lr_is_constant_for_constant_epoch(self):
|
|
scheduler = ExponentialLR(self.opt, gamma=0.9)
|
|
self._test_lr_is_constant_for_constant_epoch(scheduler)
|
|
|
|
def test_constantlr_is_constant_for_constant_epoch(self):
|
|
scheduler = ConstantLR(self.opt)
|
|
self._test_lr_is_constant_for_constant_epoch(scheduler)
|
|
|
|
def test_linear_linearlr_is_constant_for_constant_epoch(self):
|
|
scheduler = LinearLR(self.opt)
|
|
self._test_lr_is_constant_for_constant_epoch(scheduler)
|
|
|
|
def test_polynomial_lr_is_constant_for_constant_epoch(self):
|
|
scheduler = PolynomialLR(self.opt, power=0.9)
|
|
self._test_lr_is_constant_for_constant_epoch(scheduler)
|
|
|
|
def test_step_lr(self):
|
|
# lr = 0.05 if epoch < 3
|
|
# lr = 0.005 if 30 <= epoch < 6
|
|
# lr = 0.0005 if epoch >= 9
|
|
epochs = 10
|
|
single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_get_last_lr_step_lr(self):
|
|
from torch.nn import Parameter
|
|
|
|
epochs = 10
|
|
optimizer = torch.optim.SGD(
|
|
[Parameter(torch.randn(2, 2, requires_grad=True))], 0.1
|
|
)
|
|
targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]]
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)
|
|
self._test_get_last_lr(scheduler, targets, epochs)
|
|
|
|
def test_get_last_lr_multi_step_lr(self):
|
|
# lr = 0.05 if epoch < 2
|
|
# lr = 0.005 if 2 <= epoch < 5
|
|
# lr = 0.0005 if 5 <= epoch < 9
|
|
# lr = 0.00005 if 9 <= epoch
|
|
epochs = 10
|
|
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
self._test_get_last_lr(scheduler, targets, epochs)
|
|
|
|
def test_multi_step_lr(self):
|
|
# lr = 0.05 if epoch < 2
|
|
# lr = 0.005 if 2 <= epoch < 5
|
|
# lr = 0.0005 if epoch < 9
|
|
# lr = 0.00005 if epoch >= 9
|
|
epochs = 10
|
|
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_multi_step_lr_with_epoch(self):
|
|
# lr = 0.05 if epoch < 2
|
|
# lr = 0.005 if 2 <= epoch < 5
|
|
# lr = 0.0005 if epoch < 9
|
|
# lr = 0.00005 if epoch >= 9
|
|
epochs = 10
|
|
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
self._test_with_epoch(scheduler, targets, epochs)
|
|
|
|
def test_get_last_lr_constantlr(self):
|
|
# lr = 0.025 if epoch < 5
|
|
# lr = 0.005 if 5 <= epoch
|
|
epochs = 10
|
|
single_targets = [0.025] * 5 + [0.05] * 5
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
|
|
self._test_get_last_lr(scheduler, targets, epochs)
|
|
|
|
def test_get_last_lr_linearlr(self):
|
|
# lr = 0.025 if epoch == 0
|
|
# lr = 0.03125 if epoch == 1
|
|
# lr = 0.0375 if epoch == 2
|
|
# lr = 0.04375 if epoch == 3
|
|
# lr = 0.005 if 4 <= epoch
|
|
epochs = 10
|
|
start_factor = 1.0 / 4
|
|
end_factor = 3.0 / 5
|
|
iters = 4
|
|
interpolation = [
|
|
start_factor + i * (end_factor - start_factor) / iters for i in range(iters)
|
|
]
|
|
single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (
|
|
epochs - iters
|
|
)
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = LinearLR(
|
|
self.opt,
|
|
start_factor=start_factor,
|
|
end_factor=end_factor,
|
|
total_iters=iters,
|
|
)
|
|
self._test_get_last_lr(scheduler, targets, epochs)
|
|
|
|
def test_constantlr(self):
|
|
# lr = 0.025 if epoch < 5
|
|
# lr = 0.005 if 5 <= epoch
|
|
epochs = 10
|
|
single_targets = [0.025] * 5 + [0.05] * 5
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_linearlr(self):
|
|
# lr = 0.025 if epoch == 0
|
|
# lr = 0.03125 if epoch == 1
|
|
# lr = 0.0375 if epoch == 2
|
|
# lr = 0.04375 if epoch == 3
|
|
# lr = 0.005 if 4 <= epoch
|
|
epochs = 10
|
|
start_factor = 1.0 / 2
|
|
iters = 4
|
|
interpolation = [
|
|
start_factor + i * (1 - start_factor) / iters for i in range(iters)
|
|
]
|
|
single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_linearlr_start_factor_limits1(self):
|
|
start_factor = 0.0
|
|
iters = 4
|
|
with self.assertRaises(ValueError):
|
|
LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
|
|
|
|
def test_linearlr_start_factor_limits2(self):
|
|
start_factor = 1.1
|
|
iters = 4
|
|
with self.assertRaises(ValueError):
|
|
LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
|
|
|
|
def test_constantlr_with_epoch(self):
|
|
# lr = 0.025 if epoch < 5
|
|
# lr = 0.005 if 5 <= epoch
|
|
epochs = 10
|
|
single_targets = [0.025] * 5 + [0.05] * 5
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
|
|
self._test_with_epoch(scheduler, targets, epochs)
|
|
|
|
def test_linearlr_with_epoch(self):
|
|
# lr = 0.025 if epoch == 0
|
|
# lr = 0.03125 if epoch == 1
|
|
# lr = 0.0375 if epoch == 2
|
|
# lr = 0.04375 if epoch == 3
|
|
# lr = 0.005 if 4 <= epoch
|
|
epochs = 10
|
|
start_factor = 1.0 / 2
|
|
end_factor = 1.0
|
|
iters = 4
|
|
interpolation = [
|
|
start_factor + i * (end_factor - start_factor) / iters for i in range(iters)
|
|
]
|
|
single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
|
|
self._test_with_epoch(scheduler, targets, epochs)
|
|
|
|
def test_exp_lr(self):
|
|
epochs = 10
|
|
single_targets = [0.05 * (0.9**x) for x in range(epochs)]
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = ExponentialLR(self.opt, gamma=0.9)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_poly_lr(self):
|
|
epochs = 10
|
|
power = 0.9
|
|
total_iters = 5
|
|
single_targets = [
|
|
(1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters)
|
|
] + [0.0] * (epochs - total_iters)
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_cos_anneal_lr(self):
|
|
epochs = 10
|
|
eta_min = 1e-10
|
|
single_targets = [
|
|
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
|
|
for x in range(epochs)
|
|
]
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_closed_form_step_lr(self):
|
|
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
|
|
|
|
def test_closed_form_linearlr(self):
|
|
scheduler = LinearLR(
|
|
self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4
|
|
)
|
|
closed_form_scheduler = LinearLR(
|
|
self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4
|
|
)
|
|
self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
|
|
|
|
def test_closed_form_constantlr(self):
|
|
scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
|
|
closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
|
|
self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
|
|
|
|
def test_closed_form_multi_step_lr(self):
|
|
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
|
|
|
|
def test_closed_form_exp_lr(self):
|
|
scheduler = ExponentialLR(self.opt, gamma=0.9)
|
|
closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9)
|
|
self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
|
|
|
|
def test_closed_form_poly_lr(self):
|
|
scheduler = PolynomialLR(self.opt, power=0.9)
|
|
closed_form_scheduler = PolynomialLR(self.opt, power=0.9)
|
|
self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
|
|
|
|
def test_closed_form_cos_anneal_lr(self):
|
|
eta_min = 1e-10
|
|
epochs = 20
|
|
T_max = 5
|
|
scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
|
|
closed_form_scheduler = CosineAnnealingLR(
|
|
self.opt, T_max=T_max, eta_min=eta_min
|
|
)
|
|
self._test_against_closed_form(scheduler, closed_form_scheduler, epochs)
|
|
|
|
def test_cos_anneal_lr_continue(self):
|
|
eta_min = 0.1
|
|
T_max = 5
|
|
scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
|
|
self.opt.step()
|
|
scheduler.step()
|
|
original_lrs = scheduler._last_lr
|
|
new_scheduler = CosineAnnealingLR(
|
|
self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0
|
|
)
|
|
new_lrs = new_scheduler._last_lr
|
|
torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5)
|
|
|
|
def test_reduce_lr_on_plateau1(self):
|
|
epochs = 10
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
targets = [[0.5] * 20]
|
|
metrics = [10 - i * 0.0167 for i in range(20)]
|
|
scheduler = ReduceLROnPlateau(
|
|
self.opt,
|
|
threshold_mode="abs",
|
|
mode="min",
|
|
threshold=0.01,
|
|
patience=5,
|
|
cooldown=5,
|
|
)
|
|
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
|
|
|
|
def test_reduce_lr_on_plateau2(self):
|
|
epochs = 22
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
|
|
metrics = [10 - i * 0.0165 for i in range(22)]
|
|
scheduler = ReduceLROnPlateau(
|
|
self.opt,
|
|
patience=5,
|
|
cooldown=0,
|
|
threshold_mode="abs",
|
|
mode="min",
|
|
threshold=0.1,
|
|
)
|
|
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
|
|
|
|
def test_reduce_lr_on_plateau3(self):
|
|
epochs = 22
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
|
|
metrics = [-0.8] * 2 + [-0.234] * 20
|
|
scheduler = ReduceLROnPlateau(
|
|
self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs"
|
|
)
|
|
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
|
|
|
|
def test_reduce_lr_on_plateau4(self):
|
|
epochs = 20
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
targets = [[0.5] * 20]
|
|
metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25
|
|
scheduler = ReduceLROnPlateau(
|
|
self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1
|
|
)
|
|
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
|
|
|
|
def test_reduce_lr_on_plateau5(self):
|
|
epochs = 20
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
|
|
metrics = [1.5 * (1.005**i) for i in range(20)]
|
|
scheduler = ReduceLROnPlateau(
|
|
self.opt,
|
|
mode="max",
|
|
threshold_mode="rel",
|
|
threshold=0.1,
|
|
patience=5,
|
|
cooldown=5,
|
|
)
|
|
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
|
|
|
|
def test_reduce_lr_on_plateau6(self):
|
|
epochs = 20
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
targets = [[0.5] * 20]
|
|
metrics = [1.5 * (0.85**i) for i in range(20)]
|
|
scheduler = ReduceLROnPlateau(
|
|
self.opt, mode="min", threshold_mode="rel", threshold=0.1
|
|
)
|
|
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
|
|
|
|
def test_reduce_lr_on_plateau7(self):
|
|
epochs = 20
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
|
|
metrics = [1] * 7 + [0.6] + [0.5] * 12
|
|
scheduler = ReduceLROnPlateau(
|
|
self.opt,
|
|
mode="min",
|
|
threshold_mode="rel",
|
|
threshold=0.1,
|
|
patience=5,
|
|
cooldown=5,
|
|
)
|
|
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
|
|
|
|
def test_reduce_lr_on_plateau8(self):
|
|
epochs = 20
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
|
|
metrics = [1.5 * (1.005**i) for i in range(20)]
|
|
scheduler = ReduceLROnPlateau(
|
|
self.opt,
|
|
mode="max",
|
|
threshold_mode="rel",
|
|
min_lr=[0.4, 0.3],
|
|
threshold=0.1,
|
|
patience=5,
|
|
cooldown=5,
|
|
)
|
|
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
|
|
|
|
def test_sequentiallr1(self):
|
|
epochs = 19
|
|
schedulers = [None] * 2
|
|
targets = [
|
|
[0.05, 0.04, 0.032]
|
|
+ [0.05 for x in range(4)]
|
|
+ [0.05 * 0.1 for x in range(4)]
|
|
+ [0.05 * 0.01 for x in range(4)]
|
|
+ [0.05 * 0.001 for x in range(4)]
|
|
]
|
|
milestones = [3]
|
|
schedulers[0] = ExponentialLR(self.opt, gamma=0.8)
|
|
schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4)
|
|
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_sequentiallr2(self):
|
|
epochs = 13
|
|
schedulers = [None] * 2
|
|
targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9**x for x in range(10)]]
|
|
milestones = [3]
|
|
schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
|
|
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
|
|
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_sequentiallr3(self):
|
|
epochs = 12
|
|
schedulers = [None] * 3
|
|
targets = [
|
|
[0.005, 0.005, 0.005]
|
|
+ [0.05, 0.04, 0.032]
|
|
+ [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
|
|
]
|
|
milestones = [3, 6]
|
|
schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
|
|
schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
|
|
schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
|
|
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_sequentiallr4(self):
|
|
optimizer = torch.optim.SGD([torch.tensor(0.5)], lr=0.1)
|
|
prev_lr = optimizer.param_groups[0]["lr"]
|
|
|
|
schedulers = [
|
|
torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1),
|
|
torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1),
|
|
]
|
|
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
|
optimizer, schedulers, milestones=[10]
|
|
)
|
|
|
|
new_lr = optimizer.param_groups[0]["lr"]
|
|
|
|
# Ensure that multiple schedulers does not affect the initial learning rate
|
|
self.assertEqual(prev_lr, new_lr)
|
|
|
|
def test_get_last_lr_sequentiallr(self):
|
|
epochs = 12
|
|
milestones = [3, 6]
|
|
schedulers = [None] * 3
|
|
schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
|
|
schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
|
|
schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
|
|
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
|
|
constant_lr_target = [0.005] * 3
|
|
exponential_lr_target = [0.05, 0.04, 0.032]
|
|
step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
|
|
single_targets = constant_lr_target + exponential_lr_target + step_lr_target
|
|
targets = [single_targets, [x * 10 for x in single_targets]]
|
|
self._test_get_last_lr(scheduler, targets, epochs)
|
|
|
|
def test_chained_lr2_get_last_lr_before_step(self):
|
|
schedulers = [
|
|
LinearLR(self.opt, start_factor=0.4, total_iters=3),
|
|
MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1),
|
|
]
|
|
scheduler = ChainedScheduler(schedulers)
|
|
self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
|
|
|
|
def test_chained_lr1(self):
|
|
epochs = 10
|
|
schedulers = [None] * 1
|
|
targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3]
|
|
schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
scheduler = ChainedScheduler(schedulers)
|
|
self._test([scheduler], targets, epochs)
|
|
self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
|
|
|
|
def test_chained_lr2(self):
|
|
epochs = 10
|
|
schedulers = [None] * 1
|
|
targets = [[0.02, 0.03, 0.04] + [0.05] * 9]
|
|
schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
|
|
scheduler = ChainedScheduler(schedulers)
|
|
self._test([scheduler], targets, epochs)
|
|
self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
|
|
|
|
def test_chained_lr3(self):
|
|
epochs = 10
|
|
schedulers = [None] * 2
|
|
targets = [
|
|
[0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3
|
|
]
|
|
schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
|
|
schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1)
|
|
scheduler = ChainedScheduler(schedulers)
|
|
self._test([scheduler], targets, epochs)
|
|
self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
|
|
|
|
def test_chained_lr4(self):
|
|
epochs = 9
|
|
schedulers = [None] * 3
|
|
targets = [
|
|
[0.05 * 0.2 * 0.9**x for x in range(3)]
|
|
+ [0.05 * 0.2 * 0.9**3 * 0.1]
|
|
+ [0.05 * 0.9**x * 0.1 for x in range(4, 6)]
|
|
+ [0.05 * 0.9**x * 0.01 for x in range(6, 9)]
|
|
]
|
|
schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
|
|
schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4)
|
|
schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
scheduler = ChainedScheduler(schedulers)
|
|
self._test([scheduler], targets, epochs)
|
|
self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
|
|
|
|
def test_chained_lr5(self):
|
|
def poly_lr(lr: float):
|
|
return [
|
|
(lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters)
|
|
] + [0.0] * (epochs - total_iters)
|
|
|
|
schedulers = [None] * 2
|
|
epochs = 10
|
|
power = 0.9
|
|
total_iters = 5
|
|
const_factor = 0.1
|
|
single_targets = [x * const_factor for x in poly_lr(lr=0.05)]
|
|
targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]]
|
|
schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters)
|
|
schedulers[1] = ConstantLR(self.opt, factor=const_factor)
|
|
scheduler = ChainedScheduler(schedulers)
|
|
self._test(scheduler, targets, epochs)
|
|
self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())
|
|
|
|
def test_compound_step_and_multistep_lr(self):
|
|
epochs = 10
|
|
schedulers = [None] * 2
|
|
schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]]
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_step_and_exp_lr(self):
|
|
epochs = 10
|
|
schedulers = [None] * 2
|
|
single_targets = [0.05 * (0.9**x) for x in range(3)]
|
|
single_targets += [0.005 * (0.9**x) for x in range(3, 6)]
|
|
single_targets += [0.0005 * (0.9**x) for x in range(6, 9)]
|
|
single_targets += [0.00005 * (0.9**x) for x in range(9, 12)]
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_exp_and_multistep_lr(self):
|
|
epochs = 10
|
|
schedulers = [None] * 2
|
|
single_targets = [0.05 * (0.9**x) for x in range(2)]
|
|
single_targets += [0.005 * (0.9**x) for x in range(2, 5)]
|
|
single_targets += [0.0005 * (0.9**x) for x in range(5, 9)]
|
|
single_targets += [0.00005 * (0.9**x) for x in range(9, 11)]
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_exp_and_linearlr(self):
|
|
epochs = 10
|
|
iters = 4
|
|
start_factor = 0.4
|
|
end_factor = 0.9
|
|
schedulers = [None] * 2
|
|
single_targets = [0.05 * (0.9**x) for x in range(11)]
|
|
for i in range(iters):
|
|
single_targets[i] *= start_factor + i / iters * (end_factor - start_factor)
|
|
for i in range(iters, 11):
|
|
single_targets[i] *= end_factor
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers[0] = LinearLR(
|
|
self.opt,
|
|
start_factor=start_factor,
|
|
end_factor=end_factor,
|
|
total_iters=iters,
|
|
)
|
|
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_step_and_constantlr(self):
|
|
epochs = 10
|
|
iters = 4
|
|
factor = 0.4
|
|
schedulers = [None] * 2
|
|
single_targets = (
|
|
[0.05 * 0.4] * 3
|
|
+ [0.005 * 0.4]
|
|
+ [0.005] * 2
|
|
+ [0.0005] * 3
|
|
+ [0.00005] * 3
|
|
)
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4)
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_linearlr_and_multistep_lr(self):
|
|
epochs = 10
|
|
iters = 4
|
|
start_factor = 0.4
|
|
schedulers = [None] * 2
|
|
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2
|
|
for i in range(iters):
|
|
single_targets[i] *= start_factor + i / iters * (1 - start_factor)
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_cosanneal_and_step_lr(self):
|
|
epochs = 10
|
|
eta_min = 1e-10
|
|
single_targets = [
|
|
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
|
|
for x in range(epochs)
|
|
]
|
|
single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)]
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers = [None] * 2
|
|
schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
|
|
schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_cosanneal_and_multistep_lr(self):
|
|
epochs = 10
|
|
eta_min = 1e-10
|
|
single_targets = [
|
|
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
|
|
for x in range(epochs)
|
|
]
|
|
multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001]
|
|
single_targets = [x * y for x, y in zip(single_targets, multipliers)]
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers = [None] * 2
|
|
schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
|
|
schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_cosanneal_and_linearlr(self):
|
|
epochs = 10
|
|
iters = 4
|
|
start_factor = 0.4
|
|
eta_min = 1e-10
|
|
schedulers = [None] * 2
|
|
single_targets = [
|
|
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
|
|
for x in range(epochs)
|
|
]
|
|
for i in range(iters):
|
|
single_targets[i] *= start_factor + i / iters * (1 - start_factor)
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
|
|
schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_cosanneal_and_exp_lr(self):
|
|
epochs = 10
|
|
eta_min = 1e-10
|
|
single_targets = [
|
|
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
|
|
for x in range(epochs)
|
|
]
|
|
multipliers = [0.1**i for i in range(epochs)]
|
|
single_targets = [x * y for x, y in zip(single_targets, multipliers)]
|
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
|
schedulers = [None] * 2
|
|
schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
|
|
schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
|
|
self._test(schedulers, targets, epochs)
|
|
|
|
def test_compound_reduce_lr_on_plateau1(self):
|
|
epochs = 10
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
single_targets = [0.5] * 20
|
|
multipliers = [0.1 ** (i // 3) for i in range(20)]
|
|
single_targets = [x * y for x, y in zip(multipliers, single_targets)]
|
|
targets = [single_targets]
|
|
targets = targets[1:] # test runs step before checking lr
|
|
metrics = [10 - i * 0.0167 for i in range(20)]
|
|
schedulers = [None, None]
|
|
schedulers[0] = ReduceLROnPlateau(
|
|
self.opt,
|
|
threshold_mode="abs",
|
|
mode="min",
|
|
threshold=0.01,
|
|
patience=5,
|
|
cooldown=5,
|
|
)
|
|
schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
|
|
self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
|
|
|
|
def test_compound_reduce_lr_on_plateau2(self):
|
|
epochs = 22
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
|
|
multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10
|
|
single_targets = [x * y for x, y in zip(single_targets, multipliers)]
|
|
targets = [single_targets]
|
|
targets = targets[1:] # test runs step before checking lr
|
|
metrics = [10 - i * 0.0165 for i in range(22)]
|
|
schedulers = [None] * 2
|
|
schedulers[0] = ReduceLROnPlateau(
|
|
self.opt,
|
|
patience=5,
|
|
cooldown=0,
|
|
threshold_mode="abs",
|
|
mode="min",
|
|
threshold=0.1,
|
|
)
|
|
schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12])
|
|
self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
|
|
|
|
def test_compound_reduce_lr_on_plateau3(self):
|
|
epochs = 22
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4
|
|
multipliers = [0.1**i for i in range(epochs)]
|
|
single_targets = [x * y for x, y in zip(multipliers, single_targets)]
|
|
targets = [single_targets]
|
|
targets = targets[1:] # test runs step before checking lr
|
|
metrics = [-0.8] * 2 + [-0.234] * 20
|
|
schedulers = [None, None]
|
|
schedulers[0] = ReduceLROnPlateau(
|
|
self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs"
|
|
)
|
|
schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
|
|
self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
|
|
|
|
def test_compound_reduce_lr_on_plateau4(self):
|
|
epochs = 20
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.05
|
|
epochs = 10
|
|
eta_min = 1e-10
|
|
single_targets = [
|
|
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
|
|
for x in range(epochs)
|
|
]
|
|
targets = [single_targets]
|
|
targets = targets[1:] # test runs step before checking lr
|
|
metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25
|
|
schedulers = [None, None]
|
|
schedulers[0] = ReduceLROnPlateau(
|
|
self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1
|
|
)
|
|
schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min)
|
|
self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
|
|
|
|
def test_compound_reduce_lr_on_plateau5(self):
|
|
iters = 4
|
|
start_factor = 0.4
|
|
epochs = 22
|
|
for param_group in self.opt.param_groups:
|
|
param_group["lr"] = 0.5
|
|
single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
|
|
multipliers = [1] * 22
|
|
for i in range(iters):
|
|
multipliers[i] *= start_factor + i / iters * (1 - start_factor)
|
|
single_targets = [x * y for x, y in zip(single_targets, multipliers)]
|
|
targets = [single_targets]
|
|
targets = targets[1:] # test runs step before checking lr
|
|
metrics = [10 - i * 0.0165 for i in range(22)]
|
|
schedulers = [None] * 2
|
|
schedulers[0] = ReduceLROnPlateau(
|
|
self.opt,
|
|
patience=5,
|
|
cooldown=0,
|
|
threshold_mode="abs",
|
|
mode="min",
|
|
threshold=0.1,
|
|
)
|
|
schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
|
|
self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
|
|
|
|
def test_cycle_lr_invalid_mode(self):
|
|
with self.assertRaises(ValueError):
|
|
scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS")
|
|
|
|
def test_cycle_lr_triangular_mode_one_lr(self):
|
|
lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
|
|
momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=1,
|
|
max_lr=5,
|
|
step_size_up=4,
|
|
cycle_momentum=True,
|
|
base_momentum=1,
|
|
max_momentum=5,
|
|
mode="triangular",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
|
|
|
|
def test_cycle_lr_triangular_mode_one_lr_no_momentum(self):
|
|
lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_target = [self.opt.defaults["momentum"]] * len(lr_target)
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=1,
|
|
max_lr=5,
|
|
step_size_up=4,
|
|
cycle_momentum=False,
|
|
mode="triangular",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
|
|
|
|
def test_cycle_lr_triangular2_mode_one_lr(self):
|
|
lr_target = [
|
|
1,
|
|
2,
|
|
3,
|
|
4,
|
|
5,
|
|
4,
|
|
3,
|
|
2,
|
|
1,
|
|
1.5,
|
|
2.0,
|
|
2.5,
|
|
3.0,
|
|
2.5,
|
|
2.0,
|
|
1.5,
|
|
1,
|
|
1.25,
|
|
1.50,
|
|
1.75,
|
|
2.00,
|
|
1.75,
|
|
]
|
|
momentum_target = [
|
|
5.0,
|
|
4.0,
|
|
3.0,
|
|
2.0,
|
|
1.0,
|
|
2.0,
|
|
3.0,
|
|
4.0,
|
|
5.0,
|
|
4.5,
|
|
4.0,
|
|
3.5,
|
|
3.0,
|
|
3.5,
|
|
4.0,
|
|
4.5,
|
|
5.0,
|
|
4.75,
|
|
4.5,
|
|
4.25,
|
|
4.0,
|
|
4.25,
|
|
]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=1,
|
|
max_lr=5,
|
|
step_size_up=4,
|
|
cycle_momentum=True,
|
|
base_momentum=1,
|
|
max_momentum=5,
|
|
mode="triangular2",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
|
|
|
|
def test_cycle_lr_exp_range_mode_one_lr(self):
|
|
base_lr, max_lr = 1, 5
|
|
diff_lr = max_lr - base_lr
|
|
gamma = 0.9
|
|
xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
|
|
lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
|
|
momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=base_lr,
|
|
max_lr=max_lr,
|
|
step_size_up=4,
|
|
cycle_momentum=True,
|
|
base_momentum=base_lr,
|
|
max_momentum=max_lr,
|
|
mode="exp_range",
|
|
gamma=gamma,
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
|
|
|
|
def test_cycle_lr_triangular_mode(self):
|
|
lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
|
|
lr_target_2 = [x + 1 for x in lr_target_1]
|
|
lr_targets = [lr_target_1, lr_target_2]
|
|
momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
|
|
momentum_target_2 = [x + 1 for x in momentum_target_1]
|
|
momentum_targets = [momentum_target_1, momentum_target_2]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=[1, 2],
|
|
max_lr=[5, 6],
|
|
step_size_up=4,
|
|
cycle_momentum=True,
|
|
base_momentum=[1, 2],
|
|
max_momentum=[5, 6],
|
|
mode="triangular",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
|
|
|
|
def test_cycle_lr_triangular2_mode(self):
|
|
lr_target_1 = [
|
|
1,
|
|
2,
|
|
3,
|
|
4,
|
|
5,
|
|
4,
|
|
3,
|
|
2,
|
|
1,
|
|
1.5,
|
|
2.0,
|
|
2.5,
|
|
3.0,
|
|
2.5,
|
|
2.0,
|
|
1.5,
|
|
1,
|
|
1.25,
|
|
1.50,
|
|
1.75,
|
|
2.00,
|
|
1.75,
|
|
]
|
|
lr_target_2 = [x + 2 for x in lr_target_1]
|
|
lr_targets = [lr_target_1, lr_target_2]
|
|
momentum_target_1 = [
|
|
5.0,
|
|
4.0,
|
|
3.0,
|
|
2.0,
|
|
1.0,
|
|
2.0,
|
|
3.0,
|
|
4.0,
|
|
5.0,
|
|
4.5,
|
|
4.0,
|
|
3.5,
|
|
3.0,
|
|
3.5,
|
|
4.0,
|
|
4.5,
|
|
5.0,
|
|
4.75,
|
|
4.5,
|
|
4.25,
|
|
4.0,
|
|
4.25,
|
|
]
|
|
momentum_target_2 = [x + 2 for x in momentum_target_1]
|
|
momentum_targets = [momentum_target_1, momentum_target_2]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=[1, 3],
|
|
max_lr=[5, 7],
|
|
step_size_up=4,
|
|
cycle_momentum=True,
|
|
base_momentum=[1, 3],
|
|
max_momentum=[5, 7],
|
|
mode="triangular2",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
|
|
|
|
def test_cycle_lr_exp_range_mode(self):
|
|
base_lr_1, max_lr_1 = 1, 5
|
|
base_lr_2, max_lr_2 = 5, 12
|
|
|
|
diff_lr_1 = max_lr_1 - base_lr_1
|
|
diff_lr_2 = max_lr_2 - base_lr_2
|
|
|
|
gamma = 0.9
|
|
xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
|
|
lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)]
|
|
lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)]
|
|
lr_targets = [lr_target_1, lr_target_2]
|
|
momentum_target_1 = [
|
|
max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)
|
|
]
|
|
momentum_target_2 = [
|
|
max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)
|
|
]
|
|
momentum_targets = [momentum_target_1, momentum_target_2]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=[base_lr_1, base_lr_2],
|
|
max_lr=[max_lr_1, max_lr_2],
|
|
step_size_up=4,
|
|
cycle_momentum=True,
|
|
base_momentum=[base_lr_1, base_lr_2],
|
|
max_momentum=[max_lr_1, max_lr_2],
|
|
mode="exp_range",
|
|
gamma=gamma,
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
|
|
|
|
def test_cycle_lr_triangular_mode_step_size_up_down(self):
|
|
lr_target = [
|
|
1.0,
|
|
2.0,
|
|
3.0,
|
|
4.0,
|
|
5.0,
|
|
13.0 / 3,
|
|
11.0 / 3,
|
|
9.0 / 3,
|
|
7.0 / 3,
|
|
5.0 / 3,
|
|
1.0,
|
|
]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_target = [
|
|
5.0,
|
|
4.0,
|
|
3.0,
|
|
2.0,
|
|
1.0,
|
|
5.0 / 3,
|
|
7.0 / 3,
|
|
3.0,
|
|
11.0 / 3,
|
|
13.0 / 3,
|
|
5.0,
|
|
]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=1,
|
|
max_lr=5,
|
|
step_size_up=4,
|
|
step_size_down=6,
|
|
cycle_momentum=True,
|
|
base_momentum=1,
|
|
max_momentum=5,
|
|
mode="triangular",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
|
|
|
|
def test_cycle_lr_triangular2_mode_step_size_up_down(self):
|
|
lr_base_target = [
|
|
1.0,
|
|
3.0,
|
|
5.0,
|
|
13.0 / 3,
|
|
11.0 / 3,
|
|
9.0 / 3,
|
|
7.0 / 3,
|
|
5.0 / 3,
|
|
1.0,
|
|
2.0,
|
|
3.0,
|
|
8.0 / 3,
|
|
7.0 / 3,
|
|
6.0 / 3,
|
|
5.0 / 3,
|
|
4.0 / 3,
|
|
1.0,
|
|
3.0 / 2,
|
|
2.0,
|
|
11.0 / 6,
|
|
10.0 / 6,
|
|
9.0 / 6,
|
|
8.0 / 6,
|
|
7.0 / 6,
|
|
]
|
|
momentum_base_target = [
|
|
5.0,
|
|
3.0,
|
|
1.0,
|
|
5.0 / 3,
|
|
7.0 / 3,
|
|
3.0,
|
|
11.0 / 3,
|
|
13.0 / 3,
|
|
5.0,
|
|
4.0,
|
|
3.0,
|
|
10.0 / 3,
|
|
11.0 / 3,
|
|
4.0,
|
|
13.0 / 3,
|
|
14.0 / 3,
|
|
5.0,
|
|
4.5,
|
|
4.0,
|
|
25.0 / 6,
|
|
13.0 / 3,
|
|
4.5,
|
|
14.0 / 3,
|
|
29.0 / 6,
|
|
]
|
|
deltas = [2 * i for i in range(0, 2)]
|
|
base_lrs = [1 + delta for delta in deltas]
|
|
max_lrs = [5 + delta for delta in deltas]
|
|
lr_targets = [[x + delta for x in lr_base_target] for delta in deltas]
|
|
momentum_targets = [
|
|
[x + delta for x in momentum_base_target] for delta in deltas
|
|
]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=base_lrs,
|
|
max_lr=max_lrs,
|
|
step_size_up=2,
|
|
step_size_down=6,
|
|
cycle_momentum=True,
|
|
base_momentum=base_lrs,
|
|
max_momentum=max_lrs,
|
|
mode="triangular2",
|
|
)
|
|
self._test_cycle_lr(
|
|
scheduler, lr_targets, momentum_targets, len(lr_base_target)
|
|
)
|
|
|
|
def test_cycle_lr_exp_range_mode_step_size_up_down(self):
|
|
base_lr, max_lr = 1, 5
|
|
diff_lr = max_lr - base_lr
|
|
gamma = 0.9
|
|
xs = [
|
|
0.0,
|
|
0.5,
|
|
1.0,
|
|
5.0 / 6,
|
|
4.0 / 6,
|
|
3.0 / 6,
|
|
2.0 / 6,
|
|
1.0 / 6,
|
|
0.0,
|
|
0.5,
|
|
1.0,
|
|
5.0 / 6,
|
|
4.0 / 6,
|
|
]
|
|
lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=base_lr,
|
|
max_lr=max_lr,
|
|
step_size_up=2,
|
|
step_size_down=6,
|
|
cycle_momentum=True,
|
|
base_momentum=base_lr,
|
|
max_momentum=max_lr,
|
|
mode="exp_range",
|
|
gamma=gamma,
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
|
|
|
|
def test_cycle_lr_with_momentumless_optimizer(self):
|
|
# Note [Temporarily set optimizer to Adam]
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
# The TestLRScheduler object carries around an SGD optimizer to avoid having to
|
|
# instantiate one for every test. This gets in the way for our very specific case
|
|
# in which we need to use Adam (or really any optimizer that doesn't use momentum)
|
|
# in order to test that the momentum bug in CyclicLR is fixed (the bug is described
|
|
# in more detail in https://github.com/pytorch/pytorch/issues/19003 ).
|
|
old_opt = self.opt
|
|
self.opt = optim.Adam(
|
|
[
|
|
{"params": self.net.conv1.parameters()},
|
|
{"params": self.net.conv2.parameters(), "lr": 0.5},
|
|
],
|
|
lr=0.05,
|
|
)
|
|
|
|
lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_target = [None] * len(lr_target)
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = CyclicLR(
|
|
self.opt,
|
|
base_lr=1,
|
|
max_lr=5,
|
|
step_size_up=4,
|
|
cycle_momentum=False,
|
|
mode="triangular",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
|
|
|
|
self.opt = old_opt # set optimizer back to SGD
|
|
|
|
def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
|
|
with self.assertRaises(ValueError):
|
|
adam_opt = optim.Adam(self.net.parameters())
|
|
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
|
|
|
|
def test_cycle_lr_removed_after_out_of_scope(self):
|
|
import gc
|
|
import weakref
|
|
|
|
gc.disable()
|
|
|
|
def test():
|
|
adam_opt = optim.Adam(self.net.parameters())
|
|
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
|
|
return weakref.ref(scheduler)
|
|
|
|
ref = test()
|
|
assert ref() is None
|
|
gc.enable()
|
|
|
|
def test_cycle_lr_state_dict_picklable(self):
|
|
adam_opt = optim.Adam(self.net.parameters())
|
|
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
|
|
self.assertIsInstance(scheduler._scale_fn_ref, weakref.WeakMethod)
|
|
state = scheduler.state_dict()
|
|
self.assertNotIn("_scale_fn_ref", state)
|
|
pickle.dumps(state)
|
|
|
|
def test_cycle_lr_scale_fn_restored_from_state_dict(self):
|
|
adam_opt = optim.Adam(self.net.parameters())
|
|
|
|
# Case 1: Built-in mode
|
|
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, mode="triangular2")
|
|
restored_scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
|
|
restored_scheduler.load_state_dict(scheduler.state_dict())
|
|
self.assertTrue(restored_scheduler.mode == scheduler.mode == "triangular2")
|
|
self.assertIsNotNone(restored_scheduler._scale_fn_ref) and self.assertIsNotNone(scheduler._scale_fn_ref)
|
|
self.assertIs(restored_scheduler._scale_fn_custom, None)
|
|
self.assertIs(scheduler._scale_fn_custom, None)
|
|
|
|
# Case 2: Custom `scale_fn`
|
|
def scale_fn(_):
|
|
return 0.5
|
|
|
|
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn)
|
|
restored_scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn)
|
|
restored_scheduler.load_state_dict(scheduler.state_dict())
|
|
self.assertIs(scheduler._scale_fn_custom, scale_fn)
|
|
self.assertIs(restored_scheduler._scale_fn_custom, scale_fn)
|
|
|
|
def test_onecycle_lr_invalid_anneal_strategy(self):
|
|
with self.assertRaises(ValueError):
|
|
scheduler = OneCycleLR(
|
|
self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS"
|
|
)
|
|
|
|
def test_onecycle_lr_invalid_pct_start(self):
|
|
with self.assertRaises(ValueError):
|
|
scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1)
|
|
|
|
def test_onecycle_lr_cannot_calculate_total_steps(self):
|
|
with self.assertRaises(ValueError):
|
|
scheduler = OneCycleLR(self.opt, max_lr=1e-3)
|
|
|
|
def test_onecycle_lr_linear_annealing(self):
|
|
lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
|
|
momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = OneCycleLR(
|
|
self.opt,
|
|
max_lr=25,
|
|
final_div_factor=2,
|
|
base_momentum=1,
|
|
max_momentum=22,
|
|
total_steps=10,
|
|
anneal_strategy="linear",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
|
|
|
|
def test_onecycle_lr_linear_annealing_three_phases(self):
|
|
lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25]
|
|
momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = OneCycleLR(
|
|
self.opt,
|
|
max_lr=25,
|
|
div_factor=25,
|
|
base_momentum=1,
|
|
max_momentum=22,
|
|
total_steps=10,
|
|
anneal_strategy="linear",
|
|
pct_start=0.4,
|
|
final_div_factor=4,
|
|
three_phase=True,
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
|
|
|
|
def test_onecycle_lr_cosine_annealing(self):
|
|
def annealing_cos(start, end, pct):
|
|
cos_out = math.cos(math.pi * pct) + 1
|
|
return end + (start - end) / 2.0 * cos_out
|
|
|
|
lr_target = [
|
|
1,
|
|
13,
|
|
25,
|
|
annealing_cos(25, 0.5, 1 / 7.0),
|
|
annealing_cos(25, 0.5, 2 / 7.0),
|
|
annealing_cos(25, 0.5, 3 / 7.0),
|
|
annealing_cos(25, 0.5, 4 / 7.0),
|
|
annealing_cos(25, 0.5, 5 / 7.0),
|
|
annealing_cos(25, 0.5, 6 / 7.0),
|
|
0.5,
|
|
]
|
|
momentum_target = [
|
|
22,
|
|
11.5,
|
|
1,
|
|
annealing_cos(1, 22, 1 / 7.0),
|
|
annealing_cos(1, 22, 2 / 7.0),
|
|
annealing_cos(1, 22, 3 / 7.0),
|
|
annealing_cos(1, 22, 4 / 7.0),
|
|
annealing_cos(1, 22, 5 / 7.0),
|
|
annealing_cos(1, 22, 6 / 7.0),
|
|
22,
|
|
]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = OneCycleLR(
|
|
self.opt,
|
|
max_lr=25,
|
|
final_div_factor=2,
|
|
base_momentum=1,
|
|
max_momentum=22,
|
|
total_steps=10,
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
|
|
|
|
def test_cycle_lr_with_adam(self):
|
|
old_opt = self.opt
|
|
self.opt = optim.Adam(
|
|
[
|
|
{"params": self.net.conv1.parameters()},
|
|
{"params": self.net.conv2.parameters(), "lr": 0.5},
|
|
],
|
|
lr=0.05,
|
|
)
|
|
|
|
lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
|
|
momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
|
|
lr_targets = [lr_target, lr_target]
|
|
momentum_targets = [momentum_target, momentum_target]
|
|
scheduler = OneCycleLR(
|
|
self.opt,
|
|
max_lr=25,
|
|
final_div_factor=2,
|
|
base_momentum=1,
|
|
max_momentum=22,
|
|
total_steps=10,
|
|
anneal_strategy="linear",
|
|
)
|
|
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True)
|
|
self.opt = old_opt # set optimizer back to SGD
|
|
|
|
def test_lambda_lr(self):
|
|
epochs = 10
|
|
self.opt.param_groups[0]["lr"] = 0.05
|
|
self.opt.param_groups[1]["lr"] = 0.4
|
|
targets = [
|
|
[0.05 * (0.9**x) for x in range(epochs)],
|
|
[0.4 * (0.8**x) for x in range(epochs)],
|
|
]
|
|
scheduler = LambdaLR(
|
|
self.opt, lr_lambda=[lambda x1: 0.9**x1, lambda x2: 0.8**x2]
|
|
)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
def test_multiplicative_lr(self):
|
|
epochs = 10
|
|
self.opt.param_groups[0]["lr"] = 0.05
|
|
self.opt.param_groups[1]["lr"] = 0.4
|
|
targets = [
|
|
[0.05 * (0.9**x) for x in range(epochs)],
|
|
[0.4 * (0.8**x) for x in range(epochs)],
|
|
]
|
|
scheduler = MultiplicativeLR(
|
|
self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8]
|
|
)
|
|
self._test(scheduler, targets, epochs)
|
|
|
|
@parametrize("T_mult", [1, 2, 4])
|
|
def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
|
|
iters = 100
|
|
eta_min = 1e-10
|
|
T_i = 10
|
|
T_cur = 0
|
|
targets = [[0.05], [0.5]]
|
|
scheduler = CosineAnnealingWarmRestarts(
|
|
self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min
|
|
)
|
|
for _ in range(1, iters, 1):
|
|
T_cur += 1
|
|
if T_cur >= T_i:
|
|
T_cur = T_cur - T_i
|
|
T_i = int(T_mult) * T_i
|
|
targets[0] += [
|
|
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
|
|
]
|
|
targets[1] += [
|
|
eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
|
|
]
|
|
self._test(scheduler, targets, iters)
|
|
|
|
def test_CosineAnnealingWarmRestarts_lr2(self):
|
|
iters = 30
|
|
eta_min = 1e-10
|
|
T_mults = [1, 2, 4]
|
|
for T_mult in T_mults:
|
|
T_i = 10
|
|
T_cur = 0
|
|
targets = [[0.05], [0.5]]
|
|
scheduler = CosineAnnealingWarmRestarts(
|
|
self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min
|
|
)
|
|
for _ in torch.arange(0.1, iters, 0.1):
|
|
T_cur = round(T_cur + 0.1, 1)
|
|
if T_cur >= T_i:
|
|
T_cur = T_cur - T_i
|
|
T_i = int(T_mult) * T_i
|
|
targets[0] += [
|
|
eta_min
|
|
+ (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
|
|
]
|
|
targets[1] += [
|
|
eta_min
|
|
+ (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
|
|
]
|
|
self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters)
|
|
|
|
def test_CosineAnnealingWarmRestarts_lr3(self):
|
|
epochs_for_T_mults = [
|
|
[0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13],
|
|
[0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3],
|
|
[0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50],
|
|
]
|
|
T_curs_for_T_mults = [
|
|
[1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3],
|
|
[1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3],
|
|
[0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10],
|
|
]
|
|
T_is_for_T_mults = [
|
|
[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
|
|
[10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10],
|
|
[10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90],
|
|
]
|
|
eta_min = 1e-10
|
|
T_mults = [1, 2, 3]
|
|
for epochs, T_mult, T_curs, T_is in zip(
|
|
epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults
|
|
):
|
|
targets = [[0.05], [0.5]]
|
|
scheduler = CosineAnnealingWarmRestarts(
|
|
self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min
|
|
)
|
|
for T_cur, T_i in zip(T_curs, T_is):
|
|
targets[0] += [
|
|
eta_min
|
|
+ (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
|
|
]
|
|
targets[1] += [
|
|
eta_min
|
|
+ (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
|
|
]
|
|
self._test_interleaved_CosineAnnealingWarmRestarts(
|
|
scheduler, targets, epochs
|
|
)
|
|
|
|
def test_swalr_no_anneal(self):
|
|
epochs, swa_start, swa_lr = 10, 5, 0.01
|
|
initial_lrs = [group["lr"] for group in self.opt.param_groups]
|
|
targets = [
|
|
[lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1)
|
|
for lr in initial_lrs
|
|
]
|
|
swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr)
|
|
self._test_swalr(swa_scheduler, None, targets, swa_start, epochs)
|
|
|
|
def test_swalr_cosine_anneal_after_multiplicative(self):
|
|
# same swa_lr for different param_groups
|
|
epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5
|
|
mult_factor = 0.9
|
|
scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
|
|
swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr)
|
|
|
|
def anneal_coef(t):
|
|
if t + 1 >= anneal_epochs:
|
|
return 0.0
|
|
return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2
|
|
|
|
initial_lrs = [group["lr"] for group in self.opt.param_groups]
|
|
targets_before_swa = [
|
|
[lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs
|
|
]
|
|
swa_epochs = epochs - swa_start - 1
|
|
targets = [
|
|
lrs
|
|
+ [
|
|
lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))
|
|
for t in range(swa_epochs)
|
|
]
|
|
for lrs in targets_before_swa
|
|
]
|
|
|
|
self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)
|
|
|
|
def test_swalr_linear_anneal_after_multiplicative(self):
|
|
# separate swa_lr for different param_groups
|
|
epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4
|
|
mult_factor = 0.9
|
|
scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
|
|
swa_scheduler = SWALR(
|
|
self.opt,
|
|
anneal_epochs=anneal_epochs,
|
|
anneal_strategy="linear",
|
|
swa_lr=swa_lrs,
|
|
)
|
|
|
|
def anneal_coef(t):
|
|
if t + 1 >= anneal_epochs:
|
|
return 0.0
|
|
return 1 - (t + 1) / anneal_epochs
|
|
|
|
initial_lrs = [group["lr"] for group in self.opt.param_groups]
|
|
targets_before_swa = [
|
|
[lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs
|
|
]
|
|
swa_epochs = epochs - swa_start - 1
|
|
targets = [
|
|
lrs
|
|
+ [
|
|
lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))
|
|
for t in range(swa_epochs)
|
|
]
|
|
for lrs, swa_lr in zip(targets_before_swa, swa_lrs)
|
|
]
|
|
|
|
self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)
|
|
|
|
def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs):
|
|
for epoch in range(epochs):
|
|
for param_group, target in zip(self.opt.param_groups, targets):
|
|
self.assertEqual(
|
|
target[epoch],
|
|
param_group["lr"],
|
|
msg="LR is wrong in epoch {}: expected {}, got {}".format(
|
|
epoch, target[epoch], param_group["lr"]
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
if epoch >= swa_start:
|
|
self.opt.step()
|
|
swa_scheduler.step()
|
|
elif scheduler is not None:
|
|
self.opt.step()
|
|
scheduler.step()
|
|
|
|
def test_swalr_hypers(self):
|
|
# Test that SWALR raises errors for incorrect hyper-parameters
|
|
with self.assertRaisesRegex(ValueError, "anneal_strategy must"):
|
|
swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0)
|
|
|
|
with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
|
|
swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0)
|
|
with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
|
|
swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0)
|
|
with self.assertRaisesRegex(ValueError, "swa_lr must"):
|
|
swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01])
|
|
|
|
def test_step_lr_state_dict(self):
|
|
self._check_scheduler_state_dict(
|
|
lambda: StepLR(self.opt, gamma=0.1, step_size=3),
|
|
lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1),
|
|
)
|
|
|
|
def test_multi_step_lr_state_dict(self):
|
|
self._check_scheduler_state_dict(
|
|
lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]),
|
|
lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]),
|
|
)
|
|
|
|
def test_exp_step_lr_state_dict(self):
|
|
self._check_scheduler_state_dict(
|
|
lambda: ExponentialLR(self.opt, gamma=0.1),
|
|
lambda: ExponentialLR(self.opt, gamma=0.01),
|
|
)
|
|
|
|
def test_cosine_lr_state_dict(self):
|
|
epochs = 10
|
|
eta_min = 1e-10
|
|
self._check_scheduler_state_dict(
|
|
lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min),
|
|
lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
|
|
epochs=epochs,
|
|
)
|
|
|
|
def test_reduce_lr_on_plateau_state_dict(self):
|
|
scheduler = ReduceLROnPlateau(self.opt, mode="min", factor=0.1, patience=2)
|
|
for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]:
|
|
scheduler.step(score)
|
|
scheduler_copy = ReduceLROnPlateau(
|
|
self.opt, mode="max", factor=0.5, patience=10
|
|
)
|
|
scheduler_copy.load_state_dict(scheduler.state_dict())
|
|
for key in scheduler.__dict__.keys():
|
|
if key not in {"optimizer", "is_better"}:
|
|
self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
|
|
|
|
def test_lambda_lr_state_dict_fn(self):
|
|
scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x)
|
|
state = scheduler.state_dict()
|
|
self.assertIsNone(state["lr_lambdas"][0])
|
|
|
|
scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x)
|
|
scheduler_copy.load_state_dict(state)
|
|
for key in scheduler.__dict__.keys():
|
|
if key not in {"optimizer", "lr_lambdas"}:
|
|
self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
|
|
|
|
def test_lambda_lr_state_dict_obj(self):
|
|
scheduler = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(10))
|
|
state = scheduler.state_dict()
|
|
self.assertIsNotNone(state["lr_lambdas"][0])
|
|
|
|
scheduler_copy = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(-1))
|
|
scheduler_copy.load_state_dict(state)
|
|
for key in scheduler.__dict__.keys():
|
|
if key not in {"optimizer"}:
|
|
self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
|
|
|
|
def test_CosineAnnealingWarmRestarts_lr_state_dict(self):
|
|
self._check_scheduler_state_dict(
|
|
lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2),
|
|
lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100),
|
|
)
|
|
|
|
def test_swa_lr_state_dict(self):
|
|
self._check_scheduler_state_dict(
|
|
lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5),
|
|
lambda: SWALR(
|
|
self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0
|
|
),
|
|
)
|
|
|
|
def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
|
|
scheduler = constr()
|
|
for _ in range(epochs):
|
|
scheduler.optimizer.step()
|
|
scheduler.step()
|
|
scheduler_copy = constr2()
|
|
scheduler_copy.load_state_dict(scheduler.state_dict())
|
|
for key in scheduler.__dict__.keys():
|
|
if key != "optimizer":
|
|
self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
|
|
self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr())
|
|
|
|
def _test_get_last_lr(self, schedulers, targets, epochs=10):
|
|
if isinstance(schedulers, LRScheduler):
|
|
schedulers = [schedulers]
|
|
optimizers = {scheduler.optimizer for scheduler in schedulers}
|
|
for epoch in range(epochs):
|
|
result = [scheduler.get_last_lr() for scheduler in schedulers]
|
|
[optimizer.step() for optimizer in optimizers]
|
|
[scheduler.step() for scheduler in schedulers]
|
|
target = [[t[epoch] for t in targets]] * len(schedulers)
|
|
for t, r in zip(target, result):
|
|
self.assertEqual(
|
|
target,
|
|
result,
|
|
msg="LR is wrong in epoch {}: expected {}, got {}".format(
|
|
epoch, t, r
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
|
|
def _test_with_epoch(self, schedulers, targets, epochs=10):
|
|
if isinstance(schedulers, LRScheduler):
|
|
schedulers = [schedulers]
|
|
optimizers = {scheduler.optimizer for scheduler in schedulers}
|
|
for epoch in range(epochs):
|
|
[optimizer.step() for optimizer in optimizers]
|
|
with warnings.catch_warnings(record=True) as w:
|
|
[
|
|
scheduler.step(epoch) for scheduler in schedulers
|
|
] # step before assert: skip initial lr
|
|
self._check_warning_is_epoch_deprecation_warning(
|
|
w, num_warnings=len(schedulers)
|
|
)
|
|
for param_group, target in zip(self.opt.param_groups, targets):
|
|
self.assertEqual(
|
|
target[epoch],
|
|
param_group["lr"],
|
|
msg="LR is wrong in epoch {}: expected {}, got {}".format(
|
|
epoch, target[epoch], param_group["lr"]
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
|
|
def _test(self, schedulers, targets, epochs=10):
|
|
if isinstance(schedulers, LRScheduler):
|
|
schedulers = [schedulers]
|
|
for epoch in range(epochs):
|
|
for param_group, target in zip(self.opt.param_groups, targets):
|
|
self.assertEqual(
|
|
target[epoch],
|
|
param_group["lr"],
|
|
msg="LR is wrong in epoch {}: expected {}, got {}".format(
|
|
epoch, target[epoch], param_group["lr"]
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
[scheduler.step() for scheduler in schedulers]
|
|
|
|
def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10):
|
|
for index, epoch in enumerate(torch.arange(0, epochs, 0.1)):
|
|
epoch = round(epoch.item(), 1)
|
|
scheduler.step(epoch)
|
|
for param_group, target in zip(self.opt.param_groups, targets):
|
|
self.assertEqual(
|
|
target[index],
|
|
param_group["lr"],
|
|
msg="LR is wrong in epoch {}: expected {}, got {}".format(
|
|
epoch, target[index], param_group["lr"]
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
|
|
def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs):
|
|
for index, epoch in enumerate(epochs):
|
|
scheduler.step(epoch)
|
|
for param_group, target in zip(self.opt.param_groups, targets):
|
|
self.assertEqual(
|
|
target[index],
|
|
param_group["lr"],
|
|
msg="LR is wrong in epoch {}: expected {}, got {}".format(
|
|
epoch, target[index], param_group["lr"]
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
|
|
def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10):
|
|
self.setUp()
|
|
targets = []
|
|
for epoch in range(epochs):
|
|
closed_form_scheduler.optimizer.step()
|
|
with warnings.catch_warnings(record=True) as w:
|
|
closed_form_scheduler.step(epoch)
|
|
self._check_warning_is_epoch_deprecation_warning(w)
|
|
targets.append([group["lr"] for group in self.opt.param_groups])
|
|
self.setUp()
|
|
for epoch in range(epochs):
|
|
self.opt.step()
|
|
scheduler.step()
|
|
for i, param_group in enumerate(self.opt.param_groups):
|
|
self.assertEqual(
|
|
targets[epoch][i],
|
|
param_group["lr"],
|
|
msg="LR is wrong in epoch {}: expected {}, got {}".format(
|
|
epoch, targets[epoch][i], param_group["lr"]
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
|
|
def _test_reduce_lr_on_plateau(
|
|
self, schedulers, targets, metrics, epochs=10, verbose=False
|
|
):
|
|
if isinstance(schedulers, (LRScheduler, ReduceLROnPlateau)):
|
|
schedulers = [schedulers]
|
|
for epoch in range(epochs):
|
|
self.opt.step()
|
|
for scheduler in schedulers:
|
|
if isinstance(scheduler, ReduceLROnPlateau):
|
|
scheduler.step(metrics[epoch])
|
|
else:
|
|
scheduler.step()
|
|
if verbose:
|
|
print("epoch{}:\tlr={}".format(epoch, self.opt.param_groups[0]["lr"]))
|
|
for param_group, target in zip(self.opt.param_groups, targets):
|
|
self.assertEqual(
|
|
target[epoch],
|
|
param_group["lr"],
|
|
msg="LR is wrong in epoch {}: expected {}, got {}".format(
|
|
epoch, target[epoch], param_group["lr"]
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
|
|
def _test_cycle_lr(
|
|
self,
|
|
scheduler,
|
|
lr_targets,
|
|
momentum_targets,
|
|
batch_iterations,
|
|
verbose=False,
|
|
use_beta1=False,
|
|
):
|
|
for batch_num in range(batch_iterations):
|
|
if verbose:
|
|
if "momentum" in self.opt.param_groups[0].keys():
|
|
print(
|
|
"batch{}:\tlr={},momentum={}".format(
|
|
batch_num,
|
|
self.opt.param_groups[0]["lr"],
|
|
self.opt.param_groups[0]["momentum"],
|
|
)
|
|
)
|
|
elif use_beta1 and "betas" in self.opt.param_groups[0].keys():
|
|
print(
|
|
"batch{}:\tlr={},beta1={}".format(
|
|
batch_num,
|
|
self.opt.param_groups[0]["lr"],
|
|
self.opt.param_groups[0]["betas"][0],
|
|
)
|
|
)
|
|
else:
|
|
print(
|
|
"batch{}:\tlr={}".format(
|
|
batch_num, self.opt.param_groups[0]["lr"]
|
|
)
|
|
)
|
|
|
|
for param_group, lr_target, momentum_target in zip(
|
|
self.opt.param_groups, lr_targets, momentum_targets
|
|
):
|
|
self.assertEqual(
|
|
lr_target[batch_num],
|
|
param_group["lr"],
|
|
msg="LR is wrong in batch_num {}: expected {}, got {}".format(
|
|
batch_num, lr_target[batch_num], param_group["lr"]
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
|
|
if use_beta1 and "betas" in param_group.keys():
|
|
self.assertEqual(
|
|
momentum_target[batch_num],
|
|
param_group["betas"][0],
|
|
msg="Beta1 is wrong in batch_num {}: expected {}, got {}".format(
|
|
batch_num,
|
|
momentum_target[batch_num],
|
|
param_group["betas"][0],
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
elif "momentum" in param_group.keys():
|
|
self.assertEqual(
|
|
momentum_target[batch_num],
|
|
param_group["momentum"],
|
|
msg="Momentum is wrong in batch_num {}: expected {}, got {}".format(
|
|
batch_num,
|
|
momentum_target[batch_num],
|
|
param_group["momentum"],
|
|
),
|
|
atol=1e-5,
|
|
rtol=0,
|
|
)
|
|
self.opt.step()
|
|
scheduler.step()
|
|
|
|
def test_cosine_then_cyclic(self):
|
|
# https://github.com/pytorch/pytorch/issues/21965
|
|
|
|
max_lr = 0.3
|
|
base_lr = 0.1
|
|
optim_lr = 0.5
|
|
|
|
model = torch.nn.Linear(2, 1)
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=optim_lr)
|
|
lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer, T_max=20, eta_min=0.1
|
|
)
|
|
lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR(
|
|
optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3
|
|
)
|
|
|
|
for i in range(40):
|
|
optimizer.step()
|
|
if i <= lr_scheduler_1.T_max:
|
|
lr_scheduler_1.step()
|
|
else:
|
|
lr_scheduler_2.step()
|
|
last_lr = optimizer.param_groups[0]["lr"]
|
|
|
|
self.assertLessEqual(last_lr, max_lr)
|
|
|
|
|
|
class SWATestDNN(torch.nn.Module):
|
|
def __init__(self, input_features):
|
|
super(SWATestDNN, self).__init__()
|
|
self.n_features = 100
|
|
self.fc1 = torch.nn.Linear(input_features, self.n_features)
|
|
self.bn = torch.nn.BatchNorm1d(self.n_features)
|
|
|
|
def compute_preactivation(self, x):
|
|
return self.fc1(x)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
|
|
class SWATestCNN(torch.nn.Module):
|
|
def __init__(self, input_channels):
|
|
super(SWATestCNN, self).__init__()
|
|
self.n_features = 10
|
|
self.conv1 = torch.nn.Conv2d(
|
|
input_channels, self.n_features, kernel_size=3, padding=1
|
|
)
|
|
self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3)
|
|
|
|
def compute_preactivation(self, x):
|
|
return self.conv1(x)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
|
|
class TestSWAUtils(TestCase):
|
|
def _test_averaged_model(self, net_device, swa_device):
|
|
dnn = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, 5, kernel_size=3),
|
|
torch.nn.ReLU(),
|
|
torch.nn.MaxPool2d(kernel_size=2),
|
|
torch.nn.BatchNorm2d(5, momentum=0.3),
|
|
torch.nn.Conv2d(5, 2, kernel_size=3),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(5, 5),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(5, 10),
|
|
).to(net_device)
|
|
|
|
averaged_dnn = AveragedModel(dnn, device=swa_device)
|
|
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
|
|
n_updates = 10
|
|
for i in range(n_updates):
|
|
for p, p_avg in zip(dnn.parameters(), averaged_params):
|
|
p.detach().add_(torch.randn_like(p))
|
|
p_avg += p.detach() / n_updates
|
|
averaged_dnn.update_parameters(dnn)
|
|
|
|
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
|
|
self.assertEqual(p_avg, p_swa)
|
|
# Check that AveragedModel is on the correct device
|
|
self.assertTrue(p_swa.device == swa_device)
|
|
self.assertTrue(p.device == net_device)
|
|
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
|
|
|
|
def test_averaged_model_all_devices(self):
|
|
cpu = torch.device("cpu")
|
|
self._test_averaged_model(cpu, cpu)
|
|
if torch.cuda.is_available():
|
|
cuda = torch.device(0)
|
|
self._test_averaged_model(cuda, cpu)
|
|
self._test_averaged_model(cpu, cuda)
|
|
self._test_averaged_model(cuda, cuda)
|
|
|
|
def test_averaged_model_mixed_device(self):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
dnn = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
|
|
)
|
|
dnn[0].cuda()
|
|
dnn[1].cpu()
|
|
averaged_dnn = AveragedModel(dnn)
|
|
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
|
|
n_updates = 10
|
|
for i in range(n_updates):
|
|
for p, p_avg in zip(dnn.parameters(), averaged_params):
|
|
p.detach().add_(torch.randn_like(p))
|
|
p_avg += p.detach() / n_updates
|
|
averaged_dnn.update_parameters(dnn)
|
|
|
|
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
|
|
self.assertEqual(p_avg, p_swa)
|
|
# Check that AveragedModel is on the correct device
|
|
self.assertTrue(p_avg.device == p_swa.device)
|
|
|
|
def test_averaged_model_state_dict(self):
|
|
dnn = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
|
|
)
|
|
averaged_dnn = AveragedModel(dnn)
|
|
averaged_dnn2 = AveragedModel(dnn)
|
|
n_updates = 10
|
|
for i in range(n_updates):
|
|
for p in dnn.parameters():
|
|
p.detach().add_(torch.randn_like(p))
|
|
averaged_dnn.update_parameters(dnn)
|
|
averaged_dnn2.load_state_dict(averaged_dnn.state_dict())
|
|
for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
|
|
self.assertEqual(p_swa, p_swa2)
|
|
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)
|
|
|
|
def test_averaged_model_exponential(self):
|
|
# Test AveragedModel with EMA as avg_fn
|
|
dnn = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, 5, kernel_size=3),
|
|
torch.nn.BatchNorm2d(5, momentum=0.3),
|
|
torch.nn.Linear(5, 10),
|
|
)
|
|
alpha = 0.9
|
|
|
|
def avg_fn(p_avg, p, n_avg):
|
|
return alpha * p_avg + (1 - alpha) * p
|
|
|
|
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn)
|
|
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
|
|
n_updates = 10
|
|
for i in range(n_updates):
|
|
updated_averaged_params = []
|
|
for p, p_avg in zip(dnn.parameters(), averaged_params):
|
|
p.detach().add_(torch.randn_like(p))
|
|
if i == 0:
|
|
updated_averaged_params.append(p.clone())
|
|
else:
|
|
updated_averaged_params.append(
|
|
(p_avg * alpha + p * (1 - alpha)).clone()
|
|
)
|
|
for b in dnn.buffers():
|
|
if b.size() != torch.Size([]):
|
|
b.detach_().add_(torch.randn_like(b))
|
|
|
|
averaged_dnn.update_parameters(dnn)
|
|
averaged_params = updated_averaged_params
|
|
|
|
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
|
|
self.assertEqual(p_avg, p_swa)
|
|
for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
|
|
self.assertEqual(b_avg, b_swa)
|
|
|
|
def test_averaged_model_exponential_buffers(self):
|
|
# Test AveragedModel with EMA as avg_fn and use_buffers as True.
|
|
dnn = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, 5, kernel_size=3),
|
|
torch.nn.BatchNorm2d(5, momentum=0.3),
|
|
torch.nn.Linear(5, 10),
|
|
)
|
|
alpha = 0.9
|
|
|
|
def avg_fn(p_avg, p, n_avg):
|
|
return alpha * p_avg + (1 - alpha) * p
|
|
|
|
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True)
|
|
dnn_params = itertools.chain(dnn.parameters(), dnn.buffers())
|
|
averaged_params = [
|
|
torch.zeros_like(param)
|
|
for param in dnn_params
|
|
if param.size() != torch.Size([])
|
|
]
|
|
n_updates = 10
|
|
for i in range(n_updates):
|
|
updated_averaged_params = []
|
|
for p, p_avg in zip(dnn_params, averaged_params):
|
|
if p.size() == torch.Size([]):
|
|
continue
|
|
p.detach().add_(torch.randn_like(p))
|
|
if i == 0:
|
|
updated_averaged_params.append(p.clone())
|
|
else:
|
|
updated_averaged_params.append(
|
|
(p_avg * alpha + p * (1 - alpha)).clone()
|
|
)
|
|
averaged_dnn.update_parameters(dnn)
|
|
averaged_params = updated_averaged_params
|
|
|
|
for p_avg, p_swa in zip(
|
|
averaged_params,
|
|
itertools.chain(
|
|
averaged_dnn.module.parameters(), averaged_dnn.module.buffers()
|
|
),
|
|
):
|
|
self.assertEqual(p_avg, p_swa)
|
|
|
|
def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
|
|
|
|
preactivation_sum = torch.zeros(dnn.n_features)
|
|
preactivation_squared_sum = torch.zeros(dnn.n_features)
|
|
if cuda:
|
|
preactivation_sum = preactivation_sum.cuda()
|
|
preactivation_squared_sum = preactivation_squared_sum.cuda()
|
|
total_num = 0
|
|
for x in dl_x:
|
|
x = x[0]
|
|
if cuda:
|
|
x = x.cuda()
|
|
|
|
dnn.forward(x)
|
|
preactivations = dnn.compute_preactivation(x)
|
|
if len(preactivations.shape) == 4:
|
|
preactivations = preactivations.transpose(1, 3)
|
|
preactivations = preactivations.contiguous().view(-1, dnn.n_features)
|
|
total_num += preactivations.shape[0]
|
|
|
|
preactivation_sum += torch.sum(preactivations, dim=0)
|
|
preactivation_squared_sum += torch.sum(preactivations**2, dim=0)
|
|
|
|
preactivation_mean = preactivation_sum / total_num
|
|
preactivation_var = preactivation_squared_sum / total_num
|
|
preactivation_var = preactivation_var - preactivation_mean**2
|
|
|
|
update_bn(dl_xy, dnn, device=x.device)
|
|
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
|
|
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
|
|
|
|
def _reset_bn(module):
|
|
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
|
|
module.running_mean = torch.zeros_like(module.running_mean)
|
|
module.running_var = torch.ones_like(module.running_var)
|
|
|
|
# reset batch norm and run update_bn again
|
|
dnn.apply(_reset_bn)
|
|
update_bn(dl_xy, dnn, device=x.device)
|
|
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
|
|
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
|
|
# using the dl_x loader instead of dl_xy
|
|
dnn.apply(_reset_bn)
|
|
update_bn(dl_x, dnn, device=x.device)
|
|
self.assertEqual(preactivation_mean, dnn.bn.running_mean)
|
|
self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
|
|
|
|
def test_update_bn_dnn(self):
|
|
# Test update_bn for a fully-connected network with BatchNorm1d
|
|
objects, input_features = 100, 5
|
|
x = torch.rand(objects, input_features)
|
|
y = torch.rand(objects)
|
|
ds_x = torch.utils.data.TensorDataset(x)
|
|
ds_xy = torch.utils.data.TensorDataset(x, y)
|
|
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
|
|
dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
|
|
dnn = SWATestDNN(input_features=input_features)
|
|
dnn.train()
|
|
self._test_update_bn(dnn, dl_x, dl_xy, False)
|
|
if torch.cuda.is_available():
|
|
dnn = SWATestDNN(input_features=input_features)
|
|
dnn.train()
|
|
self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
|
|
self.assertTrue(dnn.training)
|
|
|
|
def test_update_bn_cnn(self):
|
|
# Test update_bn for convolutional network and BatchNorm2d
|
|
objects = 100
|
|
input_channels = 3
|
|
height, width = 5, 5
|
|
x = torch.rand(objects, input_channels, height, width)
|
|
y = torch.rand(objects)
|
|
ds_x = torch.utils.data.TensorDataset(x)
|
|
ds_xy = torch.utils.data.TensorDataset(x, y)
|
|
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
|
|
dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
|
|
dnn = SWATestCNN(input_channels=input_channels)
|
|
dnn.train()
|
|
self._test_update_bn(dnn, dl_x, dl_xy, False)
|
|
if torch.cuda.is_available():
|
|
dnn = SWATestCNN(input_channels=input_channels)
|
|
dnn.train()
|
|
self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
|
|
self.assertTrue(dnn.training)
|
|
|
|
def test_bn_update_eval_momentum(self):
|
|
# check that update_bn preserves eval mode
|
|
objects = 100
|
|
input_channels = 3
|
|
height, width = 5, 5
|
|
x = torch.rand(objects, input_channels, height, width)
|
|
ds_x = torch.utils.data.TensorDataset(x)
|
|
dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
|
|
dnn = SWATestCNN(input_channels=input_channels)
|
|
dnn.eval()
|
|
update_bn(dl_x, dnn)
|
|
self.assertFalse(dnn.training)
|
|
|
|
# check that momentum is preserved
|
|
self.assertEqual(dnn.bn.momentum, 0.3)
|
|
|
|
|
|
instantiate_parametrized_tests(TestLRScheduler)
|
|
|
|
|
|
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
|
|
# Ignored is the list of values in `opt_differentiable_state`, we do this
|
|
# for `gradcheck` to correctly track the state tensors as function inputs
|
|
# because otherwise it can't unpack the values in the `opt_differentiable_state`
|
|
# dict
|
|
p = p.clone()
|
|
p.grad = grad
|
|
opt_differentiable_state = {
|
|
k: v.clone() if isinstance(v, torch.Tensor) else v
|
|
for k, v in opt_differentiable_state.items()
|
|
}
|
|
opt = opt_class([p], **kwargs)
|
|
opt.state[p].update(opt_differentiable_state)
|
|
opt.step()
|
|
return (p,) + tuple(
|
|
v
|
|
for v in opt.state[p].values()
|
|
if isinstance(v, torch.Tensor) and v.requires_grad
|
|
)
|
|
|
|
|
|
class TestDifferentiableOptimizer(TestCase):
|
|
def test_sgd(self):
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state = {"momentum_buffer": mbuff}
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.SGD,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_adam(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["max_exp_avg_sq"] = torch.rand(
|
|
10, requires_grad=True, dtype=torch.float64
|
|
)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Adam,
|
|
{"lr": 0.9, "differentiable": True, "amsgrad": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_rmsprop(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["step"] = 0
|
|
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["momentum_buffer"] = torch.rand(
|
|
10, requires_grad=True, dtype=torch.float64
|
|
)
|
|
# This can cause issues with large values and nan due to sqrt ops
|
|
state["grad_avg"] = 1e-2 * torch.rand(
|
|
10, requires_grad=True, dtype=torch.float64
|
|
)
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.RMSprop,
|
|
{
|
|
"lr": 0.9,
|
|
"maximize": True,
|
|
"momentum": 0.9,
|
|
"differentiable": True,
|
|
"centered": True,
|
|
"weight_decay": 0.1,
|
|
},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_adadelta(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Adadelta,
|
|
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_adagrad(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Adagrad,
|
|
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_adamax(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Adamax,
|
|
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
@skipIfTorchDynamo("The inplace mu update fails with dynamo, "
|
|
"since this is only happening when differentiable is enabled, skipping for now")
|
|
def test_asgd(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` `eta` & `mu` are not continuous variables (even though we define them as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
|
|
state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
|
|
state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.ASGD,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_rprop(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.Rprop,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_adamw(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["max_exp_avg_sq"] = torch.rand(
|
|
10, requires_grad=True, dtype=torch.float64
|
|
)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.AdamW,
|
|
{"lr": 0.9, "differentiable": True, "amsgrad": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_nadam(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.NAdam,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
def test_radam(self):
|
|
state = {}
|
|
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
# `step` is not a continuous variable (even though we define it as a float)
|
|
# and so it shouldn't require gradients.
|
|
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
|
|
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
|
|
|
|
gradcheck(
|
|
_diff_fn,
|
|
(
|
|
p,
|
|
grad,
|
|
state,
|
|
torch.optim.RAdam,
|
|
{"lr": 0.9, "differentiable": True},
|
|
*state.values(),
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|