r"""Functional interface""" import math import torch from torch import Tensor from typing import List, Optional # TODO: use foreach API in optim._functional to do all the computation def _make_sparse(grad, grad_indices, values): size = grad.size() if grad_indices.numel() == 0 or values.numel() == 0: return torch.empty_like(grad) return torch.sparse_coo_tensor(grad_indices, values, size) def adagrad(params: List[Tensor], grads: List[Tensor], state_sums: List[Tensor], state_steps: List[int], *, lr: float, weight_decay: float, lr_decay: float, eps: float): r"""Functional API that performs Adagrad algorithm computation. See :class:`~torch.optim.Adagrad` for details. """ for (param, grad, state_sum, step) in zip(params, grads, state_sums, state_steps): if weight_decay != 0: if grad.is_sparse: raise RuntimeError("weight_decay option is not compatible with sparse gradients") grad = grad.add(param, alpha=weight_decay) clr = lr / (1 + (step - 1) * lr_decay) if grad.is_sparse: grad = grad.coalesce() # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() size = grad.size() state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) std = state_sum.sparse_mask(grad) std_values = std._values().sqrt_().add_(eps) param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr) else: state_sum.addcmul_(grad, grad, value=1) std = state_sum.sqrt().add_(eps) param.addcdiv_(grad, std, value=-clr) def adam(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[int], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float): r"""Functional API that performs Adam algorithm computation. See :class:`~torch.optim.Adam` for details. """ for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) step_size = lr / bias_correction1 param.addcdiv_(exp_avg, denom, value=-step_size) def adamw(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[int], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float): r"""Functional API that performs AdamW algorithm computation. See :class:`~torch.optim.AdamW` for details. """ for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] # Perform stepweight decay param.mul_(1 - lr * weight_decay) bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) step_size = lr / bias_correction1 param.addcdiv_(exp_avg, denom, value=-step_size) def sgd(params: List[Tensor], d_p_list: List[Tensor], momentum_buffer_list: List[Optional[Tensor]], *, weight_decay: float, momentum: float, lr: float, dampening: float, nesterov: bool): r"""Functional API that performs SGD algorithm computation. See :class:`~torch.optim.SGD` for details. """ for i, param in enumerate(params): d_p = d_p_list[i] if weight_decay != 0: d_p = d_p.add(param, alpha=weight_decay) if momentum != 0: buf = momentum_buffer_list[i] if buf is None: buf = torch.clone(d_p).detach() momentum_buffer_list[i] = buf else: buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: d_p = d_p.add(buf, alpha=momentum) else: d_p = buf param.add_(d_p, alpha=-lr) def adadelta(params: List[Tensor], grads: List[Tensor], square_avgs: List[Tensor], acc_deltas: List[Tensor], *, lr: float, rho: float, eps: float, weight_decay: float): r"""Functional API that performs Adadelta algorithm computation. See :class:`~torch.optim.Adadelta` for details. """ for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas): if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) std = square_avg.add(eps).sqrt_() delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad) param.add_(delta, alpha=-lr) acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) def rmsprop(params: List[Tensor], grads: List[Tensor], square_avgs: List[Tensor], grad_avgs: List[Tensor], momentum_buffer_list: List[Tensor], *, lr: float, alpha: float, eps: float, weight_decay: float, momentum: float, centered: bool): r"""Functional API that performs rmsprop algorithm computation. See :class:`~torch.optim.RMSProp` for details. """ for i, param in enumerate(params): grad = grads[i] square_avg = square_avgs[i] if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) if centered: grad_avg = grad_avgs[i] grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps) else: avg = square_avg.sqrt().add_(eps) if momentum > 0: buf = momentum_buffer_list[i] buf.mul_(momentum).addcdiv_(grad, avg) param.add_(buf, alpha=-lr) else: param.addcdiv_(grad, avg, value=-lr) def rprop(params: List[Tensor], grads: List[Tensor], prevs: List[Tensor], step_sizes: List[Tensor], *, step_size_min: float, step_size_max: float, etaminus: float, etaplus: float): r"""Functional API that performs rprop algorithm computation. See :class:`~torch.optim.Rprop` for details. """ for i, param in enumerate(params): grad = grads[i] prev = prevs[i] step_size = step_sizes[i] sign = grad.mul(prev).sign() sign[sign.gt(0)] = etaplus sign[sign.lt(0)] = etaminus sign[sign.eq(0)] = 1 # update stepsizes with step size updates step_size.mul_(sign).clamp_(step_size_min, step_size_max) # for dir<0, dfdx=0 # for dir>=0 dfdx=dfdx grad = grad.clone(memory_format=torch.preserve_format) grad[sign.eq(etaminus)] = 0 # update parameters param.addcmul_(grad.sign(), step_size, value=-1) prev.copy_(grad) def adamax(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_infs: List[Tensor], state_steps: List[int], *, eps: float, beta1: float, beta2: float, lr: float, weight_decay: float): r"""Functional API that performs adamax algorithm computation. See :class:`~torch.optim.Adamax` for details. """ for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_inf = exp_infs[i] step = state_steps[i] if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) # Update biased first moment estimate. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # Update the exponentially weighted infinity norm. norm_buf = torch.cat([ exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0) ], 0) torch.amax(norm_buf, 0, keepdim=False, out=exp_inf) bias_correction = 1 - beta1 ** step clr = lr / bias_correction param.addcdiv_(exp_avg, exp_inf, value=-clr) def asgd(params: List[Tensor], grads: List[Tensor], axs: List[Tensor], mus: List[float], etas: List[float], *, weight_decay: float, lambd: float): r"""Functional API that performs asgd algorithm computation. See :class:`~torch.optim.ASGD` for details. """ for i, param in enumerate(params): grad = grads[i] mu = mus[i] ax = axs[i] eta = etas[i] if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) # decay term param.mul_(1 - lambd * eta) # update parameter param.add_(grad, alpha=-eta) # averaging if mu != 1: ax.add_(param.sub(ax).mul(mu)) else: ax.copy_(param) def nadam(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], mu_products: List[float], state_steps: List[int], *, beta1: float, beta2: float, lr: float, weight_decay: float, momentum_decay: float, eps: float): r"""Functional API that performs NAdam algorithm computation. See :class:`~torch.optim.NAdam` for details. """ for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] mu_product = mu_products[i] step = state_steps[i] bias_correction2 = 1 - beta2 ** step if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) # calculate the momentum cache \mu^{t} and \mu^{t+1} mu = beta1 * (1. - 0.5 * (0.96 ** (step * momentum_decay))) mu_next = beta1 * (1. - 0.5 * (0.96 ** ((step + 1) * momentum_decay))) mu_product = mu_product * mu mu_product_next = mu_product * mu * mu_next # decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = exp_avg_sq.div(bias_correction2).sqrt().add_(eps) param.addcdiv_(grad, denom, value=-lr * (1. - mu) / (1. - mu_product)) param.addcdiv_(exp_avg, denom, value=-lr * mu_next / (1. - mu_product_next)) def radam(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], state_steps: List[int], *, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float): r"""Functional API that performs RAdam algorithm computation. See :class:`~torch.optim.RAdam` for details. """ for i, param in enumerate(params): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # correcting bias for the first moving moment bias_corrected_exp_avg = exp_avg / bias_correction1 # maximum length of the approximated SMA rho_inf = 2 / (1 - beta2) - 1 # compute the length of the approximated SMA rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2 if rho_t > 5.: # Compute the variance rectification term and update parameters accordingly rect = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t)) adaptive_lr = math.sqrt(bias_correction2) / exp_avg_sq.sqrt().add_(eps) param.add_(bias_corrected_exp_avg * lr * adaptive_lr * rect, alpha=-1.0) else: param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)