pytorch/torch/optim/__init__.py
Michael Acar a4b2f3e213 Implement AdamW optimizer (#21250)
Summary:
# What is this?
This is an implementation of the AdamW optimizer as implemented in [the fastai library](803894051b/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training.

There have already been several abortive attempts to push this into pytorch in some form or fashion: https://github.com/pytorch/pytorch/pull/17468, https://github.com/pytorch/pytorch/pull/10866, https://github.com/pytorch/pytorch/pull/3740, https://github.com/pytorch/pytorch/pull/4429. Hopefully this one goes through.
# Why is this important?
Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.
# How was this tested?
There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21250

Differential Revision: D16060339

Pulled By: vincentqb

fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
2019-07-02 09:09:10 -07:00

34 lines
922 B
Python

"""
:mod:`torch.optim` is a package implementing various optimization algorithms.
Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can be also easily integrated in the
future.
"""
from .adadelta import Adadelta # noqa: F401
from .adagrad import Adagrad # noqa: F401
from .adam import Adam # noqa: F401
from .adamw import AdamW # noqa: F401
from .sparse_adam import SparseAdam # noqa: F401
from .adamax import Adamax # noqa: F401
from .asgd import ASGD # noqa: F401
from .sgd import SGD # noqa: F401
from .rprop import Rprop # noqa: F401
from .rmsprop import RMSprop # noqa: F401
from .optimizer import Optimizer # noqa: F401
from .lbfgs import LBFGS # noqa: F401
from . import lr_scheduler # noqa: F401
del adadelta
del adagrad
del adam
del adamw
del sparse_adam
del adamax
del asgd
del sgd
del rprop
del rmsprop
del optimizer
del lbfgs