mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[BE][optim] abstract out docstrings, add differentiable docs (#92336)
1. abstract out common doc strings --> I'm sure there are more, but let this be a first step. 2. Add differentiable docs to those who are actually differentiable Pull Request resolved: https://github.com/pytorch/pytorch/pull/92336 Approved by: https://github.com/albanD
This commit is contained in:
parent
0035340488
commit
0070c546b5
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _default_to_foreach
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_foreach,
|
||||
_differentiable_doc, _foreach_doc, _maximize_doc)
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
from typing import List, Optional
|
||||
|
||||
|
|
@ -9,56 +10,6 @@ __all__ = ["Adadelta", "adadelta"]
|
|||
|
||||
|
||||
class Adadelta(Optimizer):
|
||||
r"""Implements Adadelta algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
|
||||
\: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
|
||||
\: \lambda \text{ (weight decay)} \\
|
||||
&\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)},
|
||||
\: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}if \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\
|
||||
&\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} +
|
||||
\epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\
|
||||
&\hspace{5mm} u_t \leftarrow u_{t-1} \rho +
|
||||
\Delta x^2_t (1 - \rho) \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
rho (float, optional): coefficient used for computing a running average
|
||||
of squared gradients (default: 0.9)
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-6)
|
||||
lr (float, optional): coefficient that scale delta before it is applied
|
||||
to the parameters (default: 1.0)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer is used.
|
||||
Since the foreach implementation is usually significantly faster than
|
||||
the for-loop implementation on CUDA, we try to use it whenever possible
|
||||
(all parameters are on CUDA). Else, we continue with the for-loop
|
||||
implementation. (default: None)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
|
||||
.. _ADADELTA\: An Adaptive Learning Rate Method:
|
||||
https://arxiv.org/abs/1212.5701
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -171,6 +122,54 @@ class Adadelta(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
Adadelta.__doc__ = r"""Implements Adadelta algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
|
||||
\: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
|
||||
\: \lambda \text{ (weight decay)} \\
|
||||
&\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)},
|
||||
\: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}if \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\
|
||||
&\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} +
|
||||
\epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\
|
||||
&\hspace{5mm} u_t \leftarrow u_{t-1} \rho +
|
||||
\Delta x^2_t (1 - \rho) \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
rho (float, optional): coefficient used for computing a running average
|
||||
of squared gradients (default: 0.9)
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-6)
|
||||
lr (float, optional): coefficient that scale delta before it is applied
|
||||
to the parameters (default: 1.0)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
{foreach}
|
||||
{maximize}
|
||||
{differentiable}
|
||||
|
||||
.. _ADADELTA\: An Adaptive Learning Rate Method:
|
||||
https://arxiv.org/abs/1212.5701
|
||||
|
||||
""".format(foreach=_foreach_doc, maximize=_maximize_doc, differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def adadelta(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -1,55 +1,14 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _get_value
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
|
||||
_differentiable_doc, _maximize_doc)
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ["Adagrad", "adagrad"]
|
||||
|
||||
|
||||
class Adagrad(Optimizer):
|
||||
r"""Implements Adagrad algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
|
||||
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
|
||||
&\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
|
||||
&\textbf{initialize} : state\_sum_0 \leftarrow 0 \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
|
||||
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
|
||||
&\hspace{5mm}\theta_t \leftarrow
|
||||
\theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
|
||||
and Stochastic Optimization`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
lr_decay (float, optional): learning rate decay (default: 0)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-10)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer is used (default: None)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
|
||||
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
|
||||
Optimization: http://jmlr.org/papers/v12/duchi11a.html
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -178,6 +137,51 @@ class Adagrad(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
Adagrad.__doc__ = r"""Implements Adagrad algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
|
||||
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
|
||||
&\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
|
||||
&\textbf{initialize} : state\_sum_0 \leftarrow 0 \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
|
||||
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
|
||||
&\hspace{5mm}\theta_t \leftarrow
|
||||
\theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
|
||||
and Stochastic Optimization`_.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
lr_decay (float, optional): learning rate decay (default: 0)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-10)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{maximize}
|
||||
{differentiable}
|
||||
|
||||
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
|
||||
Optimization: http://jmlr.org/papers/v12/duchi11a.html
|
||||
|
||||
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def adagrad(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from typing import cast, List, Optional, Dict
|
|||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling, _dispatch_sqrt
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
|
||||
_dispatch_sqrt, _capturable_doc, _differentiable_doc, _maximize_doc)
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
__all__ = ['Adam', 'adam']
|
||||
|
|
@ -47,81 +48,6 @@ def _get_fp16AMP_params(
|
|||
return _MultiDeviceReplicator(found_inf_combined)
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""Implements Adam algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
||||
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
|
||||
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
|
||||
\:\textit{maximize} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
|
||||
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
||||
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
||||
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
||||
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
||||
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
||||
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
||||
\widehat{v_t}) \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (bool, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
capturable (bool, optional): whether this instance is safe to capture in a CUDA graph.
|
||||
Passing True can impair ungraphed performance, so if you don't intend to
|
||||
graph capture this instance, leave it False (default: False)
|
||||
differentiable (bool, optional): whether autograd should occur through the optimizer step
|
||||
in training otherwise, the step() function runs in a torch.no_grad() context.
|
||||
Setting to True can impair performance, so leave it False if you don't intend to run
|
||||
autograd through this instance (default: False)
|
||||
fused (bool, optional): whether the fused implementation (CUDA only) is used.
|
||||
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
|
||||
are supported. Since the fused implementation is usually significantly faster than
|
||||
the for-loop implementation, we try to use it whenever possible (all parameters
|
||||
are on CUDA and are of a supported type). Else, we continue with the for-loop
|
||||
implementation. (default: None)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, amsgrad=False, *, foreach: Optional[bool] = None,
|
||||
maximize: bool = False, capturable: bool = False,
|
||||
|
|
@ -285,6 +211,77 @@ class Adam(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
Adam.__doc__ = r"""Implements Adam algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
||||
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
|
||||
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
|
||||
\:\textit{maximize} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
|
||||
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
||||
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
||||
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
||||
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
||||
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
||||
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
||||
\widehat{v_t}) \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (bool, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{maximize}
|
||||
{capturable}
|
||||
{differentiable}
|
||||
fused (bool, optional): whether the fused implementation (CUDA only) is used.
|
||||
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
|
||||
are supported. Since the fused implementation is usually significantly faster than
|
||||
the for-loop implementation, we try to use it whenever possible (all parameters
|
||||
are on CUDA and are of a supported type). Else, we continue with the for-loop
|
||||
implementation. (default: None)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
|
||||
""".format(maximize=_maximize_doc, capturable=_capturable_doc, differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def adam(params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -1,56 +1,14 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
|
||||
_maximize_doc, _differentiable_doc)
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ["Adamax", "adamax"]
|
||||
|
||||
|
||||
class Adamax(Optimizer):
|
||||
r"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
||||
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
|
||||
\: \lambda \text{ (weight decay)}, \\
|
||||
&\hspace{13mm} \epsilon \text{ (epsilon)} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}if \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 2e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer is used (default: None)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -173,6 +131,51 @@ class Adamax(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
Adamax.__doc__ = r"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
||||
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
|
||||
\: \lambda \text{ (weight decay)}, \\
|
||||
&\hspace{13mm} \epsilon \text{ (epsilon)} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}if \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 2e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer is used (default: None)
|
||||
{maximize}
|
||||
{differentiable}
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
|
||||
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def adamax(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
|
||||
_stack_if_compiling, _default_to_foreach)
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
|
||||
_capturable_doc, _differentiable_doc, _foreach_doc, _maximize_doc, _default_to_foreach)
|
||||
from typing import List, Optional
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
|
|
@ -9,73 +9,6 @@ __all__ = ["AdamW", "adamw"]
|
|||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
r"""Implements AdamW algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
|
||||
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
|
||||
\: \epsilon \text{ (epsilon)} \\
|
||||
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
|
||||
\: \textit{maximize} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
|
||||
\text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
|
||||
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
||||
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
||||
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
||||
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
||||
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
||||
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
||||
\widehat{v_t}) \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
||||
amsgrad (bool, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer is used.
|
||||
If unspecified by the user (so foreach is None), we will try to use foreach
|
||||
over the for-loop implementation on CUDA, since it is usually significantly
|
||||
more performant. (default: None)
|
||||
capturable (bool, optional): whether this instance is safe to capture in a CUDA graph.
|
||||
Passing True can impair ungraphed performance, so if you don't intend to
|
||||
graph capture this instance, leave it False (default: False)
|
||||
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -218,6 +151,73 @@ class AdamW(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
AdamW.__doc__ = r"""Implements AdamW algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
|
||||
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
|
||||
\: \epsilon \text{ (epsilon)} \\
|
||||
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
|
||||
\: \textit{maximize} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
|
||||
\text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
|
||||
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
||||
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
||||
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
||||
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
||||
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
||||
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
||||
\widehat{v_t}) \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
||||
amsgrad (bool, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
{maximize}
|
||||
{foreach}
|
||||
{capturable}
|
||||
{differentiable}
|
||||
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
|
||||
""".format(maximize=_maximize_doc,
|
||||
foreach=_foreach_doc,
|
||||
capturable=_capturable_doc,
|
||||
differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _get_value
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
|
||||
_differentiable_doc, _maximize_doc)
|
||||
from torch._utils import is_compiling
|
||||
from typing import List, Optional
|
||||
|
||||
|
|
@ -14,28 +15,6 @@ def _to_tensor(x):
|
|||
return x
|
||||
|
||||
class ASGD(Optimizer):
|
||||
"""Implements Averaged Stochastic Gradient Descent.
|
||||
|
||||
It has been proposed in `Acceleration of stochastic approximation by
|
||||
averaging`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
lambd (float, optional): decay term (default: 1e-4)
|
||||
alpha (float, optional): power for eta update (default: 0.75)
|
||||
t0 (float, optional): point at which to start averaging (default: 1e6)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
|
||||
.. _Acceleration of stochastic approximation by averaging:
|
||||
https://dl.acm.org/citation.cfm?id=131098
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -157,6 +136,30 @@ class ASGD(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
ASGD.__doc__ = r"""Implements Averaged Stochastic Gradient Descent.
|
||||
|
||||
It has been proposed in `Acceleration of stochastic approximation by
|
||||
averaging`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
lambd (float, optional): decay term (default: 1e-4)
|
||||
alpha (float, optional): power for eta update (default: 0.75)
|
||||
t0 (float, optional): point at which to start averaging (default: 1e6)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{maximize}
|
||||
{differentiable}
|
||||
|
||||
.. _Acceleration of stochastic approximation by averaging:
|
||||
https://dl.acm.org/citation.cfm?id=131098
|
||||
|
||||
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def asgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -1,59 +1,12 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
|
||||
_differentiable_doc)
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ['NAdam', 'nadam']
|
||||
|
||||
class NAdam(Optimizer):
|
||||
r"""Implements NAdam algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
|
||||
\: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
|
||||
&\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}if \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\
|
||||
&\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
||||
&\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
|
||||
& \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\
|
||||
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 2e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
|
||||
.. _Incorporating Nesterov Momentum into Adam:
|
||||
https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, momentum_decay=4e-3, *, foreach: Optional[bool] = None,
|
||||
differentiable: bool = False):
|
||||
|
|
@ -153,6 +106,56 @@ class NAdam(Optimizer):
|
|||
|
||||
return loss
|
||||
|
||||
NAdam.__doc__ = r"""Implements NAdam algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
|
||||
\: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
|
||||
&\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}if \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\
|
||||
&\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
||||
&\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
|
||||
& \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\
|
||||
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
||||
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 2e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{differentiable}
|
||||
|
||||
.. _Incorporating Nesterov Momentum into Adam:
|
||||
https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
|
||||
|
||||
""".format(differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def nadam(params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -67,6 +67,27 @@ def _default_to_foreach(tensorlists: List[List[torch.Tensor]], differentiable: b
|
|||
)
|
||||
|
||||
|
||||
# Common doc strings among optimizers
|
||||
_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used. If unspecified by the user (so foreach is None), we will try to use
|
||||
foreach over the for-loop implementation on CUDA, since it is usually
|
||||
significantly more performant. (default: None)"""
|
||||
|
||||
_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
|
||||
capture in a CUDA graph. Passing True can impair ungraphed performance,
|
||||
so if you don't intend to graph capture this instance, leave it False
|
||||
(default: False)"""
|
||||
|
||||
_differentiable_doc = r"""differentiable (bool, optional): whether autograd should
|
||||
occur through the optimizer step in training. Otherwise, the step()
|
||||
function runs in a torch.no_grad() context. Setting to True can impair
|
||||
performance, so leave it False if you don't intend to run autograd
|
||||
through this instance (default: False)"""
|
||||
|
||||
_maximize_doc = r"""maximize (bool, optional): maximize the params based on the
|
||||
objective, instead of minimizing (default: False)"""
|
||||
|
||||
|
||||
def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
|
||||
r"""Register a pre hook common to all optimizers. The hook should have the following
|
||||
signature::
|
||||
|
|
|
|||
|
|
@ -2,71 +2,14 @@ import math
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
|
||||
_differentiable_doc)
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ["RAdam", "radam"]
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
r"""Implements RAdam algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2
|
||||
\text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
|
||||
\lambda \text{ (weightdecay)}, \\
|
||||
&\hspace{13mm} \epsilon \text{ (epsilon)} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
v_0 \leftarrow 0 \text{ ( second moment)}, \\
|
||||
&\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{6mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
||||
&\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
||||
&\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
|
||||
2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex]
|
||||
&\hspace{6mm}\textbf{if} \: \rho_t > 5 \\
|
||||
&\hspace{12mm} l_t \leftarrow \sqrt{ (1-\beta^t_2) / \big( v_t +\epsilon \big) } \\
|
||||
&\hspace{12mm} r_t \leftarrow
|
||||
\sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
|
||||
&\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} r_t l_t \\
|
||||
&\hspace{6mm}\textbf{else} \\
|
||||
&\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
|
||||
|
||||
This implementation uses the same weight_decay implementation as Adam (were the weight_decay is applied
|
||||
to the gradient) and not the one from AdamW (were weight_decay is applied to the update). This
|
||||
is different from the `author's implementation`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
|
||||
.. _On the variance of the adaptive learning rate and beyond:
|
||||
https://arxiv.org/abs/1908.03265
|
||||
.. _author's implementation:
|
||||
https://github.com/LiyuanLucasLiu/RAdam
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -177,6 +120,67 @@ class RAdam(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
RAdam.__doc__ = r"""Implements RAdam algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2
|
||||
\text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
|
||||
\lambda \text{ (weightdecay)}, \\
|
||||
&\hspace{13mm} \epsilon \text{ (epsilon)} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
v_0 \leftarrow 0 \text{ ( second moment)}, \\
|
||||
&\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{6mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
&\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
||||
&\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
||||
&\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
|
||||
2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex]
|
||||
&\hspace{6mm}\textbf{if} \: \rho_t > 5 \\
|
||||
&\hspace{12mm} l_t \leftarrow \sqrt{ (1-\beta^t_2) / \big( v_t +\epsilon \big) } \\
|
||||
&\hspace{12mm} r_t \leftarrow
|
||||
\sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
|
||||
&\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} r_t l_t \\
|
||||
&\hspace{6mm}\textbf{else} \\
|
||||
&\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
|
||||
|
||||
This implementation uses the same weight_decay implementation as Adam (were the weight_decay is applied
|
||||
to the gradient) and not the one from AdamW (were weight_decay is applied to the update). This
|
||||
is different from the `author's implementation`_.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{differentiable}
|
||||
|
||||
.. _On the variance of the adaptive learning rate and beyond:
|
||||
https://arxiv.org/abs/1908.03265
|
||||
.. _author's implementation:
|
||||
https://github.com/LiyuanLucasLiu/RAdam
|
||||
|
||||
""".format(differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def radam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -1,73 +1,12 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _differentiable_doc, _maximize_doc
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ["RMSprop", "rmsprop"]
|
||||
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
r"""Implements RMSprop algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)},
|
||||
\: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
|
||||
&\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\
|
||||
&\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
|
||||
\textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}if \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t
|
||||
\hspace{8mm} \\
|
||||
&\hspace{5mm} \tilde{v_t} \leftarrow v_t \\
|
||||
&\hspace{5mm}if \: centered \\
|
||||
&\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\
|
||||
&\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\
|
||||
&\hspace{5mm}if \: \mu > 0 \\
|
||||
&\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
|
||||
g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\
|
||||
&\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\
|
||||
&\hspace{5mm} else \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} -
|
||||
\gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to
|
||||
`lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
|
||||
and centered version `Generating Sequences
|
||||
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
||||
The implementation here takes the square root of the gradient average before
|
||||
adding epsilon (note that TensorFlow interchanges these two operations). The effective
|
||||
learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
|
||||
is the scheduled learning rate and :math:`v` is the weighted moving average
|
||||
of the squared gradient.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
momentum (float, optional): momentum factor (default: 0)
|
||||
alpha (float, optional): smoothing constant (default: 0.99)
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
centered (bool, optional) : if ``True``, compute the centered RMSProp,
|
||||
the gradient is normalized by an estimation of its variance
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -194,6 +133,68 @@ class RMSprop(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
RMSprop.__doc__ = r"""Implements RMSprop algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)},
|
||||
\: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
|
||||
&\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\
|
||||
&\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
|
||||
\textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}if \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t
|
||||
\hspace{8mm} \\
|
||||
&\hspace{5mm} \tilde{v_t} \leftarrow v_t \\
|
||||
&\hspace{5mm}if \: centered \\
|
||||
&\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\
|
||||
&\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\
|
||||
&\hspace{5mm}if \: \mu > 0 \\
|
||||
&\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
|
||||
g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\
|
||||
&\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\
|
||||
&\hspace{5mm} else \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} -
|
||||
\gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to
|
||||
`lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
|
||||
and centered version `Generating Sequences
|
||||
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
||||
The implementation here takes the square root of the gradient average before
|
||||
adding epsilon (note that TensorFlow interchanges these two operations). The effective
|
||||
learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
|
||||
is the scheduled learning rate and :math:`v` is the weighted moving average
|
||||
of the squared gradient.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
momentum (float, optional): momentum factor (default: 0)
|
||||
alpha (float, optional): smoothing constant (default: 0.99)
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
centered (bool, optional) : if ``True``, compute the centered RMSProp,
|
||||
the gradient is normalized by an estimation of its variance
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{maximize}
|
||||
{differentiable}
|
||||
|
||||
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def rmsprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -1,62 +1,12 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _differentiable_doc, _maximize_doc
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ["Rprop", "rprop"]
|
||||
|
||||
|
||||
class Rprop(Optimizer):
|
||||
r"""Implements the resilient backpropagation algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta)
|
||||
\text{ (objective)}, \\
|
||||
&\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min}
|
||||
\text{ (step sizes)} \\
|
||||
&\textbf{initialize} : g^0_{prev} \leftarrow 0,
|
||||
\: \eta_0 \leftarrow \text{lr (learning rate)} \\
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\
|
||||
&\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\
|
||||
&\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+},
|
||||
\Gamma_{max}) \\
|
||||
&\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\
|
||||
&\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-},
|
||||
\Gamma_{min}) \\
|
||||
&\hspace{15mm} g^i_t \leftarrow 0 \\
|
||||
&\hspace{10mm} \textbf{else} \: \\
|
||||
&\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\
|
||||
&\hspace{5mm}g_{prev} \leftarrow g_t \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to the paper
|
||||
`A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm
|
||||
<http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that
|
||||
are multiplicative increase and decrease factors
|
||||
(default: (0.5, 1.2))
|
||||
step_sizes (Tuple[float, float], optional): a pair of minimal and
|
||||
maximal allowed step sizes (default: (1e-6, 50))
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
|
@ -168,6 +118,57 @@ class Rprop(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
Rprop.__doc__ = r"""Implements the resilient backpropagation algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta)
|
||||
\text{ (objective)}, \\
|
||||
&\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min}
|
||||
\text{ (step sizes)} \\
|
||||
&\textbf{initialize} : g^0_{prev} \leftarrow 0,
|
||||
\: \eta_0 \leftarrow \text{lr (learning rate)} \\
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\
|
||||
&\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\
|
||||
&\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+},
|
||||
\Gamma_{max}) \\
|
||||
&\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\
|
||||
&\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-},
|
||||
\Gamma_{min}) \\
|
||||
&\hspace{15mm} g^i_t \leftarrow 0 \\
|
||||
&\hspace{10mm} \textbf{else} \: \\
|
||||
&\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\
|
||||
&\hspace{5mm}g_{prev} \leftarrow g_t \\
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to the paper
|
||||
`A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm
|
||||
<http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that
|
||||
are multiplicative increase and decrease factors
|
||||
(default: (0.5, 1.2))
|
||||
step_sizes (Tuple[float, float], optional): a pair of minimal and
|
||||
maximal allowed step sizes (default: (1e-6, 50))
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{maximize}
|
||||
{differentiable}
|
||||
|
||||
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc)
|
||||
|
||||
def rprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
|
|||
|
|
@ -1,100 +1,12 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import Optimizer, required, _use_grad_for_differentiable
|
||||
from .optimizer import Optimizer, required, _use_grad_for_differentiable, _differentiable_doc, _maximize_doc
|
||||
from typing import List, Optional
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
__all__ = ['SGD', 'sgd']
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""Implements stochastic gradient descent (optionally with momentum).
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
|
||||
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
|
||||
&\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
|
||||
\:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
|
||||
&\hspace{10mm}\textbf{if} \: t > 1 \\
|
||||
&\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
|
||||
&\hspace{10mm}\textbf{else} \\
|
||||
&\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
|
||||
&\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
|
||||
&\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
|
||||
&\hspace{10mm}\textbf{else} \\[-1.ex]
|
||||
&\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
|
||||
&\hspace{5mm}\textbf{if} \: \textit{maximize} \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
|
||||
&\hspace{5mm}\textbf{else} \\[-1.ex]
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
Nesterov momentum is based on the formula from
|
||||
`On the importance of initialization and momentum in deep learning`__.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float): learning rate
|
||||
momentum (float, optional): momentum factor (default: 0)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
dampening (float, optional): dampening for momentum (default: 0)
|
||||
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> optimizer.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> optimizer.step()
|
||||
|
||||
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
|
||||
|
||||
.. note::
|
||||
The implementation of SGD with Momentum/Nesterov subtly differs from
|
||||
Sutskever et. al. and implementations in some other frameworks.
|
||||
|
||||
Considering the specific case of Momentum, the update can be written as
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
|
||||
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
|
||||
\end{aligned}
|
||||
|
||||
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
|
||||
parameters, gradient, velocity, and momentum respectively.
|
||||
|
||||
This is in contrast to Sutskever et. al. and
|
||||
other frameworks which employ an update of the form
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
|
||||
p_{t+1} & = p_{t} - v_{t+1}.
|
||||
\end{aligned}
|
||||
|
||||
The Nesterov version is analogously modified.
|
||||
|
||||
Moreover, the initial value of the momentum buffer is set to the
|
||||
gradient value at the first step. This is in contrast to some other
|
||||
frameworks that initialize it to all zeros.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=required, momentum=0, dampening=0,
|
||||
weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None,
|
||||
differentiable: bool = False):
|
||||
|
|
@ -180,6 +92,98 @@ class SGD(Optimizer):
|
|||
return loss
|
||||
|
||||
|
||||
SGD.__doc__ = r"""\
|
||||
Implements stochastic gradient descent (optionally with momentum).
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
|
||||
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
|
||||
&\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
|
||||
\:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
|
||||
&\hspace{10mm}\textbf{if} \: t > 1 \\
|
||||
&\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
|
||||
&\hspace{10mm}\textbf{else} \\
|
||||
&\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
|
||||
&\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
|
||||
&\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
|
||||
&\hspace{10mm}\textbf{else} \\[-1.ex]
|
||||
&\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
|
||||
&\hspace{5mm}\textbf{if} \: \textit{maximize} \\
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
|
||||
&\hspace{5mm}\textbf{else} \\[-1.ex]
|
||||
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
Nesterov momentum is based on the formula from
|
||||
`On the importance of initialization and momentum in deep learning`__.
|
||||
""" + r"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float): learning rate
|
||||
momentum (float, optional): momentum factor (default: 0)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
dampening (float, optional): dampening for momentum (default: 0)
|
||||
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
||||
{maximize}
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{differentiable}
|
||||
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc) + r"""
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> optimizer.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> optimizer.step()
|
||||
|
||||
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
|
||||
|
||||
.. note::
|
||||
The implementation of SGD with Momentum/Nesterov subtly differs from
|
||||
Sutskever et. al. and implementations in some other frameworks.
|
||||
|
||||
Considering the specific case of Momentum, the update can be written as
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
|
||||
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
|
||||
\end{aligned}
|
||||
|
||||
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
|
||||
parameters, gradient, velocity, and momentum respectively.
|
||||
|
||||
This is in contrast to Sutskever et. al. and
|
||||
other frameworks which employ an update of the form
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
|
||||
p_{t+1} & = p_{t} - v_{t+1}.
|
||||
\end{aligned}
|
||||
|
||||
The Nesterov version is analogously modified.
|
||||
|
||||
Moreover, the initial value of the momentum buffer is set to the
|
||||
gradient value at the first step. This is in contrast to some other
|
||||
frameworks that initialize it to all zeros.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def sgd(params: List[Tensor],
|
||||
d_p_list: List[Tensor],
|
||||
momentum_buffer_list: List[Optional[Tensor]],
|
||||
|
|
|
|||
|
|
@ -1,30 +1,10 @@
|
|||
import torch
|
||||
from . import _functional as F
|
||||
from .optimizer import Optimizer
|
||||
from .optimizer import Optimizer, _maximize_doc
|
||||
|
||||
__all__ = ['SparseAdam']
|
||||
|
||||
class SparseAdam(Optimizer):
|
||||
r"""Implements lazy version of Adam algorithm suitable for sparse tensors.
|
||||
|
||||
In this variant, only moments that show up in the gradient get updated, and
|
||||
only those portions of the gradient get applied to the parameters.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, maximize: bool = False):
|
||||
if not 0.0 < lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
|
|
@ -116,3 +96,23 @@ class SparseAdam(Optimizer):
|
|||
maximize=maximize)
|
||||
|
||||
return loss
|
||||
|
||||
SparseAdam.__doc__ = r"""Implements lazy version of Adam algorithm suitable for sparse tensors.
|
||||
|
||||
In this variant, only moments that show up in the gradient get updated, and
|
||||
only those portions of the gradient get applied to the parameters.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
{maximize}
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
|
||||
""".format(maximize=_maximize_doc)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user