[optim] prevent unintended aliasing in lr_scheduler; update type annotations/docs (#163120)

1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor.
2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s.
3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs.
4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation.

Fixes #163103

LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:
```python
self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
```

This PR adds a helper to address this:
```python
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
    """Create a list containing group[key] for each optimizer param_group.
    Prevents aliasing when group[key] could be a Tensor.
    Raises a KeyError when group[key] does not exist.
    """
    return [
        group[key].clone() if isinstance(group[key], Tensor) else group[key]
        for group in optimizer.param_groups
    ]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163120
Approved by: https://github.com/janeyx99
This commit is contained in:
Filip 2025-09-25 06:58:55 +00:00 committed by PyTorch MergeBot
parent ad869c58f5
commit 2b6a74abf1
3 changed files with 506 additions and 77 deletions

View File

@ -77,7 +77,7 @@ class TestLRScheduler(TestCase):
self.opt = SGD(
[
{"params": self.net.conv1.parameters()},
{"params": self.net.conv2.parameters(), "lr": 0.5},
{"params": self.net.conv2.parameters(), "lr": torch.tensor(0.5)},
],
lr=0.05,
)
@ -2530,7 +2530,7 @@ class TestLRScheduler(TestCase):
],
)
def test_constant_initial_lr(self, LRClass):
# Test that the initial learning rate is constant
# Test that the initial learning rate is constant and that it does not alias base_lrs
lr = torch.as_tensor(0.1)
opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
sch = LRClass(opt)
@ -2544,6 +2544,7 @@ class TestLRScheduler(TestCase):
for group, ori_group in zip(opt.param_groups, ori_param_groups):
self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
self.assertEqual(sch.base_lrs, [0.1])
self.assertIsNot(sch.base_lrs[0], group["initial_lr"])
def test_constant_initial_params_cyclelr(self):
# Test that the initial learning rate is constant

View File

@ -79,9 +79,20 @@ def _format_param(name: str, optimizer: Optimizer, param):
return list(map(_copy, param))
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
"""Create a list containing group[key] for each optimizer param_group.
Prevents aliasing when group[key] could be a Tensor.
Raises a KeyError when group[key] does not exist.
"""
return [
group[key].clone() if isinstance(group[key], Tensor) else group[key]
for group in optimizer.param_groups
]
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
"""Set param_group[key] to val without aliasing or assignment when they're both tensors.
Raises a KeyError if param_group[key] does not exist.
"""Set param_group[key] to val without aliasing or assignment when they're
both tensors. Raises a KeyError if param_group[key] does not exist.
"""
if isinstance(param_group[key], Tensor):
param_group[key].fill_(_to_scalar(val))
@ -90,7 +101,25 @@ def _update_param_group_val(param_group: dict[str, Any], key: str, val: float |
class LRScheduler:
r"""Adjusts the learning rate during optimization."""
r"""Base class for all learning rate schedulers.
Subclasses implement :meth:`get_lr` and optionally override :meth:`step` to
define scheduling behavior.
Args:
optimizer (Optimizer): The optimizer this scheduler will adjust the
learning rates of.
last_epoch (int): Index of the last epoch seen by the scheduler. Use
``-1`` (default) to initialize the scheduler. Only use a non-default
value when restoring this scheduler from a saved checkpoint.
.. warning::
Initializing a scheduler overwrites its optimizer's
``param_group["lr"]``\s. When restoring a checkpoint, initialize the
scheduler **before** calling your optimizer's
:meth:`~torch.optim.Optimizer.load_state_dict` to avoid overwriting the
loaded learning rates.
"""
_get_lr_called_within_step: bool = False
_is_initial: bool = False
@ -121,9 +150,9 @@ class LRScheduler:
"1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n"
"2. You're using last_epoch >= 0 for a fresh training run (not recommended)"
)
self.base_lrs: list[float] = [
group["initial_lr"] for group in optimizer.param_groups
]
self.base_lrs: list[float | Tensor] = _param_groups_val_list(
optimizer, "initial_lr"
)
self.last_epoch = last_epoch
# Following https://github.com/pytorch/pytorch/issues/20124
@ -161,7 +190,7 @@ class LRScheduler:
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
"""
return {
@ -177,16 +206,58 @@ class LRScheduler:
"""
self.__dict__.update(state_dict)
def get_last_lr(self) -> list[float]:
"""Return last computed learning rate by current scheduler."""
def get_last_lr(self) -> list[float | Tensor]:
r"""Get the most recent learning rates computed by this scheduler.
Returns:
list[float | Tensor]: A :class:`list` of learning rates with entries
for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`, with the same types as
their ``group["lr"]``\s.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
# We always update self._last_lr with _param_groups_val_list, so it's a
# .clone() of the group["lr"]s. If we didn't do this, the user could
# corrupt their learning rates by modifying the outputs in place.
return self._last_lr
def get_lr(self) -> list[float]:
"""Compute learning rate using chainable form of the scheduler."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
raise NotImplementedError
def step(self, epoch: Optional[int] = None) -> None:
"""Perform a step."""
"""Step the scheduler.
Args:
epoch (int, optional):
.. deprecated:: 1.4
If provided, sets :attr:`last_epoch` to ``epoch`` and uses
:meth:`_get_closed_form_lr` if it is available. This is not
universally supported. Use :meth:`step` without arguments
instead.
.. note::
Call this method after calling the optimizer's
:meth:`~torch.optim.Optimizer.step`.
"""
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
@ -224,16 +295,18 @@ class LRScheduler:
else:
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = cast(list[float], self._get_closed_form_lr())
values = cast(
list[Union[float, Tensor]], self._get_closed_form_lr()
)
else:
values = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, values):
_update_param_group_val(param_group, "lr", lr)
self._last_lr: list[float] = [
group["lr"] for group in self.optimizer.param_groups
]
self._last_lr: list[float | Tensor] = _param_groups_val_list(
self.optimizer, "lr"
)
def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None:
@ -333,8 +406,7 @@ class LambdaLR(LRScheduler):
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
It contains an entry for every variable in ``self.__dict__`` which is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
@ -374,8 +446,26 @@ class LambdaLR(LRScheduler):
self.lr_lambdas[idx].__dict__.update(fn)
@override
def get_lr(self) -> list[float]:
"""Compute learning rate."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the :attr:`base_lrs` by the outputs of the :attr:`lr_lambdas` at
:attr:`last_epoch`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
return [
@ -436,7 +526,7 @@ class MultiplicativeLR(LRScheduler):
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
@ -473,8 +563,27 @@ class MultiplicativeLR(LRScheduler):
self.lr_lambdas[idx].__dict__.update(fn)
@override
def get_lr(self) -> list[float]:
"""Compute the learning rate of each parameter group."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the current ``group["lr"]``\s in each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by the outputs of the
:attr:`lr_lambdas` at :attr:`last_epoch`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if not self._is_initial:
@ -483,7 +592,7 @@ class MultiplicativeLR(LRScheduler):
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
]
else:
return [group["lr"] for group in self.optimizer.param_groups]
return _param_groups_val_list(self.optimizer, "lr")
class StepLR(LRScheduler):
@ -527,15 +636,46 @@ class StepLR(LRScheduler):
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float]:
"""Compute the learning rate of each parameter group."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
If the current epoch is a non-zero multiple of :attr:`step_size`, we
scale the current ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return [group["lr"] for group in self.optimizer.param_groups]
return _param_groups_val_list(self.optimizer, "lr")
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
def _get_closed_form_lr(self) -> list[float]:
def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs
@ -582,18 +722,53 @@ class MultiStepLR(LRScheduler):
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float]:
"""Compute the learning rate of each parameter group."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
If the current epoch is in :attr:`milestones`, decays the
``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
If the current epoch appears in :attr:`milestones` ``n`` times, we
scale by :attr:`gamma` to the power of ``n``
"""
_warn_get_lr_called_within_step(self)
if self.last_epoch not in self.milestones:
return [group["lr"] for group in self.optimizer.param_groups]
return _param_groups_val_list(self.optimizer, "lr")
return [
group["lr"] * self.gamma ** self.milestones[self.last_epoch]
for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
milestones = sorted(self.milestones.elements())
return [
base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
@ -651,21 +826,53 @@ class ConstantLR(LRScheduler):
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float]:
"""Compute the learning rate of each parameter group."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
When :attr:`last_epoch` is 0, this method scales the ``group["lr"]``\s
in each of the optimizer's :attr:`~torch.optim.Optimizer.param_groups`
by :attr:`factor`. Once :attr:`total_iters` is reached, it undoes this,
scaling by ``1 / factor``.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0:
return [group["lr"] * self.factor for group in self.optimizer.param_groups]
if self.last_epoch != self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
return _param_groups_val_list(self.optimizer, "lr")
return [
group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
base_lr
* (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
@ -733,8 +940,28 @@ class LinearLR(LRScheduler):
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float]:
"""Compute the learning rate."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that successive steps
interpolate linearly from :attr:`start_factor` up to :attr:`end_factor`
across :attr:`total_iters` steps.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0:
@ -743,7 +970,7 @@ class LinearLR(LRScheduler):
]
if self._is_initial or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
return _param_groups_val_list(self.optimizer, "lr")
return [
group["lr"]
@ -759,6 +986,18 @@ class LinearLR(LRScheduler):
]
def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
base_lr
* (
@ -802,17 +1041,47 @@ class ExponentialLR(LRScheduler):
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float]:
"""Compute the learning rate of each parameter group."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Multiplies the current ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
# when loading from a checkpoint, we don't want _initial_step (called from the constructor)
# to update the lr one more step ahead of itself.
if self._is_initial:
return [group["lr"] for group in self.optimizer.param_groups]
return _param_groups_val_list(self.optimizer, "lr")
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
@ -935,7 +1204,7 @@ class SequentialLR(LRScheduler):
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
The wrapped scheduler states will also be saved.
"""
@ -1008,12 +1277,38 @@ class PolynomialLR(LRScheduler):
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float]:
"""Compute the learning rate."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that the learning rates
follow
.. math::
\texttt{base\_lr} \cdot \left(1 - \frac{\texttt{last\_epoch}}
{\texttt{total\_iters}} \right)^\texttt{power}
Returns the current learning rates unchanged after :attr:`total_iters`
is reached.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if self._is_initial or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
return _param_groups_val_list(self.optimizer, "lr")
decay_factor = (
(1.0 - self.last_epoch / self.total_iters)
@ -1021,7 +1316,19 @@ class PolynomialLR(LRScheduler):
) ** self.power
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
(
base_lr
@ -1094,12 +1401,36 @@ class CosineAnnealingLR(LRScheduler):
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float]:
"""Retrieve the learning rate of each parameter group."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that their learning
rates approximate
.. math::
\texttt{eta\_min} + \frac{1}{2} (\texttt{base\_lr} -
\texttt{eta\_min}) \left(1 + \cos\left(\pi \cdot
\frac{\texttt{last\_epoch}}{\texttt{T\_max}}\right) \right)
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if self._is_initial:
return [group["lr"] for group in self.optimizer.param_groups]
return _param_groups_val_list(self.optimizer, "lr")
elif self._step_count == 1 and self.last_epoch > 0:
return [
self.eta_min
@ -1122,7 +1453,19 @@ class CosineAnnealingLR(LRScheduler):
for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self) -> list[float]:
def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
self.eta_min
+ (base_lr - self.eta_min)
@ -1191,23 +1534,19 @@ class ChainedScheduler(LRScheduler):
)
self._schedulers = schedulers
self.optimizer = optimizer
self._last_lr = [
group["lr"] for group in self._schedulers[-1].optimizer.param_groups
]
self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
def step(self) -> None: # type: ignore[override]
"""Perform a step."""
for scheduler in self._schedulers:
scheduler.step()
self._last_lr = [
group["lr"] for group in self._schedulers[-1].optimizer.param_groups
]
self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
@override
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
The wrapped scheduler states will also be saved.
"""
@ -1334,7 +1673,7 @@ class ReduceLROnPlateau(LRScheduler):
self.cooldown = cooldown
self.eps = eps
self.last_epoch = 0
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
self._init_is_better(
mode=mode, threshold=threshold, threshold_mode=threshold_mode
)
@ -1371,7 +1710,7 @@ class ReduceLROnPlateau(LRScheduler):
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
def _reduce_lr(self, epoch):
if len(self.optimizer.param_groups) != len(self.min_lrs):
@ -1561,11 +1900,7 @@ class CyclicLR(LRScheduler):
base_lrs = _format_param("base_lr", optimizer, base_lr)
if last_epoch == -1:
for lr, group in zip(base_lrs, optimizer.param_groups):
if isinstance(group["lr"], Tensor):
lr_val = lr.item() if isinstance(lr, Tensor) else lr
group["lr"].fill_(lr_val)
else:
group["lr"] = lr
_update_param_group_val(group, "lr", lr)
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
@ -1649,13 +1984,34 @@ class CyclicLR(LRScheduler):
return gamma**x
@override
def get_lr(self) -> list[float]:
"""Calculate the learning rate at batch index.
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
This function treats `self.last_epoch` as the last batch index.
Advances each ``group["lr"]`` in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` along a cycle between the
group's ``base_lr`` and ``max_lr`` using :meth:`scale_fn`.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
This method treats :attr:`last_epoch` as the index of the previous
batch.
.. note::
When :attr:`cycle_momentum` is ``True``, this method has a side
effect of updating the optimizer's momentum.
"""
_warn_get_lr_called_within_step(self)
@ -1700,7 +2056,7 @@ class CyclicLR(LRScheduler):
def state_dict(self) -> dict[str, Any]: # noqa: D102
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
@ -1795,8 +2151,36 @@ class CosineAnnealingWarmRestarts(LRScheduler):
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float]:
"""Compute the initial learning rate."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Computes learning rates for the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` following:
.. math::
\texttt{eta\_min} + \frac{1}{2}(\texttt{base\_lr} -
\texttt{eta\_min})\left(1 + \cos\left(\pi \cdot
\frac{\texttt{T\_cur}}{\texttt{T\_i}}\right)\right)
Where :attr:`T_cur` is the number of epochs since the last restart and
:attr:`T_i` is the number of epochs between two restarts. Both
:attr:`T_cur` and :attr:`T_i` are updated in :meth:`step`, and
:attr:`T_i` becomes :attr:`T_mult` times larger after each restart.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
return [
@ -1869,7 +2253,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
_update_param_group_val(param_group, "lr", lr)
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
class _SchedulePhase(TypedDict):
@ -2141,8 +2525,31 @@ class OneCycleLR(LRScheduler):
return (end - start) * pct + start
@override
def get_lr(self) -> list[float]:
"""Compute the learning rate of each parameter group."""
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Finds the appropriate :attr:`_schedule_phases` entry for the current
step and interpolates between its ``start_lr`` and ``end_lr`` using
:meth:`_anneal_func`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
When :attr:`cycle_momentum` is ``True``, this method has a side
effect of updating the optimizer's momentum.
"""
_warn_get_lr_called_within_step(self)
lrs = []

View File

@ -454,8 +454,29 @@ class SWALR(LRScheduler):
return swa_lr
return (lr - alpha * swa_lr) / (1 - alpha)
@override
def get_lr(self):
"""Get learning rate."""
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Uses :attr:`anneal_func` to interpolate between each group's
``group["lr"]`` and ``group["swa_lr"]`` over :attr:`anneal_epochs`
epochs. Once :attr:`anneal_epochs` is reached, keeps the learning rate
fixed at ``group["swa_lr"]``.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
# so we ignore the type error here. See `LRScheduler.step()` for more details.
if not self._get_lr_called_within_step: