mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Solves the next most important use case in https://github.com/pytorch/pytorch/issues/68052. I have kept the style as close to that in SGD as seemed reasonable, given the slight differences in their internal implementations. All feedback welcome! cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/68164 Reviewed By: VitalyFedyunin Differential Revision: D32994129 Pulled By: albanD fbshipit-source-id: 65c57c3f3dbbd3e3e5338d51def54482503e8850
154 lines
7.8 KiB
Python
154 lines
7.8 KiB
Python
import torch
|
|
from . import _functional as F
|
|
from .optimizer import Optimizer
|
|
|
|
|
|
class Adam(Optimizer):
|
|
r"""Implements Adam algorithm.
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
&\rule{110mm}{0.4pt} \\
|
|
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
|
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
|
|
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},\: \\
|
|
\textit{maximize} \\
|
|
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
|
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
|
|
&\rule{110mm}{0.4pt} \\
|
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
|
|
&\hspace{5mm} /textbf{if} \: \textit{maximize}: \\
|
|
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
&\hspace{5mm} /textbf{else} \\
|
|
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
|
|
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
|
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
|
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
|
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
|
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
|
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
|
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
|
\widehat{v_t}) \\
|
|
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
|
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
|
&\hspace{5mm}\textbf{else} \\
|
|
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
|
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
|
&\rule{110mm}{0.4pt} \\[-1.ex]
|
|
&\bf{return} \: \theta_t \\[-1.ex]
|
|
&\rule{110mm}{0.4pt} \\[-1.ex]
|
|
\end{aligned}
|
|
|
|
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
|
|
|
|
Args:
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
parameter groups
|
|
lr (float, optional): learning rate (default: 1e-3)
|
|
betas (Tuple[float, float], optional): coefficients used for computing
|
|
running averages of gradient and its square (default: (0.9, 0.999))
|
|
eps (float, optional): term added to the denominator to improve
|
|
numerical stability (default: 1e-8)
|
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
|
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
|
(default: False)
|
|
maximize (bool, optional): maximize the params based on the objective, instead of
|
|
minimizing (default: False)
|
|
|
|
.. _Adam\: A Method for Stochastic Optimization:
|
|
https://arxiv.org/abs/1412.6980
|
|
.. _On the Convergence of Adam and Beyond:
|
|
https://openreview.net/forum?id=ryQu7f-RZ
|
|
"""
|
|
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
|
weight_decay=0, amsgrad=False, *, maximize: bool = False):
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
if not 0.0 <= weight_decay:
|
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
|
weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
|
|
super(Adam, self).__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super(Adam, self).__setstate__(state)
|
|
for group in self.param_groups:
|
|
group.setdefault('amsgrad', False)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Args:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
params_with_grad = []
|
|
grads = []
|
|
exp_avgs = []
|
|
exp_avg_sqs = []
|
|
max_exp_avg_sqs = []
|
|
state_steps = []
|
|
beta1, beta2 = group['betas']
|
|
|
|
for p in group['params']:
|
|
if p.grad is not None:
|
|
params_with_grad.append(p)
|
|
if p.grad.is_sparse:
|
|
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
|
grads.append(p.grad)
|
|
|
|
state = self.state[p]
|
|
# Lazy state initialization
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
# Exponential moving average of gradient values
|
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
# Exponential moving average of squared gradient values
|
|
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
if group['amsgrad']:
|
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
|
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
exp_avgs.append(state['exp_avg'])
|
|
exp_avg_sqs.append(state['exp_avg_sq'])
|
|
|
|
if group['amsgrad']:
|
|
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
|
|
|
# update the steps for each param group update
|
|
state['step'] += 1
|
|
# record the step after step update
|
|
state_steps.append(state['step'])
|
|
|
|
F.adam(params_with_grad,
|
|
grads,
|
|
exp_avgs,
|
|
exp_avg_sqs,
|
|
max_exp_avg_sqs,
|
|
state_steps,
|
|
amsgrad=group['amsgrad'],
|
|
beta1=beta1,
|
|
beta2=beta2,
|
|
lr=group['lr'],
|
|
weight_decay=group['weight_decay'],
|
|
eps=group['eps'],
|
|
maximize=group['maximize'])
|
|
return loss
|