mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
To add Nesterov Adam algorithm for multi-tensor optimizers API (#59165)
Summary:
Previously in the PR: https://github.com/pytorch/pytorch/issues/59009 we added NAdam to Optimizers. Here in this PR we are proposing multi-tensor version of NAdam for PyTorch.
Nadam has been proposed in the paper https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ and report and report : http://cs229.stanford.edu/proj2015/054_report.pdf by Timothy Dozat.
It has been one of the most used algorithm in Deep Learning community.
It worth to noting that the implementation of NAdam is inspired by the implementation for Keras :
f9d3868495/tensorflow/python/keras/optimizer_v2/nadam.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59165
Reviewed By: vincentqb
Differential Revision: D29360577
Pulled By: iramazanli
fbshipit-source-id: 0fe14016303b2df2cb8cc31912a2674acf63d1e5
This commit is contained in:
parent
3bfe15085d
commit
f0e972a481
|
|
@ -318,6 +318,10 @@ class TestOptim(TestCase):
|
|||
((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=1., amsgrad=False)),
|
||||
((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=True)),
|
||||
((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=False)),
|
||||
((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0., momentum_decay=6e-3)),
|
||||
((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=1., momentum_decay=6e-3)),
|
||||
((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0., momentum_decay=4e-3)),
|
||||
((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0.01, momentum_decay=4e-3)),
|
||||
((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True)),
|
||||
((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False)),
|
||||
((optim.RAdam, optim._multi_tensor.RAdam), dict(weight_decay=0)),
|
||||
|
|
@ -378,7 +382,7 @@ class TestOptim(TestCase):
|
|||
res.append(model.parameters())
|
||||
|
||||
for p1, p2 in zip(res[0], res[1]):
|
||||
self.assertEqual(p1, p2)
|
||||
self.assertEqual(p1, p2, atol=5e-5, rtol=0)
|
||||
|
||||
def test_adam(self):
|
||||
for optimizer in [optim.Adam, optim_mt.Adam]:
|
||||
|
|
@ -481,25 +485,26 @@ class TestOptim(TestCase):
|
|||
optimizer(None, lr=1e-2, rho=1.1)
|
||||
|
||||
def test_nadam(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.NAdam(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3),
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9)]
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
|
||||
optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0))
|
||||
with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
|
||||
optim.NAdam(None, lr=1e-2, momentum_decay=-0.2)
|
||||
for optimizer in [optim.NAdam, optim_mt.NAdam]:
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3),
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9)]
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
|
||||
optimizer(None, lr=1e-2, betas=(1.0, 0.0))
|
||||
with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
|
||||
optimizer(None, lr=1e-2, momentum_decay=-0.2)
|
||||
|
||||
def test_adagrad(self):
|
||||
self._test_basic_cases(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ future.
|
|||
|
||||
from .adam import Adam
|
||||
from .adamw import AdamW
|
||||
from .nadam import NAdam
|
||||
from .sgd import SGD
|
||||
from .radam import RAdam as RAdam
|
||||
from .rmsprop import RMSprop
|
||||
|
|
@ -19,6 +20,7 @@ del adam
|
|||
del adamw
|
||||
del sgd
|
||||
del radam
|
||||
del nadam
|
||||
del rmsprop
|
||||
del rprop
|
||||
del asgd
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .adam import Adam as Adam
|
||||
from .adamw import AdamW as AdamW
|
||||
from .nadam import NAdam as NAdam
|
||||
from .sgd import SGD as SGD
|
||||
from .radam import RAdam as RAdam
|
||||
from .rmsprop import RMSprop as RMSprop
|
||||
|
|
|
|||
|
|
@ -123,9 +123,9 @@ def radam(params: List[Tensor],
|
|||
lr: float,
|
||||
weight_decay: float,
|
||||
eps: float):
|
||||
r"""Functional API that performs Adam algorithm computation.
|
||||
r"""Functional API that performs RAdam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adam` for details.
|
||||
See :class:`~torch.optim.RAdam` for details.
|
||||
"""
|
||||
|
||||
# maximum length of the approximated SMA
|
||||
|
|
@ -158,3 +158,49 @@ def radam(params: List[Tensor],
|
|||
denom = [torch.ones_like(exp_av, memory_format=torch.preserve_format) for exp_av in exp_avg]
|
||||
step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]
|
||||
torch._foreach_addcdiv_(params, exp_avg, denom, step_size)
|
||||
|
||||
|
||||
def nadam(params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avg: List[Tensor],
|
||||
exp_avg_sq: List[Tensor],
|
||||
mu_products: List[Tensor],
|
||||
states: List[Dict],
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
momentum_decay: float,
|
||||
eps: float):
|
||||
r"""Functional API that performs NAdam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.NAdam` for details.
|
||||
"""
|
||||
|
||||
bias_correction1 = [1 - beta1 ** state['step'] for state in states]
|
||||
bias_correction2 = [1 - beta2 ** state['step'] for state in states]
|
||||
mus = [beta1 * (1. - 0.5 * (0.96 ** (state['step'] * momentum_decay))) for state in states]
|
||||
mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((state['step'] + 1) * momentum_decay)))
|
||||
for state in states]
|
||||
if weight_decay != 0:
|
||||
torch._foreach_add_(grads, params, alpha=weight_decay)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
torch._foreach_mul_(exp_avg, beta1)
|
||||
torch._foreach_add_(exp_avg, grads, alpha=1 - beta1)
|
||||
|
||||
torch._foreach_mul_(exp_avg_sq, beta2)
|
||||
torch._foreach_addcmul_(exp_avg_sq, grads, grads, 1 - beta2)
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sq)
|
||||
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
|
||||
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
|
||||
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
|
||||
|
||||
step_size_grads = [(lr * (1. - mu) / (1. - mu_product)) * -1
|
||||
for mu_product, mu in zip(mu_products, mus)]
|
||||
step_size_expavg = [(lr * mu_next / (1. - mu_product * mu_next)) * -1
|
||||
for mu_product, mu_next in zip(mu_products, mu_nexts)]
|
||||
torch._foreach_addcdiv_(params, grads, denom, step_size_grads)
|
||||
torch._foreach_addcdiv_(params, exp_avg, denom, step_size_expavg)
|
||||
|
|
|
|||
130
torch/optim/_multi_tensor/nadam.py
Normal file
130
torch/optim/_multi_tensor/nadam.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import torch
|
||||
from . import _functional as F
|
||||
from ..optimizer import Optimizer
|
||||
from collections import defaultdict
|
||||
|
||||
class NAdam(Optimizer):
|
||||
r"""Implements NAdam algorithm with multi tensor APIs.
|
||||
|
||||
It has been proposed in `Incorporating Nesterov Momentum into Adam`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 2e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
|
||||
|
||||
.. _Incorporating Nesterov Momentum into Adam:
|
||||
https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, momentum_decay=4e-3):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= momentum_decay:
|
||||
raise ValueError("Invalid momentum_decay value: {}".format(momentum_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, momentum_decay=momentum_decay)
|
||||
super(NAdam, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avg = []
|
||||
exp_avg_sq = []
|
||||
mu_products = []
|
||||
states = []
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('NAdam does not support sparse gradients')
|
||||
params_with_grad.append(p)
|
||||
grads.append(p.grad)
|
||||
|
||||
for p in params_with_grad:
|
||||
state = self.state[p]
|
||||
|
||||
# Lazy state initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['mu_product'] = 1.
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
exp_avg.append(state['exp_avg'])
|
||||
exp_avg_sq.append(state['exp_avg_sq'])
|
||||
|
||||
state['step'] += 1
|
||||
states.append(state)
|
||||
|
||||
mu = beta1 * (1. - 0.5 * (0.96 ** (state['step'] * group['momentum_decay'])))
|
||||
state['mu_product'] *= mu
|
||||
mu_products.append(state['mu_product'])
|
||||
|
||||
F.nadam(params_with_grad,
|
||||
grads,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
mu_products,
|
||||
states,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
momentum_decay=group['momentum_decay'],
|
||||
eps=group['eps'])
|
||||
|
||||
return loss
|
||||
|
||||
# TODO: refactor to a base class once foreach ops are in a good shape.
|
||||
def zero_grad(self, set_to_none: bool = False):
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
if set_to_none:
|
||||
p.grad = None
|
||||
else:
|
||||
if p.grad.grad_fn is not None:
|
||||
p.grad.detach_()
|
||||
else:
|
||||
p.grad.requires_grad_(False)
|
||||
|
||||
if p.grad.is_sparse:
|
||||
p.grad.zero_()
|
||||
else:
|
||||
per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad)
|
||||
|
||||
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
torch._foreach_zero_(grads)
|
||||
5
torch/optim/_multi_tensor/nadam.pyi
Normal file
5
torch/optim/_multi_tensor/nadam.pyi
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from typing import Tuple
|
||||
from ..optimizer import _params_t, Optimizer
|
||||
|
||||
class NAdam(Optimizer):
|
||||
def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=..., momentum_decay: float=...) -> None: ...
|
||||
Loading…
Reference in New Issue
Block a user