mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* added functionality for state_dict/load_state_dict for lr_scheduler * fixed linting issues/removed unused import * refactor lr_scheduler state_dicts/state_dict holds everything __dict__ but optimizer * changed documentation in lr_scheduler * Update lr_scheduler.py
377 lines
14 KiB
Python
377 lines
14 KiB
Python
import math
|
|
from bisect import bisect_right
|
|
from functools import partial
|
|
from .optimizer import Optimizer
|
|
|
|
|
|
class _LRScheduler(object):
|
|
def __init__(self, optimizer, last_epoch=-1):
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError('{} is not an Optimizer'.format(
|
|
type(optimizer).__name__))
|
|
self.optimizer = optimizer
|
|
if last_epoch == -1:
|
|
for group in optimizer.param_groups:
|
|
group.setdefault('initial_lr', group['lr'])
|
|
else:
|
|
for i, group in enumerate(optimizer.param_groups):
|
|
if 'initial_lr' not in group:
|
|
raise KeyError("param 'initial_lr' is not specified "
|
|
"in param_groups[{}] when resuming an optimizer".format(i))
|
|
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
|
self.step(last_epoch + 1)
|
|
self.last_epoch = last_epoch
|
|
|
|
def __getstate__(self):
|
|
return self.state_dict()
|
|
|
|
def __setstate__(self, state):
|
|
self.load_state_dict(state)
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
"""
|
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Arguments:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
self.__dict__.update(state_dict)
|
|
|
|
def get_lr(self):
|
|
raise NotImplementedError
|
|
|
|
def step(self, epoch=None):
|
|
if epoch is None:
|
|
epoch = self.last_epoch + 1
|
|
self.last_epoch = epoch
|
|
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
|
param_group['lr'] = lr
|
|
|
|
|
|
class LambdaLR(_LRScheduler):
|
|
"""Sets the learning rate of each parameter group to the initial lr
|
|
times a given function. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
lr_lambda (function or list): A function which computes a multiplicative
|
|
factor given an integer parameter epoch, or a list of such
|
|
functions, one for each group in optimizer.param_groups.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer has two groups.
|
|
>>> lambda1 = lambda epoch: epoch // 30
|
|
>>> lambda2 = lambda epoch: 0.95 ** epoch
|
|
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
|
|
>>> for epoch in range(100):
|
|
>>> scheduler.step()
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
"""
|
|
|
|
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
|
|
self.optimizer = optimizer
|
|
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
|
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
|
else:
|
|
if len(lr_lambda) != len(optimizer.param_groups):
|
|
raise ValueError("Expected {} lr_lambdas, but got {}".format(
|
|
len(optimizer.param_groups), len(lr_lambda)))
|
|
self.lr_lambdas = list(lr_lambda)
|
|
self.last_epoch = last_epoch
|
|
super(LambdaLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * lmbda(self.last_epoch)
|
|
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
|
|
|
|
|
|
class StepLR(_LRScheduler):
|
|
"""Sets the learning rate of each parameter group to the initial lr
|
|
decayed by gamma every step_size epochs. When last_epoch=-1, sets
|
|
initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
step_size (int): Period of learning rate decay.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
Default: 0.1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.05 if epoch < 30
|
|
>>> # lr = 0.005 if 30 <= epoch < 60
|
|
>>> # lr = 0.0005 if 60 <= epoch < 90
|
|
>>> # ...
|
|
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
|
|
>>> for epoch in range(100):
|
|
>>> scheduler.step()
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
"""
|
|
|
|
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
|
|
self.step_size = step_size
|
|
self.gamma = gamma
|
|
super(StepLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class MultiStepLR(_LRScheduler):
|
|
"""Set the learning rate of each parameter group to the initial lr decayed
|
|
by gamma once the number of epoch reaches one of the milestones. When
|
|
last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
milestones (list): List of epoch indices. Must be increasing.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
Default: 0.1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.05 if epoch < 30
|
|
>>> # lr = 0.005 if 30 <= epoch < 80
|
|
>>> # lr = 0.0005 if epoch >= 80
|
|
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
|
|
>>> for epoch in range(100):
|
|
>>> scheduler.step()
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
"""
|
|
|
|
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
|
|
if not list(milestones) == sorted(milestones):
|
|
raise ValueError('Milestones should be a list of'
|
|
' increasing integers. Got {}', milestones)
|
|
self.milestones = milestones
|
|
self.gamma = gamma
|
|
super(MultiStepLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ExponentialLR(_LRScheduler):
|
|
"""Set the learning rate of each parameter group to the initial lr decayed
|
|
by gamma every epoch. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
"""
|
|
|
|
def __init__(self, optimizer, gamma, last_epoch=-1):
|
|
self.gamma = gamma
|
|
super(ExponentialLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * self.gamma ** self.last_epoch
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class CosineAnnealingLR(_LRScheduler):
|
|
r"""Set the learning rate of each parameter group using a cosine annealing
|
|
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
|
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
|
|
|
.. math::
|
|
|
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
|
|
\cos(\frac{T_{cur}}{T_{max}}\pi))
|
|
|
|
When last_epoch=-1, sets initial lr as lr.
|
|
|
|
It has been proposed in
|
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
|
implements the cosine annealing part of SGDR, and not the restarts.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
T_max (int): Maximum number of iterations.
|
|
eta_min (float): Minimum learning rate. Default: 0.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
|
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
|
https://arxiv.org/abs/1608.03983
|
|
"""
|
|
|
|
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
|
|
self.T_max = T_max
|
|
self.eta_min = eta_min
|
|
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [self.eta_min + (base_lr - self.eta_min) *
|
|
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ReduceLROnPlateau(object):
|
|
"""Reduce learning rate when a metric has stopped improving.
|
|
Models often benefit from reducing the learning rate by a factor
|
|
of 2-10 once learning stagnates. This scheduler reads a metrics
|
|
quantity and if no improvement is seen for a 'patience' number
|
|
of epochs, the learning rate is reduced.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
mode (str): One of `min`, `max`. In `min` mode, lr will
|
|
be reduced when the quantity monitored has stopped
|
|
decreasing; in `max` mode it will be reduced when the
|
|
quantity monitored has stopped increasing. Default: 'min'.
|
|
factor (float): Factor by which the learning rate will be
|
|
reduced. new_lr = lr * factor. Default: 0.1.
|
|
patience (int): Number of epochs with no improvement after
|
|
which learning rate will be reduced. Default: 10.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
threshold (float): Threshold for measuring the new optimum,
|
|
to only focus on significant changes. Default: 1e-4.
|
|
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
|
|
dynamic_threshold = best * ( 1 + threshold ) in 'max'
|
|
mode or best * ( 1 - threshold ) in `min` mode.
|
|
In `abs` mode, dynamic_threshold = best + threshold in
|
|
`max` mode or best - threshold in `min` mode. Default: 'rel'.
|
|
cooldown (int): Number of epochs to wait before resuming
|
|
normal operation after lr has been reduced. Default: 0.
|
|
min_lr (float or list): A scalar or a list of scalars. A
|
|
lower bound on the learning rate of all param groups
|
|
or each group respectively. Default: 0.
|
|
eps (float): Minimal decay applied to lr. If the difference
|
|
between new and old lr is smaller than eps, the update is
|
|
ignored. Default: 1e-8.
|
|
|
|
Example:
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
|
|
>>> for epoch in range(10):
|
|
>>> train(...)
|
|
>>> val_loss = validate(...)
|
|
>>> # Note that step should be called after validate()
|
|
>>> scheduler.step(val_loss)
|
|
"""
|
|
|
|
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
|
|
verbose=False, threshold=1e-4, threshold_mode='rel',
|
|
cooldown=0, min_lr=0, eps=1e-8):
|
|
|
|
if factor >= 1.0:
|
|
raise ValueError('Factor should be < 1.0.')
|
|
self.factor = factor
|
|
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError('{} is not an Optimizer'.format(
|
|
type(optimizer).__name__))
|
|
self.optimizer = optimizer
|
|
|
|
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
|
|
if len(min_lr) != len(optimizer.param_groups):
|
|
raise ValueError("expected {} min_lrs, got {}".format(
|
|
len(optimizer.param_groups), len(min_lr)))
|
|
self.min_lrs = list(min_lr)
|
|
else:
|
|
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
|
|
|
self.patience = patience
|
|
self.verbose = verbose
|
|
self.cooldown = cooldown
|
|
self.cooldown_counter = 0
|
|
self.mode = mode
|
|
self.threshold = threshold
|
|
self.threshold_mode = threshold_mode
|
|
self.best = None
|
|
self.num_bad_epochs = None
|
|
self.mode_worse = None # the worse value for the chosen mode
|
|
self.is_better = None
|
|
self.eps = eps
|
|
self.last_epoch = -1
|
|
self._init_is_better(mode=mode, threshold=threshold,
|
|
threshold_mode=threshold_mode)
|
|
self._reset()
|
|
|
|
def _reset(self):
|
|
"""Resets num_bad_epochs counter and cooldown counter."""
|
|
self.best = self.mode_worse
|
|
self.cooldown_counter = 0
|
|
self.num_bad_epochs = 0
|
|
|
|
def step(self, metrics, epoch=None):
|
|
current = metrics
|
|
if epoch is None:
|
|
epoch = self.last_epoch = self.last_epoch + 1
|
|
self.last_epoch = epoch
|
|
|
|
if self.is_better(current, self.best):
|
|
self.best = current
|
|
self.num_bad_epochs = 0
|
|
else:
|
|
self.num_bad_epochs += 1
|
|
|
|
if self.in_cooldown:
|
|
self.cooldown_counter -= 1
|
|
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
|
|
|
|
if self.num_bad_epochs > self.patience:
|
|
self._reduce_lr(epoch)
|
|
self.cooldown_counter = self.cooldown
|
|
self.num_bad_epochs = 0
|
|
|
|
def _reduce_lr(self, epoch):
|
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
old_lr = float(param_group['lr'])
|
|
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
|
if old_lr - new_lr > self.eps:
|
|
param_group['lr'] = new_lr
|
|
if self.verbose:
|
|
print('Epoch {:5d}: reducing learning rate'
|
|
' of group {} to {:.4e}.'.format(epoch, i, new_lr))
|
|
|
|
@property
|
|
def in_cooldown(self):
|
|
return self.cooldown_counter > 0
|
|
|
|
def _cmp(self, mode, threshold_mode, threshold, a, best):
|
|
if mode == 'min' and threshold_mode == 'rel':
|
|
rel_epsilon = 1. - threshold
|
|
return a < best * rel_epsilon
|
|
|
|
elif mode == 'min' and threshold_mode == 'abs':
|
|
return a < best - threshold
|
|
|
|
elif mode == 'max' and threshold_mode == 'rel':
|
|
rel_epsilon = threshold + 1.
|
|
return a > best * rel_epsilon
|
|
|
|
else: # mode == 'max' and epsilon_mode == 'abs':
|
|
return a > best + threshold
|
|
|
|
def _init_is_better(self, mode, threshold, threshold_mode):
|
|
if mode not in {'min', 'max'}:
|
|
raise ValueError('mode ' + mode + ' is unknown!')
|
|
if threshold_mode not in {'rel', 'abs'}:
|
|
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
|
|
|
|
if mode == 'min':
|
|
self.mode_worse = float('inf')
|
|
else: # mode == 'max':
|
|
self.mode_worse = (-float('inf'))
|
|
|
|
self.is_better = partial(self._cmp, mode, threshold_mode, threshold)
|