To add Nesterov Adam algorithm for multi-tensor optimizers API (#59165)

Summary:
Previously in the PR: https://github.com/pytorch/pytorch/issues/59009 we added NAdam to Optimizers.  Here in this PR we are proposing multi-tensor version of NAdam for PyTorch.

Nadam has been proposed in the paper   https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ and report  and report : http://cs229.stanford.edu/proj2015/054_report.pdf by Timothy Dozat.

It has been one of the most used algorithm in Deep Learning community.

It worth to noting that the implementation of NAdam is inspired by the implementation for Keras :
f9d3868495/tensorflow/python/keras/optimizer_v2/nadam.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59165

Reviewed By: vincentqb

Differential Revision: D29360577

Pulled By: iramazanli

fbshipit-source-id: 0fe14016303b2df2cb8cc31912a2674acf63d1e5
This commit is contained in:
Ilqar Ramazanli 2021-06-27 16:57:32 -07:00 committed by Facebook GitHub Bot
parent 3bfe15085d
commit f0e972a481
6 changed files with 211 additions and 22 deletions

View File

@ -318,6 +318,10 @@ class TestOptim(TestCase):
((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=1., amsgrad=False)),
((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=True)),
((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=False)),
((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0., momentum_decay=6e-3)),
((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=1., momentum_decay=6e-3)),
((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0., momentum_decay=4e-3)),
((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0.01, momentum_decay=4e-3)),
((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True)),
((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False)),
((optim.RAdam, optim._multi_tensor.RAdam), dict(weight_decay=0)),
@ -378,7 +382,7 @@ class TestOptim(TestCase):
res.append(model.parameters())
for p1, p2 in zip(res[0], res[1]):
self.assertEqual(p1, p2)
self.assertEqual(p1, p2, atol=5e-5, rtol=0)
def test_adam(self):
for optimizer in [optim.Adam, optim_mt.Adam]:
@ -481,25 +485,26 @@ class TestOptim(TestCase):
optimizer(None, lr=1e-2, rho=1.1)
def test_nadam(self):
self._test_basic_cases(
lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3)
)
self._test_basic_cases(
lambda weight, bias: optim.NAdam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3)
)
self._test_basic_cases(
lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3)
)
self._test_basic_cases(
lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3),
[lambda opt: ExponentialLR(opt, gamma=0.9)]
)
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0))
with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
optim.NAdam(None, lr=1e-2, momentum_decay=-0.2)
for optimizer in [optim.NAdam, optim_mt.NAdam]:
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3)
)
self._test_basic_cases(
lambda weight, bias: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3)
)
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3)
)
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3),
[lambda opt: ExponentialLR(opt, gamma=0.9)]
)
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
optimizer(None, lr=1e-2, betas=(1.0, 0.0))
with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
optimizer(None, lr=1e-2, momentum_decay=-0.2)
def test_adagrad(self):
self._test_basic_cases(

View File

@ -7,6 +7,7 @@ future.
from .adam import Adam
from .adamw import AdamW
from .nadam import NAdam
from .sgd import SGD
from .radam import RAdam as RAdam
from .rmsprop import RMSprop
@ -19,6 +20,7 @@ del adam
del adamw
del sgd
del radam
del nadam
del rmsprop
del rprop
del asgd

View File

@ -1,5 +1,6 @@
from .adam import Adam as Adam
from .adamw import AdamW as AdamW
from .nadam import NAdam as NAdam
from .sgd import SGD as SGD
from .radam import RAdam as RAdam
from .rmsprop import RMSprop as RMSprop

View File

@ -123,9 +123,9 @@ def radam(params: List[Tensor],
lr: float,
weight_decay: float,
eps: float):
r"""Functional API that performs Adam algorithm computation.
r"""Functional API that performs RAdam algorithm computation.
See :class:`~torch.optim.Adam` for details.
See :class:`~torch.optim.RAdam` for details.
"""
# maximum length of the approximated SMA
@ -158,3 +158,49 @@ def radam(params: List[Tensor],
denom = [torch.ones_like(exp_av, memory_format=torch.preserve_format) for exp_av in exp_avg]
step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]
torch._foreach_addcdiv_(params, exp_avg, denom, step_size)
def nadam(params: List[Tensor],
grads: List[Tensor],
exp_avg: List[Tensor],
exp_avg_sq: List[Tensor],
mu_products: List[Tensor],
states: List[Dict],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
momentum_decay: float,
eps: float):
r"""Functional API that performs NAdam algorithm computation.
See :class:`~torch.optim.NAdam` for details.
"""
bias_correction1 = [1 - beta1 ** state['step'] for state in states]
bias_correction2 = [1 - beta2 ** state['step'] for state in states]
mus = [beta1 * (1. - 0.5 * (0.96 ** (state['step'] * momentum_decay))) for state in states]
mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((state['step'] + 1) * momentum_decay)))
for state in states]
if weight_decay != 0:
torch._foreach_add_(grads, params, alpha=weight_decay)
# Decay the first and second moment running average coefficient
torch._foreach_mul_(exp_avg, beta1)
torch._foreach_add_(exp_avg, grads, alpha=1 - beta1)
torch._foreach_mul_(exp_avg_sq, beta2)
torch._foreach_addcmul_(exp_avg_sq, grads, grads, 1 - beta2)
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sq)
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
step_size_grads = [(lr * (1. - mu) / (1. - mu_product)) * -1
for mu_product, mu in zip(mu_products, mus)]
step_size_expavg = [(lr * mu_next / (1. - mu_product * mu_next)) * -1
for mu_product, mu_next in zip(mu_products, mu_nexts)]
torch._foreach_addcdiv_(params, grads, denom, step_size_grads)
torch._foreach_addcdiv_(params, exp_avg, denom, step_size_expavg)

View File

@ -0,0 +1,130 @@
import torch
from . import _functional as F
from ..optimizer import Optimizer
from collections import defaultdict
class NAdam(Optimizer):
r"""Implements NAdam algorithm with multi tensor APIs.
It has been proposed in `Incorporating Nesterov Momentum into Adam`_.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 2e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
.. _Incorporating Nesterov Momentum into Adam:
https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
"""
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, momentum_decay=4e-3):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= momentum_decay:
raise ValueError("Invalid momentum_decay value: {}".format(momentum_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, momentum_decay=momentum_decay)
super(NAdam, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avg = []
exp_avg_sq = []
mu_products = []
states = []
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is not None:
if p.grad.is_sparse:
raise RuntimeError('NAdam does not support sparse gradients')
params_with_grad.append(p)
grads.append(p.grad)
for p in params_with_grad:
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
state['mu_product'] = 1.
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg.append(state['exp_avg'])
exp_avg_sq.append(state['exp_avg_sq'])
state['step'] += 1
states.append(state)
mu = beta1 * (1. - 0.5 * (0.96 ** (state['step'] * group['momentum_decay'])))
state['mu_product'] *= mu
mu_products.append(state['mu_product'])
F.nadam(params_with_grad,
grads,
exp_avg,
exp_avg_sq,
mu_products,
states,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
momentum_decay=group['momentum_decay'],
eps=group['eps'])
return loss
# TODO: refactor to a base class once foreach ops are in a good shape.
def zero_grad(self, set_to_none: bool = False):
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
if p.grad.is_sparse:
p.grad.zero_()
else:
per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad)
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._foreach_zero_(grads)

View File

@ -0,0 +1,5 @@
from typing import Tuple
from ..optimizer import _params_t, Optimizer
class NAdam(Optimizer):
def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=..., momentum_decay: float=...) -> None: ...