mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
As discussed in https://github.com/pytorch/pytorch/issues/46392 this makes the code more readable and possibly more performant.
It also fixes a bug detected by this where the argument order of `map` was confused: 030a24906e (diff-5bb26bd3a23ee3bb540aeadcc0385df2a4e48de39f87ed9ea76b21990738fe98L1537-R1537)
Fixes https://github.com/pytorch/pytorch/issues/46392
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46461
Reviewed By: ailzhang
Differential Revision: D24367015
Pulled By: ezyang
fbshipit-source-id: d55a67933cc22346b00544c9671f09982ad920e7
1320 lines
56 KiB
Python
1320 lines
56 KiB
Python
import types
|
|
import math
|
|
from torch._six import inf
|
|
from functools import wraps
|
|
import warnings
|
|
import weakref
|
|
from collections import Counter
|
|
from bisect import bisect_right
|
|
|
|
from .optimizer import Optimizer
|
|
|
|
|
|
EPOCH_DEPRECATION_WARNING = (
|
|
"The epoch parameter in `scheduler.step()` was not necessary and is being "
|
|
"deprecated where possible. Please use `scheduler.step()` to step the "
|
|
"scheduler. During the deprecation, if epoch is different from None, the "
|
|
"closed form is used instead of the new chainable form, where available. "
|
|
"Please open an issue if you are unable to replicate your use case: "
|
|
"https://github.com/pytorch/pytorch/issues/new/choose."
|
|
)
|
|
|
|
SAVE_STATE_WARNING = "Please also save or load the state of the optimizer when saving or loading the scheduler."
|
|
|
|
class _LRScheduler(object):
|
|
|
|
def __init__(self, optimizer, last_epoch=-1, verbose=False):
|
|
|
|
# Attach optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError('{} is not an Optimizer'.format(
|
|
type(optimizer).__name__))
|
|
self.optimizer = optimizer
|
|
|
|
# Initialize epoch and base learning rates
|
|
if last_epoch == -1:
|
|
for group in optimizer.param_groups:
|
|
group.setdefault('initial_lr', group['lr'])
|
|
else:
|
|
for i, group in enumerate(optimizer.param_groups):
|
|
if 'initial_lr' not in group:
|
|
raise KeyError("param 'initial_lr' is not specified "
|
|
"in param_groups[{}] when resuming an optimizer".format(i))
|
|
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
|
|
self.last_epoch = last_epoch
|
|
|
|
# Following https://github.com/pytorch/pytorch/issues/20124
|
|
# We would like to ensure that `lr_scheduler.step()` is called after
|
|
# `optimizer.step()`
|
|
def with_counter(method):
|
|
if getattr(method, '_with_counter', False):
|
|
# `optimizer.step()` has already been replaced, return.
|
|
return method
|
|
|
|
# Keep a weak reference to the optimizer instance to prevent
|
|
# cyclic references.
|
|
instance_ref = weakref.ref(method.__self__)
|
|
# Get the unbound method for the same purpose.
|
|
func = method.__func__
|
|
cls = instance_ref().__class__
|
|
del method
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
instance = instance_ref()
|
|
instance._step_count += 1
|
|
wrapped = func.__get__(instance, cls)
|
|
return wrapped(*args, **kwargs)
|
|
|
|
# Note that the returned function here is no longer a bound method,
|
|
# so attributes like `__func__` and `__self__` no longer exist.
|
|
wrapper._with_counter = True
|
|
return wrapper
|
|
|
|
self.optimizer.step = with_counter(self.optimizer.step)
|
|
self.optimizer._step_count = 0
|
|
self._step_count = 0
|
|
self.verbose = verbose
|
|
|
|
self.step()
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
"""
|
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Arguments:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
self.__dict__.update(state_dict)
|
|
|
|
def get_last_lr(self):
|
|
""" Return last computed learning rate by current scheduler.
|
|
"""
|
|
return self._last_lr
|
|
|
|
def get_lr(self):
|
|
# Compute learning rate using chainable form of the scheduler
|
|
raise NotImplementedError
|
|
|
|
def print_lr(self, is_verbose, group, lr, epoch=None):
|
|
"""Display the current learning rate.
|
|
"""
|
|
if is_verbose:
|
|
if epoch is None:
|
|
print('Adjusting learning rate'
|
|
' of group {} to {:.4e}.'.format(group, lr))
|
|
else:
|
|
print('Epoch {:5d}: adjusting learning rate'
|
|
' of group {} to {:.4e}.'.format(epoch, group, lr))
|
|
|
|
|
|
def step(self, epoch=None):
|
|
# Raise a warning if old pattern is detected
|
|
# https://github.com/pytorch/pytorch/issues/20124
|
|
if self._step_count == 1:
|
|
if not hasattr(self.optimizer.step, "_with_counter"):
|
|
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
|
|
"initialization. Please, make sure to call `optimizer.step()` before "
|
|
"`lr_scheduler.step()`. See more details at "
|
|
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
|
|
|
|
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
|
|
elif self.optimizer._step_count < 1:
|
|
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
|
|
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
|
|
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
|
|
"will result in PyTorch skipping the first value of the learning rate schedule. "
|
|
"See more details at "
|
|
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
|
|
self._step_count += 1
|
|
|
|
class _enable_get_lr_call:
|
|
|
|
def __init__(self, o):
|
|
self.o = o
|
|
|
|
def __enter__(self):
|
|
self.o._get_lr_called_within_step = True
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.o._get_lr_called_within_step = False
|
|
|
|
with _enable_get_lr_call(self):
|
|
if epoch is None:
|
|
self.last_epoch += 1
|
|
values = self.get_lr()
|
|
else:
|
|
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
|
|
self.last_epoch = epoch
|
|
if hasattr(self, "_get_closed_form_lr"):
|
|
values = self._get_closed_form_lr()
|
|
else:
|
|
values = self.get_lr()
|
|
|
|
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
|
param_group, lr = data
|
|
param_group['lr'] = lr
|
|
self.print_lr(self.verbose, i, lr, epoch)
|
|
|
|
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
|
|
class LambdaLR(_LRScheduler):
|
|
"""Sets the learning rate of each parameter group to the initial lr
|
|
times a given function. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
lr_lambda (function or list): A function which computes a multiplicative
|
|
factor given an integer parameter epoch, or a list of such
|
|
functions, one for each group in optimizer.param_groups.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer has two groups.
|
|
>>> lambda1 = lambda epoch: epoch // 30
|
|
>>> lambda2 = lambda epoch: 0.95 ** epoch
|
|
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
|
|
self.optimizer = optimizer
|
|
|
|
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
|
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
|
else:
|
|
if len(lr_lambda) != len(optimizer.param_groups):
|
|
raise ValueError("Expected {} lr_lambdas, but got {}".format(
|
|
len(optimizer.param_groups), len(lr_lambda)))
|
|
self.lr_lambdas = list(lr_lambda)
|
|
super(LambdaLR, self).__init__(optimizer, last_epoch, verbose)
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
The learning rate lambda functions will only be saved if they are callable objects
|
|
and not if they are functions or lambdas.
|
|
"""
|
|
|
|
warnings.warn(SAVE_STATE_WARNING, UserWarning)
|
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
|
|
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
|
|
|
|
for idx, fn in enumerate(self.lr_lambdas):
|
|
if not isinstance(fn, types.FunctionType):
|
|
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Arguments:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
|
|
warnings.warn(SAVE_STATE_WARNING, UserWarning)
|
|
lr_lambdas = state_dict.pop('lr_lambdas')
|
|
self.__dict__.update(state_dict)
|
|
# Restore state_dict keys in order to prevent side effects
|
|
# https://github.com/pytorch/pytorch/issues/32756
|
|
state_dict['lr_lambdas'] = lr_lambdas
|
|
|
|
for idx, fn in enumerate(lr_lambdas):
|
|
if fn is not None:
|
|
self.lr_lambdas[idx].__dict__.update(fn)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.")
|
|
|
|
return [base_lr * lmbda(self.last_epoch)
|
|
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
|
|
|
|
|
|
class MultiplicativeLR(_LRScheduler):
|
|
"""Multiply the learning rate of each parameter group by the factor given
|
|
in the specified function. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
lr_lambda (function or list): A function which computes a multiplicative
|
|
factor given an integer parameter epoch, or a list of such
|
|
functions, one for each group in optimizer.param_groups.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> lmbda = lambda epoch: 0.95
|
|
>>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
|
|
self.optimizer = optimizer
|
|
|
|
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
|
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
|
else:
|
|
if len(lr_lambda) != len(optimizer.param_groups):
|
|
raise ValueError("Expected {} lr_lambdas, but got {}".format(
|
|
len(optimizer.param_groups), len(lr_lambda)))
|
|
self.lr_lambdas = list(lr_lambda)
|
|
super(MultiplicativeLR, self).__init__(optimizer, last_epoch, verbose)
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
The learning rate lambda functions will only be saved if they are callable objects
|
|
and not if they are functions or lambdas.
|
|
"""
|
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
|
|
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
|
|
|
|
for idx, fn in enumerate(self.lr_lambdas):
|
|
if not isinstance(fn, types.FunctionType):
|
|
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Arguments:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
lr_lambdas = state_dict.pop('lr_lambdas')
|
|
self.__dict__.update(state_dict)
|
|
# Restore state_dict keys in order to prevent side effects
|
|
# https://github.com/pytorch/pytorch/issues/32756
|
|
state_dict['lr_lambdas'] = lr_lambdas
|
|
|
|
for idx, fn in enumerate(lr_lambdas):
|
|
if fn is not None:
|
|
self.lr_lambdas[idx].__dict__.update(fn)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch > 0:
|
|
return [group['lr'] * lmbda(self.last_epoch)
|
|
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
|
|
else:
|
|
return list(self.base_lrs)
|
|
|
|
|
|
class StepLR(_LRScheduler):
|
|
"""Decays the learning rate of each parameter group by gamma every
|
|
step_size epochs. Notice that such decay can happen simultaneously with
|
|
other changes to the learning rate from outside this scheduler. When
|
|
last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
step_size (int): Period of learning rate decay.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
Default: 0.1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.05 if epoch < 30
|
|
>>> # lr = 0.005 if 30 <= epoch < 60
|
|
>>> # lr = 0.0005 if 60 <= epoch < 90
|
|
>>> # ...
|
|
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
|
|
self.step_size = step_size
|
|
self.gamma = gamma
|
|
super(StepLR, self).__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
return [group['lr'] * self.gamma
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class MultiStepLR(_LRScheduler):
|
|
"""Decays the learning rate of each parameter group by gamma once the
|
|
number of epoch reaches one of the milestones. Notice that such decay can
|
|
happen simultaneously with other changes to the learning rate from outside
|
|
this scheduler. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
milestones (list): List of epoch indices. Must be increasing.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
Default: 0.1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.05 if epoch < 30
|
|
>>> # lr = 0.005 if 30 <= epoch < 80
|
|
>>> # lr = 0.0005 if epoch >= 80
|
|
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
|
|
self.milestones = Counter(milestones)
|
|
self.gamma = gamma
|
|
super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch not in self.milestones:
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
milestones = list(sorted(self.milestones.elements()))
|
|
return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ExponentialLR(_LRScheduler):
|
|
"""Decays the learning rate of each parameter group by gamma every epoch.
|
|
When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
"""
|
|
|
|
def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
|
|
self.gamma = gamma
|
|
super(ExponentialLR, self).__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch == 0:
|
|
return self.base_lrs
|
|
return [group['lr'] * self.gamma
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [base_lr * self.gamma ** self.last_epoch
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class CosineAnnealingLR(_LRScheduler):
|
|
r"""Set the learning rate of each parameter group using a cosine annealing
|
|
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
|
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
|
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
|
& T_{cur} \neq (2k+1)T_{max}; \\
|
|
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
|
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
|
& T_{cur} = (2k+1)T_{max}.
|
|
\end{aligned}
|
|
|
|
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
|
|
is defined recursively, the learning rate can be simultaneously modified
|
|
outside this scheduler by other operators. If the learning rate is set
|
|
solely by this scheduler, the learning rate at each step becomes:
|
|
|
|
.. math::
|
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
|
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
|
|
|
|
It has been proposed in
|
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
|
implements the cosine annealing part of SGDR, and not the restarts.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
T_max (int): Maximum number of iterations.
|
|
eta_min (float): Minimum learning rate. Default: 0.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
|
https://arxiv.org/abs/1608.03983
|
|
"""
|
|
|
|
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
|
|
self.T_max = T_max
|
|
self.eta_min = eta_min
|
|
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch == 0:
|
|
return self.base_lrs
|
|
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
|
|
return [group['lr'] + (base_lr - self.eta_min) *
|
|
(1 - math.cos(math.pi / self.T_max)) / 2
|
|
for base_lr, group in
|
|
zip(self.base_lrs, self.optimizer.param_groups)]
|
|
return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
|
|
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
|
|
(group['lr'] - self.eta_min) + self.eta_min
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [self.eta_min + (base_lr - self.eta_min) *
|
|
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ReduceLROnPlateau(object):
|
|
"""Reduce learning rate when a metric has stopped improving.
|
|
Models often benefit from reducing the learning rate by a factor
|
|
of 2-10 once learning stagnates. This scheduler reads a metrics
|
|
quantity and if no improvement is seen for a 'patience' number
|
|
of epochs, the learning rate is reduced.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
mode (str): One of `min`, `max`. In `min` mode, lr will
|
|
be reduced when the quantity monitored has stopped
|
|
decreasing; in `max` mode it will be reduced when the
|
|
quantity monitored has stopped increasing. Default: 'min'.
|
|
factor (float): Factor by which the learning rate will be
|
|
reduced. new_lr = lr * factor. Default: 0.1.
|
|
patience (int): Number of epochs with no improvement after
|
|
which learning rate will be reduced. For example, if
|
|
`patience = 2`, then we will ignore the first 2 epochs
|
|
with no improvement, and will only decrease the LR after the
|
|
3rd epoch if the loss still hasn't improved then.
|
|
Default: 10.
|
|
threshold (float): Threshold for measuring the new optimum,
|
|
to only focus on significant changes. Default: 1e-4.
|
|
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
|
|
dynamic_threshold = best * ( 1 + threshold ) in 'max'
|
|
mode or best * ( 1 - threshold ) in `min` mode.
|
|
In `abs` mode, dynamic_threshold = best + threshold in
|
|
`max` mode or best - threshold in `min` mode. Default: 'rel'.
|
|
cooldown (int): Number of epochs to wait before resuming
|
|
normal operation after lr has been reduced. Default: 0.
|
|
min_lr (float or list): A scalar or a list of scalars. A
|
|
lower bound on the learning rate of all param groups
|
|
or each group respectively. Default: 0.
|
|
eps (float): Minimal decay applied to lr. If the difference
|
|
between new and old lr is smaller than eps, the update is
|
|
ignored. Default: 1e-8.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
|
|
>>> for epoch in range(10):
|
|
>>> train(...)
|
|
>>> val_loss = validate(...)
|
|
>>> # Note that step should be called after validate()
|
|
>>> scheduler.step(val_loss)
|
|
"""
|
|
|
|
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
|
|
threshold=1e-4, threshold_mode='rel', cooldown=0,
|
|
min_lr=0, eps=1e-8, verbose=False):
|
|
|
|
if factor >= 1.0:
|
|
raise ValueError('Factor should be < 1.0.')
|
|
self.factor = factor
|
|
|
|
# Attach optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError('{} is not an Optimizer'.format(
|
|
type(optimizer).__name__))
|
|
self.optimizer = optimizer
|
|
|
|
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
|
|
if len(min_lr) != len(optimizer.param_groups):
|
|
raise ValueError("expected {} min_lrs, got {}".format(
|
|
len(optimizer.param_groups), len(min_lr)))
|
|
self.min_lrs = list(min_lr)
|
|
else:
|
|
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
|
|
|
self.patience = patience
|
|
self.verbose = verbose
|
|
self.cooldown = cooldown
|
|
self.cooldown_counter = 0
|
|
self.mode = mode
|
|
self.threshold = threshold
|
|
self.threshold_mode = threshold_mode
|
|
self.best = None
|
|
self.num_bad_epochs = None
|
|
self.mode_worse = None # the worse value for the chosen mode
|
|
self.eps = eps
|
|
self.last_epoch = 0
|
|
self._init_is_better(mode=mode, threshold=threshold,
|
|
threshold_mode=threshold_mode)
|
|
self._reset()
|
|
|
|
def _reset(self):
|
|
"""Resets num_bad_epochs counter and cooldown counter."""
|
|
self.best = self.mode_worse
|
|
self.cooldown_counter = 0
|
|
self.num_bad_epochs = 0
|
|
|
|
def step(self, metrics, epoch=None):
|
|
# convert `metrics` to float, in case it's a zero-dim Tensor
|
|
current = float(metrics)
|
|
if epoch is None:
|
|
epoch = self.last_epoch + 1
|
|
else:
|
|
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
|
|
self.last_epoch = epoch
|
|
|
|
if self.is_better(current, self.best):
|
|
self.best = current
|
|
self.num_bad_epochs = 0
|
|
else:
|
|
self.num_bad_epochs += 1
|
|
|
|
if self.in_cooldown:
|
|
self.cooldown_counter -= 1
|
|
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
|
|
|
|
if self.num_bad_epochs > self.patience:
|
|
self._reduce_lr(epoch)
|
|
self.cooldown_counter = self.cooldown
|
|
self.num_bad_epochs = 0
|
|
|
|
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
def _reduce_lr(self, epoch):
|
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
old_lr = float(param_group['lr'])
|
|
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
|
if old_lr - new_lr > self.eps:
|
|
param_group['lr'] = new_lr
|
|
if self.verbose:
|
|
print('Epoch {:5d}: reducing learning rate'
|
|
' of group {} to {:.4e}.'.format(epoch, i, new_lr))
|
|
|
|
@property
|
|
def in_cooldown(self):
|
|
return self.cooldown_counter > 0
|
|
|
|
def is_better(self, a, best):
|
|
if self.mode == 'min' and self.threshold_mode == 'rel':
|
|
rel_epsilon = 1. - self.threshold
|
|
return a < best * rel_epsilon
|
|
|
|
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
|
return a < best - self.threshold
|
|
|
|
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
|
rel_epsilon = self.threshold + 1.
|
|
return a > best * rel_epsilon
|
|
|
|
else: # mode == 'max' and epsilon_mode == 'abs':
|
|
return a > best + self.threshold
|
|
|
|
def _init_is_better(self, mode, threshold, threshold_mode):
|
|
if mode not in {'min', 'max'}:
|
|
raise ValueError('mode ' + mode + ' is unknown!')
|
|
if threshold_mode not in {'rel', 'abs'}:
|
|
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
|
|
|
|
if mode == 'min':
|
|
self.mode_worse = inf
|
|
else: # mode == 'max':
|
|
self.mode_worse = -inf
|
|
|
|
self.mode = mode
|
|
self.threshold = threshold
|
|
self.threshold_mode = threshold_mode
|
|
|
|
def state_dict(self):
|
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.__dict__.update(state_dict)
|
|
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
|
|
|
|
|
|
class CyclicLR(_LRScheduler):
|
|
r"""Sets the learning rate of each parameter group according to
|
|
cyclical learning rate policy (CLR). The policy cycles the learning
|
|
rate between two boundaries with a constant frequency, as detailed in
|
|
the paper `Cyclical Learning Rates for Training Neural Networks`_.
|
|
The distance between the two boundaries can be scaled on a per-iteration
|
|
or per-cycle basis.
|
|
|
|
Cyclical learning rate policy changes the learning rate after every batch.
|
|
`step` should be called after a batch has been used for training.
|
|
|
|
This class has three built-in policies, as put forth in the paper:
|
|
|
|
* "triangular": A basic triangular cycle without amplitude scaling.
|
|
* "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
|
|
* "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
|
|
at each cycle iteration.
|
|
|
|
This implementation was adapted from the github repo: `bckenstler/CLR`_
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
base_lr (float or list): Initial learning rate which is the
|
|
lower boundary in the cycle for each parameter group.
|
|
max_lr (float or list): Upper learning rate boundaries in the cycle
|
|
for each parameter group. Functionally,
|
|
it defines the cycle amplitude (max_lr - base_lr).
|
|
The lr at any cycle is the sum of base_lr
|
|
and some scaling of the amplitude; therefore
|
|
max_lr may not actually be reached depending on
|
|
scaling function.
|
|
step_size_up (int): Number of training iterations in the
|
|
increasing half of a cycle. Default: 2000
|
|
step_size_down (int): Number of training iterations in the
|
|
decreasing half of a cycle. If step_size_down is None,
|
|
it is set to step_size_up. Default: None
|
|
mode (str): One of {triangular, triangular2, exp_range}.
|
|
Values correspond to policies detailed above.
|
|
If scale_fn is not None, this argument is ignored.
|
|
Default: 'triangular'
|
|
gamma (float): Constant in 'exp_range' scaling function:
|
|
gamma**(cycle iterations)
|
|
Default: 1.0
|
|
scale_fn (function): Custom scaling policy defined by a single
|
|
argument lambda function, where
|
|
0 <= scale_fn(x) <= 1 for all x >= 0.
|
|
If specified, then 'mode' is ignored.
|
|
Default: None
|
|
scale_mode (str): {'cycle', 'iterations'}.
|
|
Defines whether scale_fn is evaluated on
|
|
cycle number or cycle iterations (training
|
|
iterations since start of cycle).
|
|
Default: 'cycle'
|
|
cycle_momentum (bool): If ``True``, momentum is cycled inversely
|
|
to learning rate between 'base_momentum' and 'max_momentum'.
|
|
Default: True
|
|
base_momentum (float or list): Lower momentum boundaries in the cycle
|
|
for each parameter group. Note that momentum is cycled inversely
|
|
to learning rate; at the peak of a cycle, momentum is
|
|
'base_momentum' and learning rate is 'max_lr'.
|
|
Default: 0.8
|
|
max_momentum (float or list): Upper momentum boundaries in the cycle
|
|
for each parameter group. Functionally,
|
|
it defines the cycle amplitude (max_momentum - base_momentum).
|
|
The momentum at any cycle is the difference of max_momentum
|
|
and some scaling of the amplitude; therefore
|
|
base_momentum may not actually be reached depending on
|
|
scaling function. Note that momentum is cycled inversely
|
|
to learning rate; at the start of a cycle, momentum is 'max_momentum'
|
|
and learning rate is 'base_lr'
|
|
Default: 0.9
|
|
last_epoch (int): The index of the last batch. This parameter is used when
|
|
resuming a training job. Since `step()` should be invoked after each
|
|
batch instead of after each epoch, this number represents the total
|
|
number of *batches* computed, not the total number of epochs computed.
|
|
When last_epoch=-1, the schedule is started from the beginning.
|
|
Default: -1
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
|
|
>>> data_loader = torch.utils.data.DataLoader(...)
|
|
>>> for epoch in range(10):
|
|
>>> for batch in data_loader:
|
|
>>> train_batch(...)
|
|
>>> scheduler.step()
|
|
|
|
|
|
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
|
|
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
|
|
"""
|
|
|
|
def __init__(self,
|
|
optimizer,
|
|
base_lr,
|
|
max_lr,
|
|
step_size_up=2000,
|
|
step_size_down=None,
|
|
mode='triangular',
|
|
gamma=1.,
|
|
scale_fn=None,
|
|
scale_mode='cycle',
|
|
cycle_momentum=True,
|
|
base_momentum=0.8,
|
|
max_momentum=0.9,
|
|
last_epoch=-1,
|
|
verbose=False):
|
|
|
|
# Attach optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError('{} is not an Optimizer'.format(
|
|
type(optimizer).__name__))
|
|
self.optimizer = optimizer
|
|
|
|
base_lrs = self._format_param('base_lr', optimizer, base_lr)
|
|
if last_epoch == -1:
|
|
for lr, group in zip(base_lrs, optimizer.param_groups):
|
|
group['lr'] = lr
|
|
|
|
self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
|
|
|
|
step_size_up = float(step_size_up)
|
|
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
|
|
self.total_size = step_size_up + step_size_down
|
|
self.step_ratio = step_size_up / self.total_size
|
|
|
|
if mode not in ['triangular', 'triangular2', 'exp_range'] \
|
|
and scale_fn is None:
|
|
raise ValueError('mode is invalid and scale_fn is None')
|
|
|
|
self.mode = mode
|
|
self.gamma = gamma
|
|
|
|
if scale_fn is None:
|
|
if self.mode == 'triangular':
|
|
self.scale_fn = self._triangular_scale_fn
|
|
self.scale_mode = 'cycle'
|
|
elif self.mode == 'triangular2':
|
|
self.scale_fn = self._triangular2_scale_fn
|
|
self.scale_mode = 'cycle'
|
|
elif self.mode == 'exp_range':
|
|
self.scale_fn = self._exp_range_scale_fn
|
|
self.scale_mode = 'iterations'
|
|
else:
|
|
self.scale_fn = scale_fn
|
|
self.scale_mode = scale_mode
|
|
|
|
self.cycle_momentum = cycle_momentum
|
|
if cycle_momentum:
|
|
if 'momentum' not in optimizer.defaults:
|
|
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
|
|
|
|
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
|
if last_epoch == -1:
|
|
for momentum, group in zip(base_momentums, optimizer.param_groups):
|
|
group['momentum'] = momentum
|
|
self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
|
|
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
|
|
|
super(CyclicLR, self).__init__(optimizer, last_epoch, verbose)
|
|
self.base_lrs = base_lrs
|
|
|
|
def _format_param(self, name, optimizer, param):
|
|
"""Return correctly formatted lr/momentum for each param group."""
|
|
if isinstance(param, (list, tuple)):
|
|
if len(param) != len(optimizer.param_groups):
|
|
raise ValueError("expected {} values for {}, got {}".format(
|
|
len(optimizer.param_groups), name, len(param)))
|
|
return param
|
|
else:
|
|
return [param] * len(optimizer.param_groups)
|
|
|
|
def _triangular_scale_fn(self, x):
|
|
return 1.
|
|
|
|
def _triangular2_scale_fn(self, x):
|
|
return 1 / (2. ** (x - 1))
|
|
|
|
def _exp_range_scale_fn(self, x):
|
|
return self.gamma**(x)
|
|
|
|
def get_lr(self):
|
|
"""Calculates the learning rate at batch index. This function treats
|
|
`self.last_epoch` as the last batch index.
|
|
|
|
If `self.cycle_momentum` is ``True``, this function has a side effect of
|
|
updating the optimizer's momentum.
|
|
"""
|
|
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
cycle = math.floor(1 + self.last_epoch / self.total_size)
|
|
x = 1. + self.last_epoch / self.total_size - cycle
|
|
if x <= self.step_ratio:
|
|
scale_factor = x / self.step_ratio
|
|
else:
|
|
scale_factor = (x - 1) / (self.step_ratio - 1)
|
|
|
|
lrs = []
|
|
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
|
|
base_height = (max_lr - base_lr) * scale_factor
|
|
if self.scale_mode == 'cycle':
|
|
lr = base_lr + base_height * self.scale_fn(cycle)
|
|
else:
|
|
lr = base_lr + base_height * self.scale_fn(self.last_epoch)
|
|
lrs.append(lr)
|
|
|
|
if self.cycle_momentum:
|
|
momentums = []
|
|
for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
|
|
base_height = (max_momentum - base_momentum) * scale_factor
|
|
if self.scale_mode == 'cycle':
|
|
momentum = max_momentum - base_height * self.scale_fn(cycle)
|
|
else:
|
|
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
|
|
momentums.append(momentum)
|
|
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
|
|
param_group['momentum'] = momentum
|
|
|
|
return lrs
|
|
|
|
|
|
class CosineAnnealingWarmRestarts(_LRScheduler):
|
|
r"""Set the learning rate of each parameter group using a cosine annealing
|
|
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
|
|
is the number of epochs since the last restart and :math:`T_{i}` is the number
|
|
of epochs between two warm restarts in SGDR:
|
|
|
|
.. math::
|
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
|
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
|
|
|
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
|
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
|
|
|
It has been proposed in
|
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
T_0 (int): Number of iterations for the first restart.
|
|
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
|
|
eta_min (float, optional): Minimum learning rate. Default: 0.
|
|
last_epoch (int, optional): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
|
https://arxiv.org/abs/1608.03983
|
|
"""
|
|
|
|
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
|
|
if T_0 <= 0 or not isinstance(T_0, int):
|
|
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
|
if T_mult < 1 or not isinstance(T_mult, int):
|
|
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
|
self.T_0 = T_0
|
|
self.T_i = T_0
|
|
self.T_mult = T_mult
|
|
self.eta_min = eta_min
|
|
|
|
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)
|
|
|
|
self.T_cur = self.last_epoch
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
|
|
for base_lr in self.base_lrs]
|
|
|
|
def step(self, epoch=None):
|
|
"""Step could be called after every batch update
|
|
|
|
Example:
|
|
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
|
|
>>> iters = len(dataloader)
|
|
>>> for epoch in range(20):
|
|
>>> for i, sample in enumerate(dataloader):
|
|
>>> inputs, labels = sample['inputs'], sample['labels']
|
|
>>> optimizer.zero_grad()
|
|
>>> outputs = net(inputs)
|
|
>>> loss = criterion(outputs, labels)
|
|
>>> loss.backward()
|
|
>>> optimizer.step()
|
|
>>> scheduler.step(epoch + i / iters)
|
|
|
|
This function can be called in an interleaved way.
|
|
|
|
Example:
|
|
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
|
|
>>> for epoch in range(20):
|
|
>>> scheduler.step()
|
|
>>> scheduler.step(26)
|
|
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
|
|
"""
|
|
|
|
if epoch is None and self.last_epoch < 0:
|
|
epoch = 0
|
|
|
|
if epoch is None:
|
|
epoch = self.last_epoch + 1
|
|
self.T_cur = self.T_cur + 1
|
|
if self.T_cur >= self.T_i:
|
|
self.T_cur = self.T_cur - self.T_i
|
|
self.T_i = self.T_i * self.T_mult
|
|
else:
|
|
if epoch < 0:
|
|
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
|
|
if epoch >= self.T_0:
|
|
if self.T_mult == 1:
|
|
self.T_cur = epoch % self.T_0
|
|
else:
|
|
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
|
|
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
|
|
self.T_i = self.T_0 * self.T_mult ** (n)
|
|
else:
|
|
self.T_i = self.T_0
|
|
self.T_cur = epoch
|
|
self.last_epoch = math.floor(epoch)
|
|
|
|
class _enable_get_lr_call:
|
|
|
|
def __init__(self, o):
|
|
self.o = o
|
|
|
|
def __enter__(self):
|
|
self.o._get_lr_called_within_step = True
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.o._get_lr_called_within_step = False
|
|
return self
|
|
|
|
with _enable_get_lr_call(self):
|
|
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
|
|
param_group, lr = data
|
|
param_group['lr'] = lr
|
|
self.print_lr(self.verbose, i, lr, epoch)
|
|
|
|
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
|
|
class OneCycleLR(_LRScheduler):
|
|
r"""Sets the learning rate of each parameter group according to the
|
|
1cycle learning rate policy. The 1cycle policy anneals the learning
|
|
rate from an initial learning rate to some maximum learning rate and then
|
|
from that maximum learning rate to some minimum learning rate much lower
|
|
than the initial learning rate.
|
|
This policy was initially described in the paper `Super-Convergence:
|
|
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
|
|
|
The 1cycle learning rate policy changes the learning rate after every batch.
|
|
`step` should be called after a batch has been used for training.
|
|
|
|
This scheduler is not chainable.
|
|
|
|
Note also that the total number of steps in the cycle can be determined in one
|
|
of two ways (listed in order of precedence):
|
|
|
|
#. A value for total_steps is explicitly provided.
|
|
#. A number of epochs (epochs) and a number of steps per epoch
|
|
(steps_per_epoch) are provided.
|
|
In this case, the number of total steps is inferred by
|
|
total_steps = epochs * steps_per_epoch
|
|
|
|
You must either provide a value for total_steps or provide a value for both
|
|
epochs and steps_per_epoch.
|
|
|
|
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
|
|
claims that "unpublished work has shown even better results by using only two phases". To
|
|
mimic the behaviour of the original paper instead, set ``three_phase=True``.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
max_lr (float or list): Upper learning rate boundaries in the cycle
|
|
for each parameter group.
|
|
total_steps (int): The total number of steps in the cycle. Note that
|
|
if a value is not provided here, then it must be inferred by providing
|
|
a value for epochs and steps_per_epoch.
|
|
Default: None
|
|
epochs (int): The number of epochs to train for. This is used along
|
|
with steps_per_epoch in order to infer the total number of steps in the cycle
|
|
if a value for total_steps is not provided.
|
|
Default: None
|
|
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
|
used along with epochs in order to infer the total number of steps in the
|
|
cycle if a value for total_steps is not provided.
|
|
Default: None
|
|
pct_start (float): The percentage of the cycle (in number of steps) spent
|
|
increasing the learning rate.
|
|
Default: 0.3
|
|
anneal_strategy (str): {'cos', 'linear'}
|
|
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
|
|
linear annealing.
|
|
Default: 'cos'
|
|
cycle_momentum (bool): If ``True``, momentum is cycled inversely
|
|
to learning rate between 'base_momentum' and 'max_momentum'.
|
|
Default: True
|
|
base_momentum (float or list): Lower momentum boundaries in the cycle
|
|
for each parameter group. Note that momentum is cycled inversely
|
|
to learning rate; at the peak of a cycle, momentum is
|
|
'base_momentum' and learning rate is 'max_lr'.
|
|
Default: 0.85
|
|
max_momentum (float or list): Upper momentum boundaries in the cycle
|
|
for each parameter group. Functionally,
|
|
it defines the cycle amplitude (max_momentum - base_momentum).
|
|
Note that momentum is cycled inversely
|
|
to learning rate; at the start of a cycle, momentum is 'max_momentum'
|
|
and learning rate is 'base_lr'
|
|
Default: 0.95
|
|
div_factor (float): Determines the initial learning rate via
|
|
initial_lr = max_lr/div_factor
|
|
Default: 25
|
|
final_div_factor (float): Determines the minimum learning rate via
|
|
min_lr = initial_lr/final_div_factor
|
|
Default: 1e4
|
|
three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
|
|
learning rate according to 'final_div_factor' instead of modifying the second
|
|
phase (the first two phases will be symmetrical about the step indicated by
|
|
'pct_start').
|
|
last_epoch (int): The index of the last batch. This parameter is used when
|
|
resuming a training job. Since `step()` should be invoked after each
|
|
batch instead of after each epoch, this number represents the total
|
|
number of *batches* computed, not the total number of epochs computed.
|
|
When last_epoch=-1, the schedule is started from the beginning.
|
|
Default: -1
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> data_loader = torch.utils.data.DataLoader(...)
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
|
|
>>> for epoch in range(10):
|
|
>>> for batch in data_loader:
|
|
>>> train_batch(...)
|
|
>>> scheduler.step()
|
|
|
|
|
|
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
|
https://arxiv.org/abs/1708.07120
|
|
"""
|
|
def __init__(self,
|
|
optimizer,
|
|
max_lr,
|
|
total_steps=None,
|
|
epochs=None,
|
|
steps_per_epoch=None,
|
|
pct_start=0.3,
|
|
anneal_strategy='cos',
|
|
cycle_momentum=True,
|
|
base_momentum=0.85,
|
|
max_momentum=0.95,
|
|
div_factor=25.,
|
|
final_div_factor=1e4,
|
|
three_phase=False,
|
|
last_epoch=-1,
|
|
verbose=False):
|
|
|
|
# Validate optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError('{} is not an Optimizer'.format(
|
|
type(optimizer).__name__))
|
|
self.optimizer = optimizer
|
|
|
|
# Validate total_steps
|
|
if total_steps is None and epochs is None and steps_per_epoch is None:
|
|
raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
|
|
elif total_steps is not None:
|
|
if total_steps <= 0 or not isinstance(total_steps, int):
|
|
raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
|
|
self.total_steps = total_steps
|
|
else:
|
|
if epochs <= 0 or not isinstance(epochs, int):
|
|
raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
|
|
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
|
|
raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
|
|
self.total_steps = epochs * steps_per_epoch
|
|
|
|
if three_phase:
|
|
self._schedule_phases = [
|
|
{
|
|
'end_step': float(pct_start * self.total_steps) - 1,
|
|
'start_lr': 'initial_lr',
|
|
'end_lr': 'max_lr',
|
|
'start_momentum': 'max_momentum',
|
|
'end_momentum': 'base_momentum',
|
|
},
|
|
{
|
|
'end_step': float(2 * pct_start * self.total_steps) - 2,
|
|
'start_lr': 'max_lr',
|
|
'end_lr': 'initial_lr',
|
|
'start_momentum': 'base_momentum',
|
|
'end_momentum': 'max_momentum',
|
|
},
|
|
{
|
|
'end_step': self.total_steps - 1,
|
|
'start_lr': 'initial_lr',
|
|
'end_lr': 'min_lr',
|
|
'start_momentum': 'max_momentum',
|
|
'end_momentum': 'max_momentum',
|
|
},
|
|
]
|
|
else:
|
|
self._schedule_phases = [
|
|
{
|
|
'end_step': float(pct_start * self.total_steps) - 1,
|
|
'start_lr': 'initial_lr',
|
|
'end_lr': 'max_lr',
|
|
'start_momentum': 'max_momentum',
|
|
'end_momentum': 'base_momentum',
|
|
},
|
|
{
|
|
'end_step': self.total_steps - 1,
|
|
'start_lr': 'max_lr',
|
|
'end_lr': 'min_lr',
|
|
'start_momentum': 'base_momentum',
|
|
'end_momentum': 'max_momentum',
|
|
},
|
|
]
|
|
|
|
# Validate pct_start
|
|
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
|
|
raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
|
|
|
|
# Validate anneal_strategy
|
|
if anneal_strategy not in ['cos', 'linear']:
|
|
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
|
|
elif anneal_strategy == 'cos':
|
|
self.anneal_func = self._annealing_cos
|
|
elif anneal_strategy == 'linear':
|
|
self.anneal_func = self._annealing_linear
|
|
|
|
# Initialize learning rate variables
|
|
max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
|
|
if last_epoch == -1:
|
|
for idx, group in enumerate(self.optimizer.param_groups):
|
|
group['initial_lr'] = max_lrs[idx] / div_factor
|
|
group['max_lr'] = max_lrs[idx]
|
|
group['min_lr'] = group['initial_lr'] / final_div_factor
|
|
|
|
# Initialize momentum variables
|
|
self.cycle_momentum = cycle_momentum
|
|
if self.cycle_momentum:
|
|
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
|
|
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
|
|
self.use_beta1 = 'betas' in self.optimizer.defaults
|
|
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
|
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
|
if last_epoch == -1:
|
|
for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
|
|
if self.use_beta1:
|
|
_, beta2 = group['betas']
|
|
group['betas'] = (m_momentum, beta2)
|
|
else:
|
|
group['momentum'] = m_momentum
|
|
group['max_momentum'] = m_momentum
|
|
group['base_momentum'] = b_momentum
|
|
|
|
super(OneCycleLR, self).__init__(optimizer, last_epoch, verbose)
|
|
|
|
def _format_param(self, name, optimizer, param):
|
|
"""Return correctly formatted lr/momentum for each param group."""
|
|
if isinstance(param, (list, tuple)):
|
|
if len(param) != len(optimizer.param_groups):
|
|
raise ValueError("expected {} values for {}, got {}".format(
|
|
len(optimizer.param_groups), name, len(param)))
|
|
return param
|
|
else:
|
|
return [param] * len(optimizer.param_groups)
|
|
|
|
def _annealing_cos(self, start, end, pct):
|
|
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
|
cos_out = math.cos(math.pi * pct) + 1
|
|
return end + (start - end) / 2.0 * cos_out
|
|
|
|
def _annealing_linear(self, start, end, pct):
|
|
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
|
return (end - start) * pct + start
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
lrs = []
|
|
step_num = self.last_epoch
|
|
|
|
if step_num > self.total_steps:
|
|
raise ValueError("Tried to step {} times. The specified number of total steps is {}"
|
|
.format(step_num + 1, self.total_steps))
|
|
|
|
for group in self.optimizer.param_groups:
|
|
start_step = 0
|
|
for i, phase in enumerate(self._schedule_phases):
|
|
end_step = phase['end_step']
|
|
if step_num <= end_step or i == len(self._schedule_phases) - 1:
|
|
pct = (step_num - start_step) / (end_step - start_step)
|
|
computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
|
|
if self.cycle_momentum:
|
|
computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
|
|
break
|
|
start_step = phase['end_step']
|
|
|
|
lrs.append(computed_lr)
|
|
if self.cycle_momentum:
|
|
if self.use_beta1:
|
|
_, beta2 = group['betas']
|
|
group['betas'] = (computed_momentum, beta2)
|
|
else:
|
|
group['momentum'] = computed_momentum
|
|
|
|
return lrs
|