diff --git a/test/test_optim.py b/test/test_optim.py index d764601083d..7488e4fac34 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -318,6 +318,10 @@ class TestOptim(TestCase): ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=1., amsgrad=False)), ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=True)), ((optim.AdamW, optim._multi_tensor.AdamW), dict(weight_decay=0., amsgrad=False)), + ((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0., momentum_decay=6e-3)), + ((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=1., momentum_decay=6e-3)), + ((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0., momentum_decay=4e-3)), + ((optim.NAdam, optim._multi_tensor.NAdam), dict(weight_decay=0.01, momentum_decay=4e-3)), ((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True)), ((optim.SGD, optim._multi_tensor.SGD), dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False)), ((optim.RAdam, optim._multi_tensor.RAdam), dict(weight_decay=0)), @@ -378,7 +382,7 @@ class TestOptim(TestCase): res.append(model.parameters()) for p1, p2 in zip(res[0], res[1]): - self.assertEqual(p1, p2) + self.assertEqual(p1, p2, atol=5e-5, rtol=0) def test_adam(self): for optimizer in [optim.Adam, optim_mt.Adam]: @@ -481,25 +485,26 @@ class TestOptim(TestCase): optimizer(None, lr=1e-2, rho=1.1) def test_nadam(self): - self._test_basic_cases( - lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.NAdam( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3) - ) - self._test_basic_cases( - lambda weight, bias: optim.NAdam([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3), - [lambda opt: ExponentialLR(opt, gamma=0.9)] - ) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): - optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0)) - with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"): - optim.NAdam(None, lr=1e-2, momentum_decay=-0.2) + for optimizer in [optim.NAdam, optim_mt.NAdam]: + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3) + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3), + [lambda opt: ExponentialLR(opt, gamma=0.9)] + ) + with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): + optimizer(None, lr=1e-2, betas=(1.0, 0.0)) + with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"): + optimizer(None, lr=1e-2, momentum_decay=-0.2) def test_adagrad(self): self._test_basic_cases( diff --git a/torch/optim/_multi_tensor/__init__.py b/torch/optim/_multi_tensor/__init__.py index 9962b1d84e7..4600a1cc1f1 100644 --- a/torch/optim/_multi_tensor/__init__.py +++ b/torch/optim/_multi_tensor/__init__.py @@ -7,6 +7,7 @@ future. from .adam import Adam from .adamw import AdamW +from .nadam import NAdam from .sgd import SGD from .radam import RAdam as RAdam from .rmsprop import RMSprop @@ -19,6 +20,7 @@ del adam del adamw del sgd del radam +del nadam del rmsprop del rprop del asgd diff --git a/torch/optim/_multi_tensor/__init__.pyi b/torch/optim/_multi_tensor/__init__.pyi index 51d3cbe470b..58d521bcfe9 100644 --- a/torch/optim/_multi_tensor/__init__.pyi +++ b/torch/optim/_multi_tensor/__init__.pyi @@ -1,5 +1,6 @@ from .adam import Adam as Adam from .adamw import AdamW as AdamW +from .nadam import NAdam as NAdam from .sgd import SGD as SGD from .radam import RAdam as RAdam from .rmsprop import RMSprop as RMSprop diff --git a/torch/optim/_multi_tensor/_functional.py b/torch/optim/_multi_tensor/_functional.py index 988906ba939..0b1d0b23d55 100644 --- a/torch/optim/_multi_tensor/_functional.py +++ b/torch/optim/_multi_tensor/_functional.py @@ -123,9 +123,9 @@ def radam(params: List[Tensor], lr: float, weight_decay: float, eps: float): - r"""Functional API that performs Adam algorithm computation. + r"""Functional API that performs RAdam algorithm computation. - See :class:`~torch.optim.Adam` for details. + See :class:`~torch.optim.RAdam` for details. """ # maximum length of the approximated SMA @@ -158,3 +158,49 @@ def radam(params: List[Tensor], denom = [torch.ones_like(exp_av, memory_format=torch.preserve_format) for exp_av in exp_avg] step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)] torch._foreach_addcdiv_(params, exp_avg, denom, step_size) + + +def nadam(params: List[Tensor], + grads: List[Tensor], + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + mu_products: List[Tensor], + states: List[Dict], + *, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + momentum_decay: float, + eps: float): + r"""Functional API that performs NAdam algorithm computation. + + See :class:`~torch.optim.NAdam` for details. + """ + + bias_correction1 = [1 - beta1 ** state['step'] for state in states] + bias_correction2 = [1 - beta2 ** state['step'] for state in states] + mus = [beta1 * (1. - 0.5 * (0.96 ** (state['step'] * momentum_decay))) for state in states] + mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((state['step'] + 1) * momentum_decay))) + for state in states] + if weight_decay != 0: + torch._foreach_add_(grads, params, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + torch._foreach_mul_(exp_avg, beta1) + torch._foreach_add_(exp_avg, grads, alpha=1 - beta1) + + torch._foreach_mul_(exp_avg_sq, beta2) + torch._foreach_addcmul_(exp_avg_sq, grads, grads, 1 - beta2) + + exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sq) + bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2] + torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt) + denom = torch._foreach_add(exp_avg_sq_sqrt, eps) + + step_size_grads = [(lr * (1. - mu) / (1. - mu_product)) * -1 + for mu_product, mu in zip(mu_products, mus)] + step_size_expavg = [(lr * mu_next / (1. - mu_product * mu_next)) * -1 + for mu_product, mu_next in zip(mu_products, mu_nexts)] + torch._foreach_addcdiv_(params, grads, denom, step_size_grads) + torch._foreach_addcdiv_(params, exp_avg, denom, step_size_expavg) diff --git a/torch/optim/_multi_tensor/nadam.py b/torch/optim/_multi_tensor/nadam.py new file mode 100644 index 00000000000..6416581be69 --- /dev/null +++ b/torch/optim/_multi_tensor/nadam.py @@ -0,0 +1,130 @@ +import torch +from . import _functional as F +from ..optimizer import Optimizer +from collections import defaultdict + +class NAdam(Optimizer): + r"""Implements NAdam algorithm with multi tensor APIs. + + It has been proposed in `Incorporating Nesterov Momentum into Adam`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + momentum_decay (float, optional): momentum momentum_decay (default: 4e-3) + + .. _Incorporating Nesterov Momentum into Adam: + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, momentum_decay=4e-3): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= momentum_decay: + raise ValueError("Invalid momentum_decay value: {}".format(momentum_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, momentum_decay=momentum_decay) + super(NAdam, self).__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avg = [] + exp_avg_sq = [] + mu_products = [] + states = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is not None: + if p.grad.is_sparse: + raise RuntimeError('NAdam does not support sparse gradients') + params_with_grad.append(p) + grads.append(p.grad) + + for p in params_with_grad: + state = self.state[p] + + # Lazy state initialization + if len(state) == 0: + state['step'] = 0 + state['mu_product'] = 1. + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg.append(state['exp_avg']) + exp_avg_sq.append(state['exp_avg_sq']) + + state['step'] += 1 + states.append(state) + + mu = beta1 * (1. - 0.5 * (0.96 ** (state['step'] * group['momentum_decay']))) + state['mu_product'] *= mu + mu_products.append(state['mu_product']) + + F.nadam(params_with_grad, + grads, + exp_avg, + exp_avg_sq, + mu_products, + states, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + momentum_decay=group['momentum_decay'], + eps=group['eps']) + + return loss + + # TODO: refactor to a base class once foreach ops are in a good shape. + def zero_grad(self, set_to_none: bool = False): + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for group in self.param_groups: + for p in group['params']: + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + + if p.grad.is_sparse: + p.grad.zero_() + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_zero_(grads) diff --git a/torch/optim/_multi_tensor/nadam.pyi b/torch/optim/_multi_tensor/nadam.pyi new file mode 100644 index 00000000000..6513a66b3c6 --- /dev/null +++ b/torch/optim/_multi_tensor/nadam.pyi @@ -0,0 +1,5 @@ +from typing import Tuple +from ..optimizer import _params_t, Optimizer + +class NAdam(Optimizer): + def __init__(self, params: _params_t, lr: float=..., betas: Tuple[float, float]=..., eps: float=..., weight_decay: float=..., momentum_decay: float=...) -> None: ...