mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings. I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519 Approved by: https://github.com/ezyang
1724 lines
73 KiB
Python
1724 lines
73 KiB
Python
import types
|
|
import math
|
|
from torch import inf
|
|
from functools import wraps, partial
|
|
import warnings
|
|
import weakref
|
|
from collections import Counter
|
|
from bisect import bisect_right
|
|
|
|
from .optimizer import Optimizer
|
|
|
|
__all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
|
|
'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
|
|
'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
|
|
|
|
EPOCH_DEPRECATION_WARNING = (
|
|
"The epoch parameter in `scheduler.step()` was not necessary and is being "
|
|
"deprecated where possible. Please use `scheduler.step()` to step the "
|
|
"scheduler. During the deprecation, if epoch is different from None, the "
|
|
"closed form is used instead of the new chainable form, where available. "
|
|
"Please open an issue if you are unable to replicate your use case: "
|
|
"https://github.com/pytorch/pytorch/issues/new/choose."
|
|
)
|
|
|
|
|
|
class LRScheduler:
|
|
|
|
def __init__(self, optimizer, last_epoch=-1, verbose=False):
|
|
|
|
# Attach optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
|
|
self.optimizer = optimizer
|
|
|
|
# Initialize epoch and base learning rates
|
|
if last_epoch == -1:
|
|
for group in optimizer.param_groups:
|
|
group.setdefault('initial_lr', group['lr'])
|
|
else:
|
|
for i, group in enumerate(optimizer.param_groups):
|
|
if 'initial_lr' not in group:
|
|
raise KeyError("param 'initial_lr' is not specified "
|
|
f"in param_groups[{i}] when resuming an optimizer")
|
|
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
|
|
self.last_epoch = last_epoch
|
|
|
|
# Following https://github.com/pytorch/pytorch/issues/20124
|
|
# We would like to ensure that `lr_scheduler.step()` is called after
|
|
# `optimizer.step()`
|
|
def with_counter(method):
|
|
if getattr(method, '_with_counter', False):
|
|
# `optimizer.step()` has already been replaced, return.
|
|
return method
|
|
|
|
# Keep a weak reference to the optimizer instance to prevent
|
|
# cyclic references.
|
|
instance_ref = weakref.ref(method.__self__)
|
|
# Get the unbound method for the same purpose.
|
|
func = method.__func__
|
|
cls = instance_ref().__class__
|
|
del method
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
instance = instance_ref()
|
|
instance._step_count += 1
|
|
wrapped = func.__get__(instance, cls)
|
|
return wrapped(*args, **kwargs)
|
|
|
|
# Note that the returned function here is no longer a bound method,
|
|
# so attributes like `__func__` and `__self__` no longer exist.
|
|
wrapper._with_counter = True
|
|
return wrapper
|
|
|
|
self.optimizer.step = with_counter(self.optimizer.step)
|
|
self.verbose = verbose
|
|
|
|
self._initial_step()
|
|
|
|
def _initial_step(self):
|
|
"""Initialize step counts and performs a step"""
|
|
self.optimizer._step_count = 0
|
|
self._step_count = 0
|
|
self.step()
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
"""
|
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Args:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
self.__dict__.update(state_dict)
|
|
|
|
def get_last_lr(self):
|
|
""" Return last computed learning rate by current scheduler.
|
|
"""
|
|
return self._last_lr
|
|
|
|
def get_lr(self):
|
|
# Compute learning rate using chainable form of the scheduler
|
|
raise NotImplementedError
|
|
|
|
def print_lr(self, is_verbose, group, lr, epoch=None):
|
|
"""Display the current learning rate.
|
|
"""
|
|
if is_verbose:
|
|
if epoch is None:
|
|
print(f'Adjusting learning rate of group {group} to {lr:.4e}.')
|
|
else:
|
|
epoch_str = ("%.2f" if isinstance(epoch, float) else
|
|
"%.5d") % epoch
|
|
print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.')
|
|
|
|
|
|
def step(self, epoch=None):
|
|
# Raise a warning if old pattern is detected
|
|
# https://github.com/pytorch/pytorch/issues/20124
|
|
if self._step_count == 1:
|
|
if not hasattr(self.optimizer.step, "_with_counter"):
|
|
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
|
|
"initialization. Please, make sure to call `optimizer.step()` before "
|
|
"`lr_scheduler.step()`. See more details at "
|
|
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
|
|
|
|
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
|
|
elif self.optimizer._step_count < 1:
|
|
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
|
|
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
|
|
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
|
|
"will result in PyTorch skipping the first value of the learning rate schedule. "
|
|
"See more details at "
|
|
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
|
|
self._step_count += 1
|
|
|
|
with _enable_get_lr_call(self):
|
|
if epoch is None:
|
|
self.last_epoch += 1
|
|
values = self.get_lr()
|
|
else:
|
|
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
|
|
self.last_epoch = epoch
|
|
if hasattr(self, "_get_closed_form_lr"):
|
|
values = self._get_closed_form_lr()
|
|
else:
|
|
values = self.get_lr()
|
|
|
|
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
|
param_group, lr = data
|
|
param_group['lr'] = lr
|
|
self.print_lr(self.verbose, i, lr, epoch)
|
|
|
|
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
|
|
# Including _LRScheduler for backwards compatibility
|
|
# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
|
|
class _LRScheduler(LRScheduler):
|
|
pass
|
|
|
|
|
|
class _enable_get_lr_call:
|
|
|
|
def __init__(self, o):
|
|
self.o = o
|
|
|
|
def __enter__(self):
|
|
self.o._get_lr_called_within_step = True
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.o._get_lr_called_within_step = False
|
|
|
|
|
|
class LambdaLR(LRScheduler):
|
|
"""Sets the learning rate of each parameter group to the initial lr
|
|
times a given function. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
lr_lambda (function or list): A function which computes a multiplicative
|
|
factor given an integer parameter epoch, or a list of such
|
|
functions, one for each group in optimizer.param_groups.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Assuming optimizer has two groups.
|
|
>>> lambda1 = lambda epoch: epoch // 30
|
|
>>> lambda2 = lambda epoch: 0.95 ** epoch
|
|
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
|
|
self.optimizer = optimizer
|
|
|
|
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
|
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
|
else:
|
|
if len(lr_lambda) != len(optimizer.param_groups):
|
|
raise ValueError(f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}")
|
|
self.lr_lambdas = list(lr_lambda)
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
The learning rate lambda functions will only be saved if they are callable objects
|
|
and not if they are functions or lambdas.
|
|
|
|
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
|
|
"""
|
|
|
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
|
|
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
|
|
|
|
for idx, fn in enumerate(self.lr_lambdas):
|
|
if not isinstance(fn, types.FunctionType):
|
|
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
|
|
|
|
Args:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
|
|
lr_lambdas = state_dict.pop('lr_lambdas')
|
|
self.__dict__.update(state_dict)
|
|
# Restore state_dict keys in order to prevent side effects
|
|
# https://github.com/pytorch/pytorch/issues/32756
|
|
state_dict['lr_lambdas'] = lr_lambdas
|
|
|
|
for idx, fn in enumerate(lr_lambdas):
|
|
if fn is not None:
|
|
self.lr_lambdas[idx].__dict__.update(fn)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.")
|
|
|
|
return [base_lr * lmbda(self.last_epoch)
|
|
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
|
|
|
|
|
|
class MultiplicativeLR(LRScheduler):
|
|
"""Multiply the learning rate of each parameter group by the factor given
|
|
in the specified function. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
lr_lambda (function or list): A function which computes a multiplicative
|
|
factor given an integer parameter epoch, or a list of such
|
|
functions, one for each group in optimizer.param_groups.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> lmbda = lambda epoch: 0.95
|
|
>>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
|
|
self.optimizer = optimizer
|
|
|
|
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
|
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
|
else:
|
|
if len(lr_lambda) != len(optimizer.param_groups):
|
|
raise ValueError(f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}")
|
|
self.lr_lambdas = list(lr_lambda)
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
The learning rate lambda functions will only be saved if they are callable objects
|
|
and not if they are functions or lambdas.
|
|
"""
|
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
|
|
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
|
|
|
|
for idx, fn in enumerate(self.lr_lambdas):
|
|
if not isinstance(fn, types.FunctionType):
|
|
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Args:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
lr_lambdas = state_dict.pop('lr_lambdas')
|
|
self.__dict__.update(state_dict)
|
|
# Restore state_dict keys in order to prevent side effects
|
|
# https://github.com/pytorch/pytorch/issues/32756
|
|
state_dict['lr_lambdas'] = lr_lambdas
|
|
|
|
for idx, fn in enumerate(lr_lambdas):
|
|
if fn is not None:
|
|
self.lr_lambdas[idx].__dict__.update(fn)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch > 0:
|
|
return [group['lr'] * lmbda(self.last_epoch)
|
|
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
|
|
else:
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
|
|
class StepLR(LRScheduler):
|
|
"""Decays the learning rate of each parameter group by gamma every
|
|
step_size epochs. Notice that such decay can happen simultaneously with
|
|
other changes to the learning rate from outside this scheduler. When
|
|
last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
step_size (int): Period of learning rate decay.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
Default: 0.1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.05 if epoch < 30
|
|
>>> # lr = 0.005 if 30 <= epoch < 60
|
|
>>> # lr = 0.0005 if 60 <= epoch < 90
|
|
>>> # ...
|
|
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
|
|
self.step_size = step_size
|
|
self.gamma = gamma
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
return [group['lr'] * self.gamma
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class MultiStepLR(LRScheduler):
|
|
"""Decays the learning rate of each parameter group by gamma once the
|
|
number of epoch reaches one of the milestones. Notice that such decay can
|
|
happen simultaneously with other changes to the learning rate from outside
|
|
this scheduler. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
milestones (list): List of epoch indices. Must be increasing.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
Default: 0.1.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.05 if epoch < 30
|
|
>>> # lr = 0.005 if 30 <= epoch < 80
|
|
>>> # lr = 0.0005 if epoch >= 80
|
|
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
|
|
self.milestones = Counter(milestones)
|
|
self.gamma = gamma
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch not in self.milestones:
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
milestones = sorted(self.milestones.elements())
|
|
return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ConstantLR(LRScheduler):
|
|
"""Decays the learning rate of each parameter group by a small constant factor until the
|
|
number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
|
|
happen simultaneously with other changes to the learning rate from outside this scheduler.
|
|
When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
|
|
total_iters (int): The number of steps that the scheduler decays the learning rate.
|
|
Default: 5.
|
|
last_epoch (int): The index of the last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.025 if epoch == 0
|
|
>>> # lr = 0.025 if epoch == 1
|
|
>>> # lr = 0.025 if epoch == 2
|
|
>>> # lr = 0.025 if epoch == 3
|
|
>>> # lr = 0.05 if epoch >= 4
|
|
>>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False):
|
|
if factor > 1.0 or factor < 0:
|
|
raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
|
|
|
|
self.factor = factor
|
|
self.total_iters = total_iters
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch == 0:
|
|
return [group['lr'] * self.factor for group in self.optimizer.param_groups]
|
|
|
|
if (self.last_epoch > self.total_iters or
|
|
(self.last_epoch != self.total_iters)):
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
if (self.last_epoch == self.total_iters):
|
|
return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class LinearLR(LRScheduler):
|
|
"""Decays the learning rate of each parameter group by linearly changing small
|
|
multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
|
|
Notice that such decay can happen simultaneously with other changes to the learning rate
|
|
from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
start_factor (float): The number we multiply learning rate in the first epoch.
|
|
The multiplication factor changes towards end_factor in the following epochs.
|
|
Default: 1./3.
|
|
end_factor (float): The number we multiply learning rate at the end of linear changing
|
|
process. Default: 1.0.
|
|
total_iters (int): The number of iterations that multiplicative factor reaches to 1.
|
|
Default: 5.
|
|
last_epoch (int): The index of the last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
>>> # lr = 0.025 if epoch == 0
|
|
>>> # lr = 0.03125 if epoch == 1
|
|
>>> # lr = 0.0375 if epoch == 2
|
|
>>> # lr = 0.04375 if epoch == 3
|
|
>>> # lr = 0.05 if epoch >= 4
|
|
>>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
|
|
verbose=False):
|
|
if start_factor > 1.0 or start_factor <= 0:
|
|
raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
|
|
|
|
if end_factor > 1.0 or end_factor < 0:
|
|
raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
|
|
|
|
self.start_factor = start_factor
|
|
self.end_factor = end_factor
|
|
self.total_iters = total_iters
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch == 0:
|
|
return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
|
|
|
|
if self.last_epoch > self.total_iters:
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
|
|
(self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [base_lr * (self.start_factor +
|
|
(self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ExponentialLR(LRScheduler):
|
|
"""Decays the learning rate of each parameter group by gamma every epoch.
|
|
When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
gamma (float): Multiplicative factor of learning rate decay.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
"""
|
|
|
|
def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
|
|
self.gamma = gamma
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch == 0:
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
return [group['lr'] * self.gamma
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [base_lr * self.gamma ** self.last_epoch
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class SequentialLR(LRScheduler):
|
|
"""Receives the list of schedulers that is expected to be called sequentially during
|
|
optimization process and milestone points that provides exact intervals to reflect
|
|
which scheduler is supposed to be called at a given epoch.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
schedulers (list): List of chained schedulers.
|
|
milestones (list): List of integers that reflects milestone points.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): Does nothing.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Assuming optimizer uses lr = 1. for all groups
|
|
>>> # lr = 0.1 if epoch == 0
|
|
>>> # lr = 0.1 if epoch == 1
|
|
>>> # lr = 0.9 if epoch == 2
|
|
>>> # lr = 0.81 if epoch == 3
|
|
>>> # lr = 0.729 if epoch == 4
|
|
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
|
|
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
|
|
>>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
|
|
for scheduler_idx in range(len(schedulers)):
|
|
if schedulers[scheduler_idx].optimizer != optimizer:
|
|
raise ValueError(
|
|
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
|
|
f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
|
|
)
|
|
|
|
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
|
|
raise ValueError(
|
|
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
|
|
f"got schedulers at index {0} and {scheduler_idx} to be different."
|
|
)
|
|
if (len(milestones) != len(schedulers) - 1):
|
|
raise ValueError(
|
|
"Sequential Schedulers expects number of schedulers provided to be one more "
|
|
f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the "
|
|
f"number of milestones to be equal to {len(milestones)}"
|
|
)
|
|
self._schedulers = schedulers
|
|
self._milestones = milestones
|
|
self.last_epoch = last_epoch + 1
|
|
self.optimizer = optimizer
|
|
|
|
# Reset learning rates back to initial values
|
|
for group in self.optimizer.param_groups:
|
|
group["lr"] = group["initial_lr"]
|
|
|
|
# "Undo" the step performed by other schedulers
|
|
for scheduler in self._schedulers:
|
|
scheduler.last_epoch -= 1
|
|
|
|
# Perform the initial step for only the first scheduler
|
|
self._schedulers[0]._initial_step()
|
|
|
|
self._last_lr = schedulers[0].get_last_lr()
|
|
|
|
def step(self):
|
|
self.last_epoch += 1
|
|
idx = bisect_right(self._milestones, self.last_epoch)
|
|
scheduler = self._schedulers[idx]
|
|
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
|
|
scheduler.step(0)
|
|
else:
|
|
scheduler.step()
|
|
|
|
self._last_lr = scheduler.get_last_lr()
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
The wrapped scheduler states will also be saved.
|
|
"""
|
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
|
|
state_dict['_schedulers'] = [None] * len(self._schedulers)
|
|
|
|
for idx, s in enumerate(self._schedulers):
|
|
state_dict['_schedulers'][idx] = s.state_dict()
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Args:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
_schedulers = state_dict.pop('_schedulers')
|
|
self.__dict__.update(state_dict)
|
|
# Restore state_dict keys in order to prevent side effects
|
|
# https://github.com/pytorch/pytorch/issues/32756
|
|
state_dict['_schedulers'] = _schedulers
|
|
|
|
for idx, s in enumerate(_schedulers):
|
|
self._schedulers[idx].load_state_dict(s)
|
|
|
|
|
|
class PolynomialLR(LRScheduler):
|
|
"""Decays the learning rate of each parameter group using a polynomial function
|
|
in the given total_iters. When last_epoch=-1, sets initial lr as lr.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
|
|
power (int): The power of the polynomial. Default: 1.0.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("undefined vars")
|
|
>>> # Assuming optimizer uses lr = 0.001 for all groups
|
|
>>> # lr = 0.001 if epoch == 0
|
|
>>> # lr = 0.00075 if epoch == 1
|
|
>>> # lr = 0.00050 if epoch == 2
|
|
>>> # lr = 0.00025 if epoch == 3
|
|
>>> # lr = 0.0 if epoch >= 4
|
|
>>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0)
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False):
|
|
self.total_iters = total_iters
|
|
self.power = power
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
|
|
return [group["lr"] for group in self.optimizer.param_groups]
|
|
|
|
decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power
|
|
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [
|
|
(
|
|
base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
|
|
)
|
|
for base_lr in self.base_lrs
|
|
]
|
|
|
|
|
|
class CosineAnnealingLR(LRScheduler):
|
|
r"""Set the learning rate of each parameter group using a cosine annealing
|
|
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
|
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
|
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
|
& T_{cur} \neq (2k+1)T_{max}; \\
|
|
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
|
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
|
& T_{cur} = (2k+1)T_{max}.
|
|
\end{aligned}
|
|
|
|
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
|
|
is defined recursively, the learning rate can be simultaneously modified
|
|
outside this scheduler by other operators. If the learning rate is set
|
|
solely by this scheduler, the learning rate at each step becomes:
|
|
|
|
.. math::
|
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
|
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
|
|
|
|
It has been proposed in
|
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
|
implements the cosine annealing part of SGDR, and not the restarts.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
T_max (int): Maximum number of iterations.
|
|
eta_min (float): Minimum learning rate. Default: 0.
|
|
last_epoch (int): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
|
https://arxiv.org/abs/1608.03983
|
|
"""
|
|
|
|
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
|
|
self.T_max = T_max
|
|
self.eta_min = eta_min
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
if self.last_epoch == 0:
|
|
return [group['lr'] for group in self.optimizer.param_groups]
|
|
elif self._step_count == 1 and self.last_epoch > 0:
|
|
return [self.eta_min + (base_lr - self.eta_min) *
|
|
(1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
|
|
for base_lr, group in
|
|
zip(self.base_lrs, self.optimizer.param_groups)]
|
|
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
|
|
return [group['lr'] + (base_lr - self.eta_min) *
|
|
(1 - math.cos(math.pi / self.T_max)) / 2
|
|
for base_lr, group in
|
|
zip(self.base_lrs, self.optimizer.param_groups)]
|
|
return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
|
|
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
|
|
(group['lr'] - self.eta_min) + self.eta_min
|
|
for group in self.optimizer.param_groups]
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [self.eta_min + (base_lr - self.eta_min) *
|
|
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
|
|
for base_lr in self.base_lrs]
|
|
|
|
|
|
class ChainedScheduler(LRScheduler):
|
|
"""Chains list of learning rate schedulers. It takes a list of chainable learning
|
|
rate schedulers and performs consecutive step() functions belonging to them by just
|
|
one call.
|
|
|
|
Args:
|
|
schedulers (list): List of chained schedulers.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> # Assuming optimizer uses lr = 1. for all groups
|
|
>>> # lr = 0.09 if epoch == 0
|
|
>>> # lr = 0.081 if epoch == 1
|
|
>>> # lr = 0.729 if epoch == 2
|
|
>>> # lr = 0.6561 if epoch == 3
|
|
>>> # lr = 0.59049 if epoch >= 4
|
|
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
|
|
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
|
|
>>> scheduler = ChainedScheduler([scheduler1, scheduler2])
|
|
>>> for epoch in range(100):
|
|
>>> train(...)
|
|
>>> validate(...)
|
|
>>> scheduler.step()
|
|
"""
|
|
|
|
def __init__(self, schedulers):
|
|
for scheduler_idx in range(1, len(schedulers)):
|
|
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
|
|
raise ValueError(
|
|
"ChainedScheduler expects all schedulers to belong to the same optimizer, but "
|
|
f"got schedulers at index {0} and {scheduler_idx} to be different"
|
|
)
|
|
self._schedulers = list(schedulers)
|
|
self.optimizer = schedulers[0].optimizer
|
|
self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
|
|
|
|
def step(self):
|
|
for scheduler in self._schedulers:
|
|
scheduler.step()
|
|
self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
|
|
|
|
def state_dict(self):
|
|
"""Returns the state of the scheduler as a :class:`dict`.
|
|
|
|
It contains an entry for every variable in self.__dict__ which
|
|
is not the optimizer.
|
|
The wrapped scheduler states will also be saved.
|
|
"""
|
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
|
|
state_dict['_schedulers'] = [None] * len(self._schedulers)
|
|
|
|
for idx, s in enumerate(self._schedulers):
|
|
state_dict['_schedulers'][idx] = s.state_dict()
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Loads the schedulers state.
|
|
|
|
Args:
|
|
state_dict (dict): scheduler state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
_schedulers = state_dict.pop('_schedulers')
|
|
self.__dict__.update(state_dict)
|
|
# Restore state_dict keys in order to prevent side effects
|
|
# https://github.com/pytorch/pytorch/issues/32756
|
|
state_dict['_schedulers'] = _schedulers
|
|
|
|
for idx, s in enumerate(_schedulers):
|
|
self._schedulers[idx].load_state_dict(s)
|
|
|
|
|
|
class ReduceLROnPlateau:
|
|
"""Reduce learning rate when a metric has stopped improving.
|
|
Models often benefit from reducing the learning rate by a factor
|
|
of 2-10 once learning stagnates. This scheduler reads a metrics
|
|
quantity and if no improvement is seen for a 'patience' number
|
|
of epochs, the learning rate is reduced.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
mode (str): One of `min`, `max`. In `min` mode, lr will
|
|
be reduced when the quantity monitored has stopped
|
|
decreasing; in `max` mode it will be reduced when the
|
|
quantity monitored has stopped increasing. Default: 'min'.
|
|
factor (float): Factor by which the learning rate will be
|
|
reduced. new_lr = lr * factor. Default: 0.1.
|
|
patience (int): Number of epochs with no improvement after
|
|
which learning rate will be reduced. For example, if
|
|
`patience = 2`, then we will ignore the first 2 epochs
|
|
with no improvement, and will only decrease the LR after the
|
|
3rd epoch if the loss still hasn't improved then.
|
|
Default: 10.
|
|
threshold (float): Threshold for measuring the new optimum,
|
|
to only focus on significant changes. Default: 1e-4.
|
|
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
|
|
dynamic_threshold = best * ( 1 + threshold ) in 'max'
|
|
mode or best * ( 1 - threshold ) in `min` mode.
|
|
In `abs` mode, dynamic_threshold = best + threshold in
|
|
`max` mode or best - threshold in `min` mode. Default: 'rel'.
|
|
cooldown (int): Number of epochs to wait before resuming
|
|
normal operation after lr has been reduced. Default: 0.
|
|
min_lr (float or list): A scalar or a list of scalars. A
|
|
lower bound on the learning rate of all param groups
|
|
or each group respectively. Default: 0.
|
|
eps (float): Minimal decay applied to lr. If the difference
|
|
between new and old lr is smaller than eps, the update is
|
|
ignored. Default: 1e-8.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
|
|
>>> for epoch in range(10):
|
|
>>> train(...)
|
|
>>> val_loss = validate(...)
|
|
>>> # Note that step should be called after validate()
|
|
>>> scheduler.step(val_loss)
|
|
"""
|
|
|
|
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
|
|
threshold=1e-4, threshold_mode='rel', cooldown=0,
|
|
min_lr=0, eps=1e-8, verbose=False):
|
|
|
|
if factor >= 1.0:
|
|
raise ValueError('Factor should be < 1.0.')
|
|
self.factor = factor
|
|
|
|
# Attach optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
|
|
self.optimizer = optimizer
|
|
|
|
if isinstance(min_lr, (list, tuple)):
|
|
if len(min_lr) != len(optimizer.param_groups):
|
|
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
|
|
self.min_lrs = list(min_lr)
|
|
else:
|
|
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
|
|
|
self.patience = patience
|
|
self.verbose = verbose
|
|
self.cooldown = cooldown
|
|
self.cooldown_counter = 0
|
|
self.mode = mode
|
|
self.threshold = threshold
|
|
self.threshold_mode = threshold_mode
|
|
self.best = None
|
|
self.num_bad_epochs = None
|
|
self.mode_worse = None # the worse value for the chosen mode
|
|
self.eps = eps
|
|
self.last_epoch = 0
|
|
self._init_is_better(mode=mode, threshold=threshold,
|
|
threshold_mode=threshold_mode)
|
|
self._reset()
|
|
|
|
def _reset(self):
|
|
"""Resets num_bad_epochs counter and cooldown counter."""
|
|
self.best = self.mode_worse
|
|
self.cooldown_counter = 0
|
|
self.num_bad_epochs = 0
|
|
|
|
def step(self, metrics, epoch=None):
|
|
# convert `metrics` to float, in case it's a zero-dim Tensor
|
|
current = float(metrics)
|
|
if epoch is None:
|
|
epoch = self.last_epoch + 1
|
|
else:
|
|
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
|
|
self.last_epoch = epoch
|
|
|
|
if self.is_better(current, self.best):
|
|
self.best = current
|
|
self.num_bad_epochs = 0
|
|
else:
|
|
self.num_bad_epochs += 1
|
|
|
|
if self.in_cooldown:
|
|
self.cooldown_counter -= 1
|
|
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
|
|
|
|
if self.num_bad_epochs > self.patience:
|
|
self._reduce_lr(epoch)
|
|
self.cooldown_counter = self.cooldown
|
|
self.num_bad_epochs = 0
|
|
|
|
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
def _reduce_lr(self, epoch):
|
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
old_lr = float(param_group['lr'])
|
|
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
|
if old_lr - new_lr > self.eps:
|
|
param_group['lr'] = new_lr
|
|
if self.verbose:
|
|
epoch_str = ("%.2f" if isinstance(epoch, float) else
|
|
"%.5d") % epoch
|
|
print(f'Epoch {epoch_str}: reducing learning rate of group {i} to {new_lr:.4e}.')
|
|
|
|
@property
|
|
def in_cooldown(self):
|
|
return self.cooldown_counter > 0
|
|
|
|
def is_better(self, a, best):
|
|
if self.mode == 'min' and self.threshold_mode == 'rel':
|
|
rel_epsilon = 1. - self.threshold
|
|
return a < best * rel_epsilon
|
|
|
|
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
|
return a < best - self.threshold
|
|
|
|
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
|
rel_epsilon = self.threshold + 1.
|
|
return a > best * rel_epsilon
|
|
|
|
else: # mode == 'max' and epsilon_mode == 'abs':
|
|
return a > best + self.threshold
|
|
|
|
def _init_is_better(self, mode, threshold, threshold_mode):
|
|
if mode not in {'min', 'max'}:
|
|
raise ValueError('mode ' + mode + ' is unknown!')
|
|
if threshold_mode not in {'rel', 'abs'}:
|
|
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
|
|
|
|
if mode == 'min':
|
|
self.mode_worse = inf
|
|
else: # mode == 'max':
|
|
self.mode_worse = -inf
|
|
|
|
self.mode = mode
|
|
self.threshold = threshold
|
|
self.threshold_mode = threshold_mode
|
|
|
|
def state_dict(self):
|
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.__dict__.update(state_dict)
|
|
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
|
|
|
|
|
|
class CyclicLR(LRScheduler):
|
|
r"""Sets the learning rate of each parameter group according to
|
|
cyclical learning rate policy (CLR). The policy cycles the learning
|
|
rate between two boundaries with a constant frequency, as detailed in
|
|
the paper `Cyclical Learning Rates for Training Neural Networks`_.
|
|
The distance between the two boundaries can be scaled on a per-iteration
|
|
or per-cycle basis.
|
|
|
|
Cyclical learning rate policy changes the learning rate after every batch.
|
|
`step` should be called after a batch has been used for training.
|
|
|
|
This class has three built-in policies, as put forth in the paper:
|
|
|
|
* "triangular": A basic triangular cycle without amplitude scaling.
|
|
* "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
|
|
* "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
|
|
at each cycle iteration.
|
|
|
|
This implementation was adapted from the github repo: `bckenstler/CLR`_
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
base_lr (float or list): Initial learning rate which is the
|
|
lower boundary in the cycle for each parameter group.
|
|
max_lr (float or list): Upper learning rate boundaries in the cycle
|
|
for each parameter group. Functionally,
|
|
it defines the cycle amplitude (max_lr - base_lr).
|
|
The lr at any cycle is the sum of base_lr
|
|
and some scaling of the amplitude; therefore
|
|
max_lr may not actually be reached depending on
|
|
scaling function.
|
|
step_size_up (int): Number of training iterations in the
|
|
increasing half of a cycle. Default: 2000
|
|
step_size_down (int): Number of training iterations in the
|
|
decreasing half of a cycle. If step_size_down is None,
|
|
it is set to step_size_up. Default: None
|
|
mode (str): One of {triangular, triangular2, exp_range}.
|
|
Values correspond to policies detailed above.
|
|
If scale_fn is not None, this argument is ignored.
|
|
Default: 'triangular'
|
|
gamma (float): Constant in 'exp_range' scaling function:
|
|
gamma**(cycle iterations)
|
|
Default: 1.0
|
|
scale_fn (function): Custom scaling policy defined by a single
|
|
argument lambda function, where
|
|
0 <= scale_fn(x) <= 1 for all x >= 0.
|
|
If specified, then 'mode' is ignored.
|
|
Default: None
|
|
scale_mode (str): {'cycle', 'iterations'}.
|
|
Defines whether scale_fn is evaluated on
|
|
cycle number or cycle iterations (training
|
|
iterations since start of cycle).
|
|
Default: 'cycle'
|
|
cycle_momentum (bool): If ``True``, momentum is cycled inversely
|
|
to learning rate between 'base_momentum' and 'max_momentum'.
|
|
Default: True
|
|
base_momentum (float or list): Lower momentum boundaries in the cycle
|
|
for each parameter group. Note that momentum is cycled inversely
|
|
to learning rate; at the peak of a cycle, momentum is
|
|
'base_momentum' and learning rate is 'max_lr'.
|
|
Default: 0.8
|
|
max_momentum (float or list): Upper momentum boundaries in the cycle
|
|
for each parameter group. Functionally,
|
|
it defines the cycle amplitude (max_momentum - base_momentum).
|
|
The momentum at any cycle is the difference of max_momentum
|
|
and some scaling of the amplitude; therefore
|
|
base_momentum may not actually be reached depending on
|
|
scaling function. Note that momentum is cycled inversely
|
|
to learning rate; at the start of a cycle, momentum is 'max_momentum'
|
|
and learning rate is 'base_lr'
|
|
Default: 0.9
|
|
last_epoch (int): The index of the last batch. This parameter is used when
|
|
resuming a training job. Since `step()` should be invoked after each
|
|
batch instead of after each epoch, this number represents the total
|
|
number of *batches* computed, not the total number of epochs computed.
|
|
When last_epoch=-1, the schedule is started from the beginning.
|
|
Default: -1
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
|
|
>>> data_loader = torch.utils.data.DataLoader(...)
|
|
>>> for epoch in range(10):
|
|
>>> for batch in data_loader:
|
|
>>> train_batch(...)
|
|
>>> scheduler.step()
|
|
|
|
|
|
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
|
|
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
|
|
"""
|
|
|
|
def __init__(self,
|
|
optimizer,
|
|
base_lr,
|
|
max_lr,
|
|
step_size_up=2000,
|
|
step_size_down=None,
|
|
mode='triangular',
|
|
gamma=1.,
|
|
scale_fn=None,
|
|
scale_mode='cycle',
|
|
cycle_momentum=True,
|
|
base_momentum=0.8,
|
|
max_momentum=0.9,
|
|
last_epoch=-1,
|
|
verbose=False):
|
|
|
|
# Attach optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
|
|
self.optimizer = optimizer
|
|
|
|
base_lrs = self._format_param('base_lr', optimizer, base_lr)
|
|
if last_epoch == -1:
|
|
for lr, group in zip(base_lrs, optimizer.param_groups):
|
|
group['lr'] = lr
|
|
|
|
self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
|
|
|
|
step_size_up = float(step_size_up)
|
|
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
|
|
self.total_size = step_size_up + step_size_down
|
|
self.step_ratio = step_size_up / self.total_size
|
|
|
|
if mode not in ['triangular', 'triangular2', 'exp_range'] \
|
|
and scale_fn is None:
|
|
raise ValueError('mode is invalid and scale_fn is None')
|
|
|
|
self.mode = mode
|
|
self.gamma = gamma
|
|
|
|
self._scale_fn_ref = None
|
|
self._scale_fn_custom = scale_fn
|
|
self.scale_mode = scale_mode
|
|
self._init_scale_fn()
|
|
|
|
self.cycle_momentum = cycle_momentum
|
|
if cycle_momentum:
|
|
if 'momentum' not in optimizer.defaults:
|
|
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
|
|
|
|
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
|
if last_epoch == -1:
|
|
for momentum, group in zip(base_momentums, optimizer.param_groups):
|
|
group['momentum'] = momentum
|
|
self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
|
|
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
|
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
self.base_lrs = base_lrs
|
|
|
|
def _init_scale_fn(self):
|
|
if self._scale_fn_custom is not None:
|
|
return
|
|
if self.mode == 'triangular':
|
|
self._scale_fn_ref = self._triangular_scale_fn
|
|
self.scale_mode = 'cycle'
|
|
elif self.mode == 'triangular2':
|
|
self._scale_fn_ref = self._triangular2_scale_fn
|
|
self.scale_mode = 'cycle'
|
|
elif self.mode == 'exp_range':
|
|
self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma)
|
|
self.scale_mode = 'iterations'
|
|
|
|
def _format_param(self, name, optimizer, param):
|
|
"""Return correctly formatted lr/momentum for each param group."""
|
|
if isinstance(param, (list, tuple)):
|
|
if len(param) != len(optimizer.param_groups):
|
|
raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}")
|
|
return param
|
|
else:
|
|
return [param] * len(optimizer.param_groups)
|
|
|
|
def scale_fn(self, x):
|
|
if self._scale_fn_custom is not None:
|
|
return self._scale_fn_custom(x)
|
|
else:
|
|
return self._scale_fn_ref(x) # static method
|
|
|
|
@staticmethod
|
|
def _triangular_scale_fn(x):
|
|
return 1.
|
|
|
|
@staticmethod
|
|
def _triangular2_scale_fn(x):
|
|
return 1 / (2. ** (x - 1))
|
|
|
|
@staticmethod
|
|
def _exp_range_scale_fn(gamma, x):
|
|
return gamma ** x
|
|
|
|
def get_lr(self):
|
|
"""Calculates the learning rate at batch index. This function treats
|
|
`self.last_epoch` as the last batch index.
|
|
|
|
If `self.cycle_momentum` is ``True``, this function has a side effect of
|
|
updating the optimizer's momentum.
|
|
"""
|
|
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
cycle = math.floor(1 + self.last_epoch / self.total_size)
|
|
x = 1. + self.last_epoch / self.total_size - cycle
|
|
if x <= self.step_ratio:
|
|
scale_factor = x / self.step_ratio
|
|
else:
|
|
scale_factor = (x - 1) / (self.step_ratio - 1)
|
|
|
|
lrs = []
|
|
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
|
|
base_height = (max_lr - base_lr) * scale_factor
|
|
if self.scale_mode == 'cycle':
|
|
lr = base_lr + base_height * self.scale_fn(cycle)
|
|
else:
|
|
lr = base_lr + base_height * self.scale_fn(self.last_epoch)
|
|
lrs.append(lr)
|
|
|
|
if self.cycle_momentum:
|
|
momentums = []
|
|
for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
|
|
base_height = (max_momentum - base_momentum) * scale_factor
|
|
if self.scale_mode == 'cycle':
|
|
momentum = max_momentum - base_height * self.scale_fn(cycle)
|
|
else:
|
|
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
|
|
momentums.append(momentum)
|
|
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
|
|
param_group['momentum'] = momentum
|
|
|
|
return lrs
|
|
|
|
def state_dict(self):
|
|
state = super().state_dict()
|
|
# We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled
|
|
state.pop("_scale_fn_ref")
|
|
return state
|
|
|
|
def load_state_dict(self, state_dict):
|
|
super().load_state_dict(state_dict)
|
|
self._init_scale_fn()
|
|
|
|
|
|
|
|
class CosineAnnealingWarmRestarts(LRScheduler):
|
|
r"""Set the learning rate of each parameter group using a cosine annealing
|
|
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
|
|
is the number of epochs since the last restart and :math:`T_{i}` is the number
|
|
of epochs between two warm restarts in SGDR:
|
|
|
|
.. math::
|
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
|
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
|
|
|
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
|
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
|
|
|
It has been proposed in
|
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
T_0 (int): Number of iterations for the first restart.
|
|
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
|
|
eta_min (float, optional): Minimum learning rate. Default: 0.
|
|
last_epoch (int, optional): The index of last epoch. Default: -1.
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
|
https://arxiv.org/abs/1608.03983
|
|
"""
|
|
|
|
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
|
|
if T_0 <= 0 or not isinstance(T_0, int):
|
|
raise ValueError(f"Expected positive integer T_0, but got {T_0}")
|
|
if T_mult < 1 or not isinstance(T_mult, int):
|
|
raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
|
|
if not isinstance(eta_min, (float, int)):
|
|
raise ValueError(f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}")
|
|
self.T_0 = T_0
|
|
self.T_i = T_0
|
|
self.T_mult = T_mult
|
|
self.eta_min = eta_min
|
|
self.T_cur = last_epoch
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
|
|
for base_lr in self.base_lrs]
|
|
|
|
def step(self, epoch=None):
|
|
"""Step could be called after every batch update
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("Undefined vars")
|
|
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
|
|
>>> iters = len(dataloader)
|
|
>>> for epoch in range(20):
|
|
>>> for i, sample in enumerate(dataloader):
|
|
>>> inputs, labels = sample['inputs'], sample['labels']
|
|
>>> optimizer.zero_grad()
|
|
>>> outputs = net(inputs)
|
|
>>> loss = criterion(outputs, labels)
|
|
>>> loss.backward()
|
|
>>> optimizer.step()
|
|
>>> scheduler.step(epoch + i / iters)
|
|
|
|
This function can be called in an interleaved way.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("Undefined vars")
|
|
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
|
|
>>> for epoch in range(20):
|
|
>>> scheduler.step()
|
|
>>> scheduler.step(26)
|
|
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
|
|
"""
|
|
|
|
if epoch is None and self.last_epoch < 0:
|
|
epoch = 0
|
|
|
|
if epoch is None:
|
|
epoch = self.last_epoch + 1
|
|
self.T_cur = self.T_cur + 1
|
|
if self.T_cur >= self.T_i:
|
|
self.T_cur = self.T_cur - self.T_i
|
|
self.T_i = self.T_i * self.T_mult
|
|
else:
|
|
if epoch < 0:
|
|
raise ValueError(f"Expected non-negative epoch, but got {epoch}")
|
|
if epoch >= self.T_0:
|
|
if self.T_mult == 1:
|
|
self.T_cur = epoch % self.T_0
|
|
else:
|
|
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
|
|
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
|
|
self.T_i = self.T_0 * self.T_mult ** (n)
|
|
else:
|
|
self.T_i = self.T_0
|
|
self.T_cur = epoch
|
|
self.last_epoch = math.floor(epoch)
|
|
|
|
class _enable_get_lr_call:
|
|
|
|
def __init__(self, o):
|
|
self.o = o
|
|
|
|
def __enter__(self):
|
|
self.o._get_lr_called_within_step = True
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.o._get_lr_called_within_step = False
|
|
return self
|
|
|
|
with _enable_get_lr_call(self):
|
|
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
|
|
param_group, lr = data
|
|
param_group['lr'] = lr
|
|
self.print_lr(self.verbose, i, lr, epoch)
|
|
|
|
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
|
|
class OneCycleLR(LRScheduler):
|
|
r"""Sets the learning rate of each parameter group according to the
|
|
1cycle learning rate policy. The 1cycle policy anneals the learning
|
|
rate from an initial learning rate to some maximum learning rate and then
|
|
from that maximum learning rate to some minimum learning rate much lower
|
|
than the initial learning rate.
|
|
This policy was initially described in the paper `Super-Convergence:
|
|
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
|
|
|
The 1cycle learning rate policy changes the learning rate after every batch.
|
|
`step` should be called after a batch has been used for training.
|
|
|
|
This scheduler is not chainable.
|
|
|
|
Note also that the total number of steps in the cycle can be determined in one
|
|
of two ways (listed in order of precedence):
|
|
|
|
#. A value for total_steps is explicitly provided.
|
|
#. A number of epochs (epochs) and a number of steps per epoch
|
|
(steps_per_epoch) are provided.
|
|
In this case, the number of total steps is inferred by
|
|
total_steps = epochs * steps_per_epoch
|
|
|
|
You must either provide a value for total_steps or provide a value for both
|
|
epochs and steps_per_epoch.
|
|
|
|
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
|
|
claims that "unpublished work has shown even better results by using only two phases". To
|
|
mimic the behaviour of the original paper instead, set ``three_phase=True``.
|
|
|
|
Args:
|
|
optimizer (Optimizer): Wrapped optimizer.
|
|
max_lr (float or list): Upper learning rate boundaries in the cycle
|
|
for each parameter group.
|
|
total_steps (int): The total number of steps in the cycle. Note that
|
|
if a value is not provided here, then it must be inferred by providing
|
|
a value for epochs and steps_per_epoch.
|
|
Default: None
|
|
epochs (int): The number of epochs to train for. This is used along
|
|
with steps_per_epoch in order to infer the total number of steps in the cycle
|
|
if a value for total_steps is not provided.
|
|
Default: None
|
|
steps_per_epoch (int): The number of steps per epoch to train for. This is
|
|
used along with epochs in order to infer the total number of steps in the
|
|
cycle if a value for total_steps is not provided.
|
|
Default: None
|
|
pct_start (float): The percentage of the cycle (in number of steps) spent
|
|
increasing the learning rate.
|
|
Default: 0.3
|
|
anneal_strategy (str): {'cos', 'linear'}
|
|
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
|
|
linear annealing.
|
|
Default: 'cos'
|
|
cycle_momentum (bool): If ``True``, momentum is cycled inversely
|
|
to learning rate between 'base_momentum' and 'max_momentum'.
|
|
Default: True
|
|
base_momentum (float or list): Lower momentum boundaries in the cycle
|
|
for each parameter group. Note that momentum is cycled inversely
|
|
to learning rate; at the peak of a cycle, momentum is
|
|
'base_momentum' and learning rate is 'max_lr'.
|
|
Default: 0.85
|
|
max_momentum (float or list): Upper momentum boundaries in the cycle
|
|
for each parameter group. Functionally,
|
|
it defines the cycle amplitude (max_momentum - base_momentum).
|
|
Note that momentum is cycled inversely
|
|
to learning rate; at the start of a cycle, momentum is 'max_momentum'
|
|
and learning rate is 'base_lr'
|
|
Default: 0.95
|
|
div_factor (float): Determines the initial learning rate via
|
|
initial_lr = max_lr/div_factor
|
|
Default: 25
|
|
final_div_factor (float): Determines the minimum learning rate via
|
|
min_lr = initial_lr/final_div_factor
|
|
Default: 1e4
|
|
three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
|
|
learning rate according to 'final_div_factor' instead of modifying the second
|
|
phase (the first two phases will be symmetrical about the step indicated by
|
|
'pct_start').
|
|
last_epoch (int): The index of the last batch. This parameter is used when
|
|
resuming a training job. Since `step()` should be invoked after each
|
|
batch instead of after each epoch, this number represents the total
|
|
number of *batches* computed, not the total number of epochs computed.
|
|
When last_epoch=-1, the schedule is started from the beginning.
|
|
Default: -1
|
|
verbose (bool): If ``True``, prints a message to stdout for
|
|
each update. Default: ``False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> data_loader = torch.utils.data.DataLoader(...)
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
|
|
>>> for epoch in range(10):
|
|
>>> for batch in data_loader:
|
|
>>> train_batch(...)
|
|
>>> optimizer.step()
|
|
>>> scheduler.step()
|
|
|
|
|
|
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
|
https://arxiv.org/abs/1708.07120
|
|
"""
|
|
def __init__(self,
|
|
optimizer,
|
|
max_lr,
|
|
total_steps=None,
|
|
epochs=None,
|
|
steps_per_epoch=None,
|
|
pct_start=0.3,
|
|
anneal_strategy='cos',
|
|
cycle_momentum=True,
|
|
base_momentum=0.85,
|
|
max_momentum=0.95,
|
|
div_factor=25.,
|
|
final_div_factor=1e4,
|
|
three_phase=False,
|
|
last_epoch=-1,
|
|
verbose=False):
|
|
|
|
# Validate optimizer
|
|
if not isinstance(optimizer, Optimizer):
|
|
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
|
|
self.optimizer = optimizer
|
|
|
|
# Validate total_steps
|
|
if total_steps is None and epochs is None and steps_per_epoch is None:
|
|
raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
|
|
elif total_steps is not None:
|
|
if total_steps <= 0 or not isinstance(total_steps, int):
|
|
raise ValueError(f"Expected positive integer total_steps, but got {total_steps}")
|
|
self.total_steps = total_steps
|
|
else:
|
|
if epochs <= 0 or not isinstance(epochs, int):
|
|
raise ValueError(f"Expected positive integer epochs, but got {epochs}")
|
|
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
|
|
raise ValueError(f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}")
|
|
self.total_steps = epochs * steps_per_epoch
|
|
|
|
if three_phase:
|
|
self._schedule_phases = [
|
|
{
|
|
'end_step': float(pct_start * self.total_steps) - 1,
|
|
'start_lr': 'initial_lr',
|
|
'end_lr': 'max_lr',
|
|
'start_momentum': 'max_momentum',
|
|
'end_momentum': 'base_momentum',
|
|
},
|
|
{
|
|
'end_step': float(2 * pct_start * self.total_steps) - 2,
|
|
'start_lr': 'max_lr',
|
|
'end_lr': 'initial_lr',
|
|
'start_momentum': 'base_momentum',
|
|
'end_momentum': 'max_momentum',
|
|
},
|
|
{
|
|
'end_step': self.total_steps - 1,
|
|
'start_lr': 'initial_lr',
|
|
'end_lr': 'min_lr',
|
|
'start_momentum': 'max_momentum',
|
|
'end_momentum': 'max_momentum',
|
|
},
|
|
]
|
|
else:
|
|
self._schedule_phases = [
|
|
{
|
|
'end_step': float(pct_start * self.total_steps) - 1,
|
|
'start_lr': 'initial_lr',
|
|
'end_lr': 'max_lr',
|
|
'start_momentum': 'max_momentum',
|
|
'end_momentum': 'base_momentum',
|
|
},
|
|
{
|
|
'end_step': self.total_steps - 1,
|
|
'start_lr': 'max_lr',
|
|
'end_lr': 'min_lr',
|
|
'start_momentum': 'base_momentum',
|
|
'end_momentum': 'max_momentum',
|
|
},
|
|
]
|
|
|
|
# Validate pct_start
|
|
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
|
|
raise ValueError(f"Expected float between 0 and 1 pct_start, but got {pct_start}")
|
|
|
|
# Validate anneal_strategy
|
|
if anneal_strategy not in ['cos', 'linear']:
|
|
raise ValueError(f"anneal_strategy must by one of 'cos' or 'linear', instead got {anneal_strategy}")
|
|
elif anneal_strategy == 'cos':
|
|
self.anneal_func = self._annealing_cos
|
|
elif anneal_strategy == 'linear':
|
|
self.anneal_func = self._annealing_linear
|
|
|
|
# Initialize learning rate variables
|
|
max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
|
|
if last_epoch == -1:
|
|
for idx, group in enumerate(self.optimizer.param_groups):
|
|
group['initial_lr'] = max_lrs[idx] / div_factor
|
|
group['max_lr'] = max_lrs[idx]
|
|
group['min_lr'] = group['initial_lr'] / final_div_factor
|
|
|
|
# Initialize momentum variables
|
|
self.cycle_momentum = cycle_momentum
|
|
if self.cycle_momentum:
|
|
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
|
|
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
|
|
self.use_beta1 = 'betas' in self.optimizer.defaults
|
|
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
|
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
|
if last_epoch == -1:
|
|
for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
|
|
if self.use_beta1:
|
|
group['betas'] = (m_momentum, *group['betas'][1:])
|
|
else:
|
|
group['momentum'] = m_momentum
|
|
group['max_momentum'] = m_momentum
|
|
group['base_momentum'] = b_momentum
|
|
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def _format_param(self, name, optimizer, param):
|
|
"""Return correctly formatted lr/momentum for each param group."""
|
|
if isinstance(param, (list, tuple)):
|
|
if len(param) != len(optimizer.param_groups):
|
|
raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}")
|
|
return param
|
|
else:
|
|
return [param] * len(optimizer.param_groups)
|
|
|
|
@staticmethod
|
|
def _annealing_cos(start, end, pct):
|
|
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
|
cos_out = math.cos(math.pi * pct) + 1
|
|
return end + (start - end) / 2.0 * cos_out
|
|
|
|
@staticmethod
|
|
def _annealing_linear(start, end, pct):
|
|
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
|
|
return (end - start) * pct + start
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
warnings.warn("To get the last learning rate computed by the scheduler, "
|
|
"please use `get_last_lr()`.", UserWarning)
|
|
|
|
lrs = []
|
|
step_num = self.last_epoch
|
|
|
|
if step_num > self.total_steps:
|
|
raise ValueError("Tried to step {} times. The specified number of total steps is {}"
|
|
.format(step_num, self.total_steps))
|
|
|
|
for group in self.optimizer.param_groups:
|
|
start_step = 0
|
|
for i, phase in enumerate(self._schedule_phases):
|
|
end_step = phase['end_step']
|
|
if step_num <= end_step or i == len(self._schedule_phases) - 1:
|
|
pct = (step_num - start_step) / (end_step - start_step)
|
|
computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
|
|
if self.cycle_momentum:
|
|
computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
|
|
break
|
|
start_step = phase['end_step']
|
|
|
|
lrs.append(computed_lr)
|
|
if self.cycle_momentum:
|
|
if self.use_beta1:
|
|
group['betas'] = (computed_momentum, *group['betas'][1:])
|
|
else:
|
|
group['momentum'] = computed_momentum
|
|
|
|
return lrs
|