mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[optim] prevent unintended aliasing in lr_scheduler; update type annotations/docs (#163120)
1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor. 2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s. 3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs. 4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation. Fixes #163103 LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this: ```python self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups] ``` This PR adds a helper to address this: ```python def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]: """Create a list containing group[key] for each optimizer param_group. Prevents aliasing when group[key] could be a Tensor. Raises a KeyError when group[key] does not exist. """ return [ group[key].clone() if isinstance(group[key], Tensor) else group[key] for group in optimizer.param_groups ] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163120 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
ad869c58f5
commit
2b6a74abf1
|
|
@ -77,7 +77,7 @@ class TestLRScheduler(TestCase):
|
||||||
self.opt = SGD(
|
self.opt = SGD(
|
||||||
[
|
[
|
||||||
{"params": self.net.conv1.parameters()},
|
{"params": self.net.conv1.parameters()},
|
||||||
{"params": self.net.conv2.parameters(), "lr": 0.5},
|
{"params": self.net.conv2.parameters(), "lr": torch.tensor(0.5)},
|
||||||
],
|
],
|
||||||
lr=0.05,
|
lr=0.05,
|
||||||
)
|
)
|
||||||
|
|
@ -2530,7 +2530,7 @@ class TestLRScheduler(TestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_constant_initial_lr(self, LRClass):
|
def test_constant_initial_lr(self, LRClass):
|
||||||
# Test that the initial learning rate is constant
|
# Test that the initial learning rate is constant and that it does not alias base_lrs
|
||||||
lr = torch.as_tensor(0.1)
|
lr = torch.as_tensor(0.1)
|
||||||
opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
|
opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
|
||||||
sch = LRClass(opt)
|
sch = LRClass(opt)
|
||||||
|
|
@ -2544,6 +2544,7 @@ class TestLRScheduler(TestCase):
|
||||||
for group, ori_group in zip(opt.param_groups, ori_param_groups):
|
for group, ori_group in zip(opt.param_groups, ori_param_groups):
|
||||||
self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
|
self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
|
||||||
self.assertEqual(sch.base_lrs, [0.1])
|
self.assertEqual(sch.base_lrs, [0.1])
|
||||||
|
self.assertIsNot(sch.base_lrs[0], group["initial_lr"])
|
||||||
|
|
||||||
def test_constant_initial_params_cyclelr(self):
|
def test_constant_initial_params_cyclelr(self):
|
||||||
# Test that the initial learning rate is constant
|
# Test that the initial learning rate is constant
|
||||||
|
|
|
||||||
|
|
@ -79,9 +79,20 @@ def _format_param(name: str, optimizer: Optimizer, param):
|
||||||
return list(map(_copy, param))
|
return list(map(_copy, param))
|
||||||
|
|
||||||
|
|
||||||
|
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
|
||||||
|
"""Create a list containing group[key] for each optimizer param_group.
|
||||||
|
Prevents aliasing when group[key] could be a Tensor.
|
||||||
|
Raises a KeyError when group[key] does not exist.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
group[key].clone() if isinstance(group[key], Tensor) else group[key]
|
||||||
|
for group in optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
|
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
|
||||||
"""Set param_group[key] to val without aliasing or assignment when they're both tensors.
|
"""Set param_group[key] to val without aliasing or assignment when they're
|
||||||
Raises a KeyError if param_group[key] does not exist.
|
both tensors. Raises a KeyError if param_group[key] does not exist.
|
||||||
"""
|
"""
|
||||||
if isinstance(param_group[key], Tensor):
|
if isinstance(param_group[key], Tensor):
|
||||||
param_group[key].fill_(_to_scalar(val))
|
param_group[key].fill_(_to_scalar(val))
|
||||||
|
|
@ -90,7 +101,25 @@ def _update_param_group_val(param_group: dict[str, Any], key: str, val: float |
|
||||||
|
|
||||||
|
|
||||||
class LRScheduler:
|
class LRScheduler:
|
||||||
r"""Adjusts the learning rate during optimization."""
|
r"""Base class for all learning rate schedulers.
|
||||||
|
|
||||||
|
Subclasses implement :meth:`get_lr` and optionally override :meth:`step` to
|
||||||
|
define scheduling behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): The optimizer this scheduler will adjust the
|
||||||
|
learning rates of.
|
||||||
|
last_epoch (int): Index of the last epoch seen by the scheduler. Use
|
||||||
|
``-1`` (default) to initialize the scheduler. Only use a non-default
|
||||||
|
value when restoring this scheduler from a saved checkpoint.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
Initializing a scheduler overwrites its optimizer's
|
||||||
|
``param_group["lr"]``\s. When restoring a checkpoint, initialize the
|
||||||
|
scheduler **before** calling your optimizer's
|
||||||
|
:meth:`~torch.optim.Optimizer.load_state_dict` to avoid overwriting the
|
||||||
|
loaded learning rates.
|
||||||
|
"""
|
||||||
|
|
||||||
_get_lr_called_within_step: bool = False
|
_get_lr_called_within_step: bool = False
|
||||||
_is_initial: bool = False
|
_is_initial: bool = False
|
||||||
|
|
@ -121,9 +150,9 @@ class LRScheduler:
|
||||||
"1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n"
|
"1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n"
|
||||||
"2. You're using last_epoch >= 0 for a fresh training run (not recommended)"
|
"2. You're using last_epoch >= 0 for a fresh training run (not recommended)"
|
||||||
)
|
)
|
||||||
self.base_lrs: list[float] = [
|
self.base_lrs: list[float | Tensor] = _param_groups_val_list(
|
||||||
group["initial_lr"] for group in optimizer.param_groups
|
optimizer, "initial_lr"
|
||||||
]
|
)
|
||||||
self.last_epoch = last_epoch
|
self.last_epoch = last_epoch
|
||||||
|
|
||||||
# Following https://github.com/pytorch/pytorch/issues/20124
|
# Following https://github.com/pytorch/pytorch/issues/20124
|
||||||
|
|
@ -161,7 +190,7 @@ class LRScheduler:
|
||||||
def state_dict(self) -> dict[str, Any]:
|
def state_dict(self) -> dict[str, Any]:
|
||||||
"""Return the state of the scheduler as a :class:`dict`.
|
"""Return the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
It contains an entry for every variable in self.__dict__ which
|
It contains an entry for every variable in ``self.__dict__`` which
|
||||||
is not the optimizer.
|
is not the optimizer.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
|
|
@ -177,16 +206,58 @@ class LRScheduler:
|
||||||
"""
|
"""
|
||||||
self.__dict__.update(state_dict)
|
self.__dict__.update(state_dict)
|
||||||
|
|
||||||
def get_last_lr(self) -> list[float]:
|
def get_last_lr(self) -> list[float | Tensor]:
|
||||||
"""Return last computed learning rate by current scheduler."""
|
r"""Get the most recent learning rates computed by this scheduler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates with entries
|
||||||
|
for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`, with the same types as
|
||||||
|
their ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
|
# We always update self._last_lr with _param_groups_val_list, so it's a
|
||||||
|
# .clone() of the group["lr"]s. If we didn't do this, the user could
|
||||||
|
# corrupt their learning rates by modifying the outputs in place.
|
||||||
return self._last_lr
|
return self._last_lr
|
||||||
|
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute learning rate using chainable form of the scheduler."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def step(self, epoch: Optional[int] = None) -> None:
|
def step(self, epoch: Optional[int] = None) -> None:
|
||||||
"""Perform a step."""
|
"""Step the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int, optional):
|
||||||
|
.. deprecated:: 1.4
|
||||||
|
If provided, sets :attr:`last_epoch` to ``epoch`` and uses
|
||||||
|
:meth:`_get_closed_form_lr` if it is available. This is not
|
||||||
|
universally supported. Use :meth:`step` without arguments
|
||||||
|
instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Call this method after calling the optimizer's
|
||||||
|
:meth:`~torch.optim.Optimizer.step`.
|
||||||
|
"""
|
||||||
# Raise a warning if old pattern is detected
|
# Raise a warning if old pattern is detected
|
||||||
# https://github.com/pytorch/pytorch/issues/20124
|
# https://github.com/pytorch/pytorch/issues/20124
|
||||||
if self._step_count == 1:
|
if self._step_count == 1:
|
||||||
|
|
@ -224,16 +295,18 @@ class LRScheduler:
|
||||||
else:
|
else:
|
||||||
self.last_epoch = epoch
|
self.last_epoch = epoch
|
||||||
if hasattr(self, "_get_closed_form_lr"):
|
if hasattr(self, "_get_closed_form_lr"):
|
||||||
values = cast(list[float], self._get_closed_form_lr())
|
values = cast(
|
||||||
|
list[Union[float, Tensor]], self._get_closed_form_lr()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
values = self.get_lr()
|
values = self.get_lr()
|
||||||
|
|
||||||
for param_group, lr in zip(self.optimizer.param_groups, values):
|
for param_group, lr in zip(self.optimizer.param_groups, values):
|
||||||
_update_param_group_val(param_group, "lr", lr)
|
_update_param_group_val(param_group, "lr", lr)
|
||||||
|
|
||||||
self._last_lr: list[float] = [
|
self._last_lr: list[float | Tensor] = _param_groups_val_list(
|
||||||
group["lr"] for group in self.optimizer.param_groups
|
self.optimizer, "lr"
|
||||||
]
|
)
|
||||||
|
|
||||||
|
|
||||||
def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None:
|
def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None:
|
||||||
|
|
@ -333,8 +406,7 @@ class LambdaLR(LRScheduler):
|
||||||
def state_dict(self) -> dict[str, Any]:
|
def state_dict(self) -> dict[str, Any]:
|
||||||
"""Return the state of the scheduler as a :class:`dict`.
|
"""Return the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
It contains an entry for every variable in self.__dict__ which
|
It contains an entry for every variable in ``self.__dict__`` which is not the optimizer.
|
||||||
is not the optimizer.
|
|
||||||
The learning rate lambda functions will only be saved if they are callable objects
|
The learning rate lambda functions will only be saved if they are callable objects
|
||||||
and not if they are functions or lambdas.
|
and not if they are functions or lambdas.
|
||||||
|
|
||||||
|
|
@ -374,8 +446,26 @@ class LambdaLR(LRScheduler):
|
||||||
self.lr_lambdas[idx].__dict__.update(fn)
|
self.lr_lambdas[idx].__dict__.update(fn)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute learning rate."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Scales the :attr:`base_lrs` by the outputs of the :attr:`lr_lambdas` at
|
||||||
|
:attr:`last_epoch`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -436,7 +526,7 @@ class MultiplicativeLR(LRScheduler):
|
||||||
def state_dict(self) -> dict[str, Any]:
|
def state_dict(self) -> dict[str, Any]:
|
||||||
"""Return the state of the scheduler as a :class:`dict`.
|
"""Return the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
It contains an entry for every variable in self.__dict__ which
|
It contains an entry for every variable in ``self.__dict__`` which
|
||||||
is not the optimizer.
|
is not the optimizer.
|
||||||
The learning rate lambda functions will only be saved if they are callable objects
|
The learning rate lambda functions will only be saved if they are callable objects
|
||||||
and not if they are functions or lambdas.
|
and not if they are functions or lambdas.
|
||||||
|
|
@ -473,8 +563,27 @@ class MultiplicativeLR(LRScheduler):
|
||||||
self.lr_lambdas[idx].__dict__.update(fn)
|
self.lr_lambdas[idx].__dict__.update(fn)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the learning rate of each parameter group."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Scales the current ``group["lr"]``\s in each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` by the outputs of the
|
||||||
|
:attr:`lr_lambdas` at :attr:`last_epoch`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
if not self._is_initial:
|
if not self._is_initial:
|
||||||
|
|
@ -483,7 +592,7 @@ class MultiplicativeLR(LRScheduler):
|
||||||
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
|
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
return _param_groups_val_list(self.optimizer, "lr")
|
||||||
|
|
||||||
|
|
||||||
class StepLR(LRScheduler):
|
class StepLR(LRScheduler):
|
||||||
|
|
@ -527,15 +636,46 @@ class StepLR(LRScheduler):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the learning rate of each parameter group."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
If the current epoch is a non-zero multiple of :attr:`step_size`, we
|
||||||
|
scale the current ``group["lr"]``\s in the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
|
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
return _param_groups_val_list(self.optimizer, "lr")
|
||||||
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
|
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
def _get_closed_form_lr(self) -> list[float]:
|
def _get_closed_form_lr(self) -> list[float | Tensor]:
|
||||||
|
r"""Compute learning rates for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
|
||||||
|
a closed-form formula.
|
||||||
|
|
||||||
|
Uses :attr:`base_lrs` to compute learning rates. This method is called
|
||||||
|
when an epoch is passed to :meth:`step`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
base_lr * self.gamma ** (self.last_epoch // self.step_size)
|
base_lr * self.gamma ** (self.last_epoch // self.step_size)
|
||||||
for base_lr in self.base_lrs
|
for base_lr in self.base_lrs
|
||||||
|
|
@ -582,18 +722,53 @@ class MultiStepLR(LRScheduler):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the learning rate of each parameter group."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
If the current epoch is in :attr:`milestones`, decays the
|
||||||
|
``group["lr"]``\s in the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If the current epoch appears in :attr:`milestones` ``n`` times, we
|
||||||
|
scale by :attr:`gamma` to the power of ``n``
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
if self.last_epoch not in self.milestones:
|
if self.last_epoch not in self.milestones:
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
return _param_groups_val_list(self.optimizer, "lr")
|
||||||
return [
|
return [
|
||||||
group["lr"] * self.gamma ** self.milestones[self.last_epoch]
|
group["lr"] * self.gamma ** self.milestones[self.last_epoch]
|
||||||
for group in self.optimizer.param_groups
|
for group in self.optimizer.param_groups
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
def _get_closed_form_lr(self):
|
||||||
|
r"""Compute learning rates for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
|
||||||
|
a closed-form formula.
|
||||||
|
|
||||||
|
Uses :attr:`base_lrs` to compute learning rates. This method is called
|
||||||
|
when an epoch is passed to :meth:`step`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
milestones = sorted(self.milestones.elements())
|
milestones = sorted(self.milestones.elements())
|
||||||
return [
|
return [
|
||||||
base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
|
base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
|
||||||
|
|
@ -651,21 +826,53 @@ class ConstantLR(LRScheduler):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the learning rate of each parameter group."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
When :attr:`last_epoch` is 0, this method scales the ``group["lr"]``\s
|
||||||
|
in each of the optimizer's :attr:`~torch.optim.Optimizer.param_groups`
|
||||||
|
by :attr:`factor`. Once :attr:`total_iters` is reached, it undoes this,
|
||||||
|
scaling by ``1 / factor``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
if self.last_epoch == 0:
|
if self.last_epoch == 0:
|
||||||
return [group["lr"] * self.factor for group in self.optimizer.param_groups]
|
return [group["lr"] * self.factor for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
if self.last_epoch != self.total_iters:
|
if self.last_epoch != self.total_iters:
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
return _param_groups_val_list(self.optimizer, "lr")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups
|
group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
def _get_closed_form_lr(self):
|
||||||
|
r"""Compute learning rates for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
|
||||||
|
a closed-form formula.
|
||||||
|
|
||||||
|
Uses :attr:`base_lrs` to compute learning rates. This method is called
|
||||||
|
when an epoch is passed to :meth:`step`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
base_lr
|
base_lr
|
||||||
* (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
|
* (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
|
||||||
|
|
@ -733,8 +940,28 @@ class LinearLR(LRScheduler):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the learning rate."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Scales the ``group["lr"]``\s in the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` such that successive steps
|
||||||
|
interpolate linearly from :attr:`start_factor` up to :attr:`end_factor`
|
||||||
|
across :attr:`total_iters` steps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
if self.last_epoch == 0:
|
if self.last_epoch == 0:
|
||||||
|
|
@ -743,7 +970,7 @@ class LinearLR(LRScheduler):
|
||||||
]
|
]
|
||||||
|
|
||||||
if self._is_initial or self.last_epoch > self.total_iters:
|
if self._is_initial or self.last_epoch > self.total_iters:
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
return _param_groups_val_list(self.optimizer, "lr")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
group["lr"]
|
group["lr"]
|
||||||
|
|
@ -759,6 +986,18 @@ class LinearLR(LRScheduler):
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
def _get_closed_form_lr(self):
|
||||||
|
r"""Compute learning rates for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
|
||||||
|
a closed-form formula.
|
||||||
|
|
||||||
|
Uses :attr:`base_lrs` to compute learning rates. This method is called
|
||||||
|
when an epoch is passed to :meth:`step`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
base_lr
|
base_lr
|
||||||
* (
|
* (
|
||||||
|
|
@ -802,17 +1041,47 @@ class ExponentialLR(LRScheduler):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the learning rate of each parameter group."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Multiplies the current ``group["lr"]``\s in the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
# when loading from a checkpoint, we don't want _initial_step (called from the constructor)
|
# when loading from a checkpoint, we don't want _initial_step (called from the constructor)
|
||||||
# to update the lr one more step ahead of itself.
|
# to update the lr one more step ahead of itself.
|
||||||
if self._is_initial:
|
if self._is_initial:
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
return _param_groups_val_list(self.optimizer, "lr")
|
||||||
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
|
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
def _get_closed_form_lr(self):
|
||||||
|
r"""Compute learning rates for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
|
||||||
|
a closed-form formula.
|
||||||
|
|
||||||
|
Uses :attr:`base_lrs` to compute learning rates. This method is called
|
||||||
|
when an epoch is passed to :meth:`step`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
|
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -935,7 +1204,7 @@ class SequentialLR(LRScheduler):
|
||||||
def state_dict(self) -> dict[str, Any]:
|
def state_dict(self) -> dict[str, Any]:
|
||||||
"""Return the state of the scheduler as a :class:`dict`.
|
"""Return the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
It contains an entry for every variable in self.__dict__ which
|
It contains an entry for every variable in ``self.__dict__`` which
|
||||||
is not the optimizer.
|
is not the optimizer.
|
||||||
The wrapped scheduler states will also be saved.
|
The wrapped scheduler states will also be saved.
|
||||||
"""
|
"""
|
||||||
|
|
@ -1008,12 +1277,38 @@ class PolynomialLR(LRScheduler):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the learning rate."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Scales the ``group["lr"]``\s in the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` such that the learning rates
|
||||||
|
follow
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\texttt{base\_lr} \cdot \left(1 - \frac{\texttt{last\_epoch}}
|
||||||
|
{\texttt{total\_iters}} \right)^\texttt{power}
|
||||||
|
|
||||||
|
Returns the current learning rates unchanged after :attr:`total_iters`
|
||||||
|
is reached.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
if self._is_initial or self.last_epoch > self.total_iters:
|
if self._is_initial or self.last_epoch > self.total_iters:
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
return _param_groups_val_list(self.optimizer, "lr")
|
||||||
|
|
||||||
decay_factor = (
|
decay_factor = (
|
||||||
(1.0 - self.last_epoch / self.total_iters)
|
(1.0 - self.last_epoch / self.total_iters)
|
||||||
|
|
@ -1021,7 +1316,19 @@ class PolynomialLR(LRScheduler):
|
||||||
) ** self.power
|
) ** self.power
|
||||||
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
|
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
def _get_closed_form_lr(self):
|
def _get_closed_form_lr(self) -> list[float | Tensor]:
|
||||||
|
r"""Compute learning rates for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
|
||||||
|
a closed-form formula.
|
||||||
|
|
||||||
|
Uses :attr:`base_lrs` to compute learning rates. This method is called
|
||||||
|
when an epoch is passed to :meth:`step`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
base_lr
|
base_lr
|
||||||
|
|
@ -1094,12 +1401,36 @@ class CosineAnnealingLR(LRScheduler):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Retrieve the learning rate of each parameter group."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Scales the ``group["lr"]``\s in the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` such that their learning
|
||||||
|
rates approximate
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\texttt{eta\_min} + \frac{1}{2} (\texttt{base\_lr} -
|
||||||
|
\texttt{eta\_min}) \left(1 + \cos\left(\pi \cdot
|
||||||
|
\frac{\texttt{last\_epoch}}{\texttt{T\_max}}\right) \right)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
if self._is_initial:
|
if self._is_initial:
|
||||||
return [group["lr"] for group in self.optimizer.param_groups]
|
return _param_groups_val_list(self.optimizer, "lr")
|
||||||
elif self._step_count == 1 and self.last_epoch > 0:
|
elif self._step_count == 1 and self.last_epoch > 0:
|
||||||
return [
|
return [
|
||||||
self.eta_min
|
self.eta_min
|
||||||
|
|
@ -1122,7 +1453,19 @@ class CosineAnnealingLR(LRScheduler):
|
||||||
for group in self.optimizer.param_groups
|
for group in self.optimizer.param_groups
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_closed_form_lr(self) -> list[float]:
|
def _get_closed_form_lr(self) -> list[float | Tensor]:
|
||||||
|
r"""Compute learning rates for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
|
||||||
|
a closed-form formula.
|
||||||
|
|
||||||
|
Uses :attr:`base_lrs` to compute learning rates. This method is called
|
||||||
|
when an epoch is passed to :meth:`step`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
self.eta_min
|
self.eta_min
|
||||||
+ (base_lr - self.eta_min)
|
+ (base_lr - self.eta_min)
|
||||||
|
|
@ -1191,23 +1534,19 @@ class ChainedScheduler(LRScheduler):
|
||||||
)
|
)
|
||||||
self._schedulers = schedulers
|
self._schedulers = schedulers
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self._last_lr = [
|
self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
|
||||||
group["lr"] for group in self._schedulers[-1].optimizer.param_groups
|
|
||||||
]
|
|
||||||
|
|
||||||
def step(self) -> None: # type: ignore[override]
|
def step(self) -> None: # type: ignore[override]
|
||||||
"""Perform a step."""
|
"""Perform a step."""
|
||||||
for scheduler in self._schedulers:
|
for scheduler in self._schedulers:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
self._last_lr = [
|
self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
|
||||||
group["lr"] for group in self._schedulers[-1].optimizer.param_groups
|
|
||||||
]
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def state_dict(self) -> dict[str, Any]:
|
def state_dict(self) -> dict[str, Any]:
|
||||||
"""Return the state of the scheduler as a :class:`dict`.
|
"""Return the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
It contains an entry for every variable in self.__dict__ which
|
It contains an entry for every variable in ``self.__dict__`` which
|
||||||
is not the optimizer.
|
is not the optimizer.
|
||||||
The wrapped scheduler states will also be saved.
|
The wrapped scheduler states will also be saved.
|
||||||
"""
|
"""
|
||||||
|
|
@ -1334,7 +1673,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||||
self.cooldown = cooldown
|
self.cooldown = cooldown
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.last_epoch = 0
|
self.last_epoch = 0
|
||||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
|
||||||
self._init_is_better(
|
self._init_is_better(
|
||||||
mode=mode, threshold=threshold, threshold_mode=threshold_mode
|
mode=mode, threshold=threshold, threshold_mode=threshold_mode
|
||||||
)
|
)
|
||||||
|
|
@ -1371,7 +1710,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||||
self.cooldown_counter = self.cooldown
|
self.cooldown_counter = self.cooldown
|
||||||
self.num_bad_epochs = 0
|
self.num_bad_epochs = 0
|
||||||
|
|
||||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
|
||||||
|
|
||||||
def _reduce_lr(self, epoch):
|
def _reduce_lr(self, epoch):
|
||||||
if len(self.optimizer.param_groups) != len(self.min_lrs):
|
if len(self.optimizer.param_groups) != len(self.min_lrs):
|
||||||
|
|
@ -1561,11 +1900,7 @@ class CyclicLR(LRScheduler):
|
||||||
base_lrs = _format_param("base_lr", optimizer, base_lr)
|
base_lrs = _format_param("base_lr", optimizer, base_lr)
|
||||||
if last_epoch == -1:
|
if last_epoch == -1:
|
||||||
for lr, group in zip(base_lrs, optimizer.param_groups):
|
for lr, group in zip(base_lrs, optimizer.param_groups):
|
||||||
if isinstance(group["lr"], Tensor):
|
_update_param_group_val(group, "lr", lr)
|
||||||
lr_val = lr.item() if isinstance(lr, Tensor) else lr
|
|
||||||
group["lr"].fill_(lr_val)
|
|
||||||
else:
|
|
||||||
group["lr"] = lr
|
|
||||||
|
|
||||||
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
|
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
|
||||||
|
|
||||||
|
|
@ -1649,13 +1984,34 @@ class CyclicLR(LRScheduler):
|
||||||
return gamma**x
|
return gamma**x
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Calculate the learning rate at batch index.
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
This function treats `self.last_epoch` as the last batch index.
|
Advances each ``group["lr"]`` in the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` along a cycle between the
|
||||||
|
group's ``base_lr`` and ``max_lr`` using :meth:`scale_fn`.
|
||||||
|
|
||||||
If `self.cycle_momentum` is ``True``, this function has a side effect of
|
Returns:
|
||||||
updating the optimizer's momentum.
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This method treats :attr:`last_epoch` as the index of the previous
|
||||||
|
batch.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
When :attr:`cycle_momentum` is ``True``, this method has a side
|
||||||
|
effect of updating the optimizer's momentum.
|
||||||
"""
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
|
|
@ -1700,7 +2056,7 @@ class CyclicLR(LRScheduler):
|
||||||
def state_dict(self) -> dict[str, Any]: # noqa: D102
|
def state_dict(self) -> dict[str, Any]: # noqa: D102
|
||||||
"""Return the state of the scheduler as a :class:`dict`.
|
"""Return the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
It contains an entry for every variable in self.__dict__ which
|
It contains an entry for every variable in ``self.__dict__`` which
|
||||||
is not the optimizer.
|
is not the optimizer.
|
||||||
The learning rate lambda functions will only be saved if they are callable objects
|
The learning rate lambda functions will only be saved if they are callable objects
|
||||||
and not if they are functions or lambdas.
|
and not if they are functions or lambdas.
|
||||||
|
|
@ -1795,8 +2151,36 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the initial learning rate."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Computes learning rates for the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups` following:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\texttt{eta\_min} + \frac{1}{2}(\texttt{base\_lr} -
|
||||||
|
\texttt{eta\_min})\left(1 + \cos\left(\pi \cdot
|
||||||
|
\frac{\texttt{T\_cur}}{\texttt{T\_i}}\right)\right)
|
||||||
|
|
||||||
|
Where :attr:`T_cur` is the number of epochs since the last restart and
|
||||||
|
:attr:`T_i` is the number of epochs between two restarts. Both
|
||||||
|
:attr:`T_cur` and :attr:`T_i` are updated in :meth:`step`, and
|
||||||
|
:attr:`T_i` becomes :attr:`T_mult` times larger after each restart.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -1869,7 +2253,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||||
_update_param_group_val(param_group, "lr", lr)
|
_update_param_group_val(param_group, "lr", lr)
|
||||||
|
|
||||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
|
||||||
|
|
||||||
|
|
||||||
class _SchedulePhase(TypedDict):
|
class _SchedulePhase(TypedDict):
|
||||||
|
|
@ -2141,8 +2525,31 @@ class OneCycleLR(LRScheduler):
|
||||||
return (end - start) * pct + start
|
return (end - start) * pct + start
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_lr(self) -> list[float]:
|
def get_lr(self) -> list[float | Tensor]:
|
||||||
"""Compute the learning rate of each parameter group."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Finds the appropriate :attr:`_schedule_phases` entry for the current
|
||||||
|
step and interpolates between its ``start_lr`` and ``end_lr`` using
|
||||||
|
:meth:`_anneal_func`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
When :attr:`cycle_momentum` is ``True``, this method has a side
|
||||||
|
effect of updating the optimizer's momentum.
|
||||||
|
"""
|
||||||
_warn_get_lr_called_within_step(self)
|
_warn_get_lr_called_within_step(self)
|
||||||
|
|
||||||
lrs = []
|
lrs = []
|
||||||
|
|
|
||||||
|
|
@ -454,8 +454,29 @@ class SWALR(LRScheduler):
|
||||||
return swa_lr
|
return swa_lr
|
||||||
return (lr - alpha * swa_lr) / (1 - alpha)
|
return (lr - alpha * swa_lr) / (1 - alpha)
|
||||||
|
|
||||||
|
@override
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
"""Get learning rate."""
|
r"""Compute the next learning rate for each of the optimizer's
|
||||||
|
:attr:`~torch.optim.Optimizer.param_groups`.
|
||||||
|
|
||||||
|
Uses :attr:`anneal_func` to interpolate between each group's
|
||||||
|
``group["lr"]`` and ``group["swa_lr"]`` over :attr:`anneal_epochs`
|
||||||
|
epochs. Once :attr:`anneal_epochs` is reached, keeps the learning rate
|
||||||
|
fixed at ``group["swa_lr"]``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float | Tensor]: A :class:`list` of learning rates for each of
|
||||||
|
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
|
||||||
|
same types as their current ``group["lr"]``\s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you're trying to inspect the most recent learning rate, use
|
||||||
|
:meth:`get_last_lr()` instead.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The returned :class:`~torch.Tensor`\s are copies, and never alias
|
||||||
|
the optimizer's ``group["lr"]``\s.
|
||||||
|
"""
|
||||||
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
|
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
|
||||||
# so we ignore the type error here. See `LRScheduler.step()` for more details.
|
# so we ignore the type error here. See `LRScheduler.step()` for more details.
|
||||||
if not self._get_lr_called_within_step:
|
if not self._get_lr_called_within_step:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user