From f0e972a481b51ddf9d21d3ba8e8ef0ec2e15cd8a Mon Sep 17 00:00:00 2001 From: Ilqar Ramazanli Date: Sun, 27 Jun 2021 16:57:32 -0700 Subject: [PATCH] To add Nesterov Adam algorithm for multi-tensor optimizers API (#59165) Summary: Previously in the PR: https://github.com/pytorch/pytorch/issues/59009 we added NAdam to Optimizers. Here in this PR we are proposing multi-tensor version of NAdam for PyTorch. Nadam has been proposed in the paper https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ and report and report : http://cs229.stanford.edu/proj2015/054_report.pdf by Timothy Dozat. It has been one of the most used algorithm in Deep Learning community. It worth to noting that the implementation of NAdam is inspired by the implementation for Keras : https://github.com/tensorflow/tensorflow/blob/f9d386849581d15d72f6f1f96f12aac230a8edbe/tensorflow/python/keras/optimizer_v2/nadam.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/59165 Reviewed By: vincentqb Differential Revision: D29360577 Pulled By: iramazanli fbshipit-source-id: 0fe14016303b2df2cb8cc31912a2674acf63d1e5 --- test/test_optim.py | 45 ++++---- torch/optim/_multi_tensor/__init__.py | 2 + torch/optim/_multi_tensor/__init__.pyi | 1 + torch/optim/_multi_tensor/_functional.py | 50 ++++++++- torch/optim/_multi_tensor/nadam.py | 130 +++++++++++++++++++++++ torch/optim/_multi_tensor/nadam.pyi | 5 + 6 files changed, 211 insertions(+), 22 deletions(-) create mode 100644 torch/optim/_multi_tensor/nadam.py create mode 100644 torch/optim/_multi_tensor/nadam.pyi diff --git a/test/test_optim.py b/test/test_optim.py index d764601083d..7488e4fac34 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -318,6 +318,10 @@ class TestOptim(TestCase): ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=1., amsgrad=False)), ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=True)), ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=False)), + ((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0., momentum_decay=6e-3)), + ((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=1., momentum_decay=6e-3)), + ((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0., momentum_decay=4e-3)), + ((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0.01, momentum_decay=4e-3)), ((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True)), ((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False)), ((optim.RAdam, optim._multi_tensor.RAdam), dict(weight_decay=0)), @@ -378,7 +382,7 @@ class TestOptim(TestCase): res.append(model.parameters()) for p1, p2 in zip(res[0], res[1]): - self.assertEqual(p1, p2) + self.assertEqual(p1, p2, atol=5e-5, rtol=0) def test_adam(self): for optimizer in [optim.Adam, optim_mt.Adam]: @@ -481,25 +485,26 @@ class TestOptim(TestCase): optimizer(None, lr=1e-2, rho=1.1) def test_nadam(self): - self._test_basic_cases( - lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.NAdam( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3), - [lambda opt: ExponentialLR(opt, gamma=0.9)] - ) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): - optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0)) - with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"): - optim.NAdam(None, lr=1e-2, momentum_decay=-0.2) + for optimizer in [optim.NAdam, optim_mt.NAdam]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3), + [lambda opt: ExponentialLR(opt, gamma=0.9)] + ) + with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): + optimizer(None, lr=1e-2, betas=(1.0, 0.0)) + with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"): + optimizer(None, lr=1e-2, momentum_decay=-0.2) def test_adagrad(self): self._test_basic_cases( diff --git a/torch/optim/_multi_tensor/__init__.py b/torch/optim/_multi_tensor/__init__.py index 9962b1d84e7..4600a1cc1f1 100644 --- a/torch/optim/_multi_tensor/__init__.py +++ b/torch/optim/_multi_tensor/__init__.py @@ -7,6 +7,7 @@ future. from .adam import Adam from .adamw import AdamW +from .nadam import NAdam from .sgd import SGD from .radam import RAdam as RAdam from .rmsprop import RMSprop @@ -19,6 +20,7 @@ del adam del adamw del sgd del radam +del nadam del rmsprop del rprop del asgd diff --git a/torch/optim/_multi_tensor/__init__.pyi b/torch/optim/_multi_tensor/__init__.pyi index 51d3cbe470b..58d521bcfe9 100644 --- a/torch/optim/_multi_tensor/__init__.pyi +++ b/torch/optim/_multi_tensor/__init__.pyi @@ -1,5 +1,6 @@ from .adam import Adam as Adam from .adamw import AdamW as AdamW +from .nadam import NAdam as NAdam from .sgd import SGD as SGD from .radam import RAdam as RAdam from .rmsprop import RMSprop as RMSprop diff --git a/torch/optim/_multi_tensor/_functional.py b/torch/optim/_multi_tensor/_functional.py index 988906ba939..0b1d0b23d55 100644 --- a/torch/optim/_multi_tensor/_functional.py +++ b/torch/optim/_multi_tensor/_functional.py @@ -123,9 +123,9 @@ def radam(params: List[Tensor], lr: float, weight_decay: float, eps: float): - r"""Functional API that performs Adam algorithm computation. + r"""Functional API that performs RAdam algorithm computation. - See :class:`~torch.optim.Adam` for details. + See :class:`~torch.optim.RAdam` for details. """ # maximum length of the approximated SMA @@ -158,3 +158,49 @@ def radam(params: List[Tensor], denom = [torch.ones_like(exp_av, memory_format=torch.preserve_format) for exp_av in exp_avg] step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)] torch._foreach_addcdiv_(params, exp_avg, denom, step_size) + + +def nadam(params: List[Tensor], + grads: List[Tensor], + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + mu_products: List[Tensor], + states: List[Dict], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float): + r"""Functional API that performs NAdam algorithm computation. + + See :class:`~torch.optim.NAdam` for details. + """ + + bias_correction1 = [1 - beta1 ** state['step'] for state in states] + bias_correction2 = [1 - beta2 ** state['step'] for state in states] + mus = [beta1 * (1. - 0.5 * (0.96 ** (state['step'] * momentum_decay))) for state in states] + mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((state['step'] + 1) * momentum_decay))) + for state in states] + if weight_decay != 0: + torch._foreach_add_(grads, params, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + torch._foreach_mul_(exp_avg, beta1) + torch._foreach_add_(exp_avg, grads, alpha=1 - beta1) + + torch._foreach_mul_(exp_avg_sq, beta2) + torch._foreach_addcmul_(exp_avg_sq, grads, grads, 1 - beta2) + + exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sq) + bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2] + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) + denom = torch._foreach_add(exp_avg_sq_sqrt, eps) + + step_size_grads = [(lr * (1. - mu) / (1. - mu_product)) * -1 + for mu_product, mu in zip(mu_products, mus)] + step_size_expavg = [(lr * mu_next / (1. - mu_product * mu_next)) * -1 + for mu_product, mu_next in zip(mu_products, mu_nexts)] + torch._foreach_addcdiv_(params, grads, denom, step_size_grads) + torch._foreach_addcdiv_(params, exp_avg, denom, step_size_expavg) diff --git a/torch/optim/_multi_tensor/nadam.py b/torch/optim/_multi_tensor/nadam.py new file mode 100644 index 00000000000..6416581be69 --- /dev/null +++ b/torch/optim/_multi_tensor/nadam.py @@ -0,0 +1,130 @@ +import torch +from . import _functional as F +from ..optimizer import Optimizer +from collections import defaultdict + +class NAdam(Optimizer): + r"""Implements NAdam algorithm with multi tensor APIs. + + It has been proposed in `Incorporating Nesterov Momentum into Adam`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + momentum_decay (float, optional): momentum momentum_decay (default: 4e-3) + + .. _Incorporating Nesterov Momentum into Adam: + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, momentum_decay=4e-3): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= momentum_decay: + raise ValueError("Invalid momentum_decay value: {}".format(momentum_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, momentum_decay=momentum_decay) + super(NAdam, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avg = [] + exp_avg_sq = [] + mu_products = [] + states = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('NAdam does not support sparse gradients') + params_with_grad.append(p) + grads.append(p.grad) + + for p in params_with_grad: + state = self.state[p] + + # Lazy state initialization + if len(state) == 0: + state['step'] = 0 + state['mu_product'] = 1. + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg.append(state['exp_avg']) + exp_avg_sq.append(state['exp_avg_sq']) + + state['step'] += 1 + states.append(state) + + mu = beta1 * (1. - 0.5 * (0.96 ** (state['step'] * group['momentum_decay']))) + state['mu_product'] *= mu + mu_products.append(state['mu_product']) + + F.nadam(params_with_grad, + grads, + exp_avg, + exp_avg_sq, + mu_products, + states, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + momentum_decay=group['momentum_decay'], + eps=group['eps']) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/nadam.pyi b/torch/optim/_multi_tensor/nadam.pyi new file mode 100644 index 00000000000..6513a66b3c6 --- /dev/null +++ b/torch/optim/_multi_tensor/nadam.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class NAdam(Optimizer): + def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=..., momentum_decay: float=...) -> None: ...