mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: I opened an issue explaining some of my frustrations with the current state of schedulers. While most points that I raised in [that issue](https://github.com/pytorch/pytorch/issues/8741#issuecomment-404449697) need to be discussed more thoroughly before being implemented, there are some that are not so difficult to fix. This PR changes the way the LambdaLR scheduler gets serialized: > The lr_lambda functions are only saved if the are callable objects (which can be stateful). > There is no point in saving functions/lambdas as you need their definition before unpickling and they are stateless. This has the big advantage that the scheduler is serializable, even if you use lambda functions or locally defined functions (aka a function in a function). Does this functionality need any unit tests? Pull Request resolved: https://github.com/pytorch/pytorch/pull/9927 Differential Revision: D9055505 Pulled By: soumith fbshipit-source-id: 6c1cec588beedd098ec7d2bce6a9add27f29e48f
416 lines
16 KiB
Python
416 lines
16 KiB
Python
import types
|
|
import math
|
|
import torch
|
|
from torch._six import inf
|
|
from bisect import bisect_right
|
|
from functools import partial
|
|
from .optimizer import Optimizer
|
|
|
|
|
|
class _LRScheduler(object):
|
|
def __init__(self, optimizer, last_epoch=-1):
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError('{} is not an Optimizer'.format(
|
|
type(optimizer).__name__))
|
|
self.optimizer = optimizer
|
|
if last_epoch == -1:
|
|
for group in optimizer.param_groups:
|
|
group.setdefault('initial_lr', group['lr'])
|
|
else:
|
|
for i, group in enumerate(optimizer.param_groups):
|
|
if 'initial_lr' not in group:
|
|
raise KeyError("param 'initial_lr' is not specified "
|
|
"in param_groups[{}] when resuming an optimizer".format(i))
|
|
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
|
self.step(last_epoch + 1)
|
|
self.last_epoch = last_epoch
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
"""
|
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Arguments:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
self.__dict__.update(state_dict)
|
|
|
|
def get_lr(self):
|
|
raise NotImplementedError
|
|
|
|
def step(self, epoch=None):
|
|
if epoch is None:
|
|
epoch = self.last_epoch + 1
|
|
self.last_epoch = epoch
|
|
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
|
param_group['lr'] = lr
|
|
|
|
|
|
class LambdaLR(_LRScheduler):
|
|
"""Sets the learning rate of each parameter group to the initial lr
|
|
times a given function. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
lr_lambda (function or list): A function which computes a multiplicative
|
|
factor given an integer parameter epoch, or a list of such
|
|
functions, one for each group in optimizer.param_groups.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer has two groups.
|
|
>>> lambda1 = lambda epoch: epoch // 30
|
|
>>> lambda2 = lambda epoch: 0.95 ** epoch
|
|
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
|
|
>>> for epoch in range(100):
|
|
>>> scheduler.step()
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
"""
|
|
|
|
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
|
|
self.optimizer = optimizer
|
|
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
|
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
|
else:
|
|
if len(lr_lambda) != len(optimizer.param_groups):
|
|
raise ValueError("Expected {} lr_lambdas, but got {}".format(
|
|
len(optimizer.param_groups), len(lr_lambda)))
|
|
self.lr_lambdas = list(lr_lambda)
|
|
self.last_epoch = last_epoch
|
|
super(LambdaLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
The learning rate lambda functions will only be saved if they are callable objects
|
|
and not if they are functions or lambdas.
|
|
"""
|
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
|
|
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
|
|
|
|
for idx, fn in enumerate(self.lr_lambdas):
|
|
if not isinstance(fn, types.FunctionType):
|
|
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Arguments:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
lr_lambdas = state_dict.pop('lr_lambdas')
|
|
self.__dict__.update(state_dict)
|
|
|
|
for idx, fn in enumerate(lr_lambdas):
|
|
if fn is not None:
|
|
self.lr_lambdas[idx].__dict__.update(fn)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * lmbda(self.last_epoch)
|
|
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
|
|
|
|
|
|
class StepLR(_LRScheduler):
|
|
"""Sets the learning rate of each parameter group to the initial lr
|
|
decayed by gamma every step_size epochs. When last_epoch=-1, sets
|
|
initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
step_size (int): Period of learning rate decay.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
Default: 0.1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.05 if epoch < 30
|
|
>>> # lr = 0.005 if 30 <= epoch < 60
|
|
>>> # lr = 0.0005 if 60 <= epoch < 90
|
|
>>> # ...
|
|
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
|
|
>>> for epoch in range(100):
|
|
>>> scheduler.step()
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
"""
|
|
|
|
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
|
|
self.step_size = step_size
|
|
self.gamma = gamma
|
|
super(StepLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class MultiStepLR(_LRScheduler):
|
|
"""Set the learning rate of each parameter group to the initial lr decayed
|
|
by gamma once the number of epoch reaches one of the milestones. When
|
|
last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
milestones (list): List of epoch indices. Must be increasing.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
Default: 0.1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
|
|
Example:
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.05 if epoch < 30
|
|
>>> # lr = 0.005 if 30 <= epoch < 80
|
|
>>> # lr = 0.0005 if epoch >= 80
|
|
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
|
|
>>> for epoch in range(100):
|
|
>>> scheduler.step()
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
"""
|
|
|
|
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
|
|
if not list(milestones) == sorted(milestones):
|
|
raise ValueError('Milestones should be a list of'
|
|
' increasing integers. Got {}', milestones)
|
|
self.milestones = milestones
|
|
self.gamma = gamma
|
|
super(MultiStepLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ExponentialLR(_LRScheduler):
|
|
"""Set the learning rate of each parameter group to the initial lr decayed
|
|
by gamma every epoch. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
"""
|
|
|
|
def __init__(self, optimizer, gamma, last_epoch=-1):
|
|
self.gamma = gamma
|
|
super(ExponentialLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * self.gamma ** self.last_epoch
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class CosineAnnealingLR(_LRScheduler):
|
|
r"""Set the learning rate of each parameter group using a cosine annealing
|
|
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
|
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
|
|
|
.. math::
|
|
|
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
|
|
\cos(\frac{T_{cur}}{T_{max}}\pi))
|
|
|
|
When last_epoch=-1, sets initial lr as lr.
|
|
|
|
It has been proposed in
|
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
|
implements the cosine annealing part of SGDR, and not the restarts.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
T_max (int): Maximum number of iterations.
|
|
eta_min (float): Minimum learning rate. Default: 0.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
|
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
|
https://arxiv.org/abs/1608.03983
|
|
"""
|
|
|
|
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
|
|
self.T_max = T_max
|
|
self.eta_min = eta_min
|
|
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [self.eta_min + (base_lr - self.eta_min) *
|
|
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ReduceLROnPlateau(object):
|
|
"""Reduce learning rate when a metric has stopped improving.
|
|
Models often benefit from reducing the learning rate by a factor
|
|
of 2-10 once learning stagnates. This scheduler reads a metrics
|
|
quantity and if no improvement is seen for a 'patience' number
|
|
of epochs, the learning rate is reduced.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
mode (str): One of `min`, `max`. In `min` mode, lr will
|
|
be reduced when the quantity monitored has stopped
|
|
decreasing; in `max` mode it will be reduced when the
|
|
quantity monitored has stopped increasing. Default: 'min'.
|
|
factor (float): Factor by which the learning rate will be
|
|
reduced. new_lr = lr * factor. Default: 0.1.
|
|
patience (int): Number of epochs with no improvement after
|
|
which learning rate will be reduced. For example, if
|
|
`patience = 2`, then we will ignore the first 2 epochs
|
|
with no improvement, and will only decrease the LR after the
|
|
3rd epoch if the loss still hasn't improved then.
|
|
Default: 10.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
threshold (float): Threshold for measuring the new optimum,
|
|
to only focus on significant changes. Default: 1e-4.
|
|
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
|
|
dynamic_threshold = best * ( 1 + threshold ) in 'max'
|
|
mode or best * ( 1 - threshold ) in `min` mode.
|
|
In `abs` mode, dynamic_threshold = best + threshold in
|
|
`max` mode or best - threshold in `min` mode. Default: 'rel'.
|
|
cooldown (int): Number of epochs to wait before resuming
|
|
normal operation after lr has been reduced. Default: 0.
|
|
min_lr (float or list): A scalar or a list of scalars. A
|
|
lower bound on the learning rate of all param groups
|
|
or each group respectively. Default: 0.
|
|
eps (float): Minimal decay applied to lr. If the difference
|
|
between new and old lr is smaller than eps, the update is
|
|
ignored. Default: 1e-8.
|
|
|
|
Example:
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
|
|
>>> for epoch in range(10):
|
|
>>> train(...)
|
|
>>> val_loss = validate(...)
|
|
>>> # Note that step should be called after validate()
|
|
>>> scheduler.step(val_loss)
|
|
"""
|
|
|
|
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
|
|
verbose=False, threshold=1e-4, threshold_mode='rel',
|
|
cooldown=0, min_lr=0, eps=1e-8):
|
|
|
|
if factor >= 1.0:
|
|
raise ValueError('Factor should be < 1.0.')
|
|
self.factor = factor
|
|
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError('{} is not an Optimizer'.format(
|
|
type(optimizer).__name__))
|
|
self.optimizer = optimizer
|
|
|
|
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
|
|
if len(min_lr) != len(optimizer.param_groups):
|
|
raise ValueError("expected {} min_lrs, got {}".format(
|
|
len(optimizer.param_groups), len(min_lr)))
|
|
self.min_lrs = list(min_lr)
|
|
else:
|
|
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
|
|
|
self.patience = patience
|
|
self.verbose = verbose
|
|
self.cooldown = cooldown
|
|
self.cooldown_counter = 0
|
|
self.mode = mode
|
|
self.threshold = threshold
|
|
self.threshold_mode = threshold_mode
|
|
self.best = None
|
|
self.num_bad_epochs = None
|
|
self.mode_worse = None # the worse value for the chosen mode
|
|
self.is_better = None
|
|
self.eps = eps
|
|
self.last_epoch = -1
|
|
self._init_is_better(mode=mode, threshold=threshold,
|
|
threshold_mode=threshold_mode)
|
|
self._reset()
|
|
|
|
def _reset(self):
|
|
"""Resets num_bad_epochs counter and cooldown counter."""
|
|
self.best = self.mode_worse
|
|
self.cooldown_counter = 0
|
|
self.num_bad_epochs = 0
|
|
|
|
def step(self, metrics, epoch=None):
|
|
current = metrics
|
|
if epoch is None:
|
|
epoch = self.last_epoch = self.last_epoch + 1
|
|
self.last_epoch = epoch
|
|
|
|
if self.is_better(current, self.best):
|
|
self.best = current
|
|
self.num_bad_epochs = 0
|
|
else:
|
|
self.num_bad_epochs += 1
|
|
|
|
if self.in_cooldown:
|
|
self.cooldown_counter -= 1
|
|
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
|
|
|
|
if self.num_bad_epochs > self.patience:
|
|
self._reduce_lr(epoch)
|
|
self.cooldown_counter = self.cooldown
|
|
self.num_bad_epochs = 0
|
|
|
|
def _reduce_lr(self, epoch):
|
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
old_lr = float(param_group['lr'])
|
|
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
|
if old_lr - new_lr > self.eps:
|
|
param_group['lr'] = new_lr
|
|
if self.verbose:
|
|
print('Epoch {:5d}: reducing learning rate'
|
|
' of group {} to {:.4e}.'.format(epoch, i, new_lr))
|
|
|
|
@property
|
|
def in_cooldown(self):
|
|
return self.cooldown_counter > 0
|
|
|
|
def _cmp(self, mode, threshold_mode, threshold, a, best):
|
|
if mode == 'min' and threshold_mode == 'rel':
|
|
rel_epsilon = 1. - threshold
|
|
return a < best * rel_epsilon
|
|
|
|
elif mode == 'min' and threshold_mode == 'abs':
|
|
return a < best - threshold
|
|
|
|
elif mode == 'max' and threshold_mode == 'rel':
|
|
rel_epsilon = threshold + 1.
|
|
return a > best * rel_epsilon
|
|
|
|
else: # mode == 'max' and epsilon_mode == 'abs':
|
|
return a > best + threshold
|
|
|
|
def _init_is_better(self, mode, threshold, threshold_mode):
|
|
if mode not in {'min', 'max'}:
|
|
raise ValueError('mode ' + mode + ' is unknown!')
|
|
if threshold_mode not in {'rel', 'abs'}:
|
|
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
|
|
|
|
if mode == 'min':
|
|
self.mode_worse = inf
|
|
else: # mode == 'max':
|
|
self.mode_worse = -inf
|
|
|
|
self.is_better = partial(self._cmp, mode, threshold_mode, threshold)
|
|
|
|
def state_dict(self):
|
|
return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.__dict__.update(state_dict)
|
|
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
|