[optim] prevent unintended aliasing in lr_scheduler; update type annotations/docs (#163120)

1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor.
2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s.
3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs.
4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation.

Fixes #163103

LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:
```python
self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
```

This PR adds a helper to address this:
```python
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
    """Create a list containing group[key] for each optimizer param_group.
    Prevents aliasing when group[key] could be a Tensor.
    Raises a KeyError when group[key] does not exist.
    """
    return [
        group[key].clone() if isinstance(group[key], Tensor) else group[key]
        for group in optimizer.param_groups
    ]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163120
Approved by: https://github.com/janeyx99
This commit is contained in:
Filip 2025-09-25 06:58:55 +00:00 committed by PyTorch MergeBot
parent ad869c58f5
commit 2b6a74abf1
3 changed files with 506 additions and 77 deletions

View File

@ -77,7 +77,7 @@ class TestLRScheduler(TestCase):
self.opt = SGD( self.opt = SGD(
[ [
{"params": self.net.conv1.parameters()}, {"params": self.net.conv1.parameters()},
{"params": self.net.conv2.parameters(), "lr": 0.5}, {"params": self.net.conv2.parameters(), "lr": torch.tensor(0.5)},
], ],
lr=0.05, lr=0.05,
) )
@ -2530,7 +2530,7 @@ class TestLRScheduler(TestCase):
], ],
) )
def test_constant_initial_lr(self, LRClass): def test_constant_initial_lr(self, LRClass):
# Test that the initial learning rate is constant # Test that the initial learning rate is constant and that it does not alias base_lrs
lr = torch.as_tensor(0.1) lr = torch.as_tensor(0.1)
opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
sch = LRClass(opt) sch = LRClass(opt)
@ -2544,6 +2544,7 @@ class TestLRScheduler(TestCase):
for group, ori_group in zip(opt.param_groups, ori_param_groups): for group, ori_group in zip(opt.param_groups, ori_param_groups):
self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
self.assertEqual(sch.base_lrs, [0.1]) self.assertEqual(sch.base_lrs, [0.1])
self.assertIsNot(sch.base_lrs[0], group["initial_lr"])
def test_constant_initial_params_cyclelr(self): def test_constant_initial_params_cyclelr(self):
# Test that the initial learning rate is constant # Test that the initial learning rate is constant

View File

@ -79,9 +79,20 @@ def _format_param(name: str, optimizer: Optimizer, param):
return list(map(_copy, param)) return list(map(_copy, param))
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
"""Create a list containing group[key] for each optimizer param_group.
Prevents aliasing when group[key] could be a Tensor.
Raises a KeyError when group[key] does not exist.
"""
return [
group[key].clone() if isinstance(group[key], Tensor) else group[key]
for group in optimizer.param_groups
]
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
"""Set param_group[key] to val without aliasing or assignment when they're both tensors. """Set param_group[key] to val without aliasing or assignment when they're
Raises a KeyError if param_group[key] does not exist. both tensors. Raises a KeyError if param_group[key] does not exist.
""" """
if isinstance(param_group[key], Tensor): if isinstance(param_group[key], Tensor):
param_group[key].fill_(_to_scalar(val)) param_group[key].fill_(_to_scalar(val))
@ -90,7 +101,25 @@ def _update_param_group_val(param_group: dict[str, Any], key: str, val: float |
class LRScheduler: class LRScheduler:
r"""Adjusts the learning rate during optimization.""" r"""Base class for all learning rate schedulers.
Subclasses implement :meth:`get_lr` and optionally override :meth:`step` to
define scheduling behavior.
Args:
optimizer (Optimizer): The optimizer this scheduler will adjust the
learning rates of.
last_epoch (int): Index of the last epoch seen by the scheduler. Use
``-1`` (default) to initialize the scheduler. Only use a non-default
value when restoring this scheduler from a saved checkpoint.
.. warning::
Initializing a scheduler overwrites its optimizer's
``param_group["lr"]``\s. When restoring a checkpoint, initialize the
scheduler **before** calling your optimizer's
:meth:`~torch.optim.Optimizer.load_state_dict` to avoid overwriting the
loaded learning rates.
"""
_get_lr_called_within_step: bool = False _get_lr_called_within_step: bool = False
_is_initial: bool = False _is_initial: bool = False
@ -121,9 +150,9 @@ class LRScheduler:
"1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n" "1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n"
"2. You're using last_epoch >= 0 for a fresh training run (not recommended)" "2. You're using last_epoch >= 0 for a fresh training run (not recommended)"
) )
self.base_lrs: list[float] = [ self.base_lrs: list[float | Tensor] = _param_groups_val_list(
group["initial_lr"] for group in optimizer.param_groups optimizer, "initial_lr"
] )
self.last_epoch = last_epoch self.last_epoch = last_epoch
# Following https://github.com/pytorch/pytorch/issues/20124 # Following https://github.com/pytorch/pytorch/issues/20124
@ -161,7 +190,7 @@ class LRScheduler:
def state_dict(self) -> dict[str, Any]: def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`. """Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer. is not the optimizer.
""" """
return { return {
@ -177,16 +206,58 @@ class LRScheduler:
""" """
self.__dict__.update(state_dict) self.__dict__.update(state_dict)
def get_last_lr(self) -> list[float]: def get_last_lr(self) -> list[float | Tensor]:
"""Return last computed learning rate by current scheduler.""" r"""Get the most recent learning rates computed by this scheduler.
Returns:
list[float | Tensor]: A :class:`list` of learning rates with entries
for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`, with the same types as
their ``group["lr"]``\s.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
# We always update self._last_lr with _param_groups_val_list, so it's a
# .clone() of the group["lr"]s. If we didn't do this, the user could
# corrupt their learning rates by modifying the outputs in place.
return self._last_lr return self._last_lr
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute learning rate using chainable form of the scheduler.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
raise NotImplementedError raise NotImplementedError
def step(self, epoch: Optional[int] = None) -> None: def step(self, epoch: Optional[int] = None) -> None:
"""Perform a step.""" """Step the scheduler.
Args:
epoch (int, optional):
.. deprecated:: 1.4
If provided, sets :attr:`last_epoch` to ``epoch`` and uses
:meth:`_get_closed_form_lr` if it is available. This is not
universally supported. Use :meth:`step` without arguments
instead.
.. note::
Call this method after calling the optimizer's
:meth:`~torch.optim.Optimizer.step`.
"""
# Raise a warning if old pattern is detected # Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124 # https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1: if self._step_count == 1:
@ -224,16 +295,18 @@ class LRScheduler:
else: else:
self.last_epoch = epoch self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"): if hasattr(self, "_get_closed_form_lr"):
values = cast(list[float], self._get_closed_form_lr()) values = cast(
list[Union[float, Tensor]], self._get_closed_form_lr()
)
else: else:
values = self.get_lr() values = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, values): for param_group, lr in zip(self.optimizer.param_groups, values):
_update_param_group_val(param_group, "lr", lr) _update_param_group_val(param_group, "lr", lr)
self._last_lr: list[float] = [ self._last_lr: list[float | Tensor] = _param_groups_val_list(
group["lr"] for group in self.optimizer.param_groups self.optimizer, "lr"
] )
def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None: def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None:
@ -333,8 +406,7 @@ class LambdaLR(LRScheduler):
def state_dict(self) -> dict[str, Any]: def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`. """Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which It contains an entry for every variable in ``self.__dict__`` which is not the optimizer.
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas. and not if they are functions or lambdas.
@ -374,8 +446,26 @@ class LambdaLR(LRScheduler):
self.lr_lambdas[idx].__dict__.update(fn) self.lr_lambdas[idx].__dict__.update(fn)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute learning rate.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the :attr:`base_lrs` by the outputs of the :attr:`lr_lambdas` at
:attr:`last_epoch`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
return [ return [
@ -436,7 +526,7 @@ class MultiplicativeLR(LRScheduler):
def state_dict(self) -> dict[str, Any]: def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`. """Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer. is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas. and not if they are functions or lambdas.
@ -473,8 +563,27 @@ class MultiplicativeLR(LRScheduler):
self.lr_lambdas[idx].__dict__.update(fn) self.lr_lambdas[idx].__dict__.update(fn)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the learning rate of each parameter group.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the current ``group["lr"]``\s in each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by the outputs of the
:attr:`lr_lambdas` at :attr:`last_epoch`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
if not self._is_initial: if not self._is_initial:
@ -483,7 +592,7 @@ class MultiplicativeLR(LRScheduler):
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
] ]
else: else:
return [group["lr"] for group in self.optimizer.param_groups] return _param_groups_val_list(self.optimizer, "lr")
class StepLR(LRScheduler): class StepLR(LRScheduler):
@ -527,15 +636,46 @@ class StepLR(LRScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the learning rate of each parameter group.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
If the current epoch is a non-zero multiple of :attr:`step_size`, we
scale the current ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return [group["lr"] for group in self.optimizer.param_groups] return _param_groups_val_list(self.optimizer, "lr")
return [group["lr"] * self.gamma for group in self.optimizer.param_groups] return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
def _get_closed_form_lr(self) -> list[float]: def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [ return [
base_lr * self.gamma ** (self.last_epoch // self.step_size) base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs for base_lr in self.base_lrs
@ -582,18 +722,53 @@ class MultiStepLR(LRScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the learning rate of each parameter group.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
If the current epoch is in :attr:`milestones`, decays the
``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
If the current epoch appears in :attr:`milestones` ``n`` times, we
scale by :attr:`gamma` to the power of ``n``
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
if self.last_epoch not in self.milestones: if self.last_epoch not in self.milestones:
return [group["lr"] for group in self.optimizer.param_groups] return _param_groups_val_list(self.optimizer, "lr")
return [ return [
group["lr"] * self.gamma ** self.milestones[self.last_epoch] group["lr"] * self.gamma ** self.milestones[self.last_epoch]
for group in self.optimizer.param_groups for group in self.optimizer.param_groups
] ]
def _get_closed_form_lr(self): def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
milestones = sorted(self.milestones.elements()) milestones = sorted(self.milestones.elements())
return [ return [
base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
@ -651,21 +826,53 @@ class ConstantLR(LRScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the learning rate of each parameter group.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
When :attr:`last_epoch` is 0, this method scales the ``group["lr"]``\s
in each of the optimizer's :attr:`~torch.optim.Optimizer.param_groups`
by :attr:`factor`. Once :attr:`total_iters` is reached, it undoes this,
scaling by ``1 / factor``.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
if self.last_epoch == 0: if self.last_epoch == 0:
return [group["lr"] * self.factor for group in self.optimizer.param_groups] return [group["lr"] * self.factor for group in self.optimizer.param_groups]
if self.last_epoch != self.total_iters: if self.last_epoch != self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups] return _param_groups_val_list(self.optimizer, "lr")
return [ return [
group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups
] ]
def _get_closed_form_lr(self): def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [ return [
base_lr base_lr
* (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
@ -733,8 +940,28 @@ class LinearLR(LRScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the learning rate.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that successive steps
interpolate linearly from :attr:`start_factor` up to :attr:`end_factor`
across :attr:`total_iters` steps.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
if self.last_epoch == 0: if self.last_epoch == 0:
@ -743,7 +970,7 @@ class LinearLR(LRScheduler):
] ]
if self._is_initial or self.last_epoch > self.total_iters: if self._is_initial or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups] return _param_groups_val_list(self.optimizer, "lr")
return [ return [
group["lr"] group["lr"]
@ -759,6 +986,18 @@ class LinearLR(LRScheduler):
] ]
def _get_closed_form_lr(self): def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [ return [
base_lr base_lr
* ( * (
@ -802,17 +1041,47 @@ class ExponentialLR(LRScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the learning rate of each parameter group.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Multiplies the current ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
# when loading from a checkpoint, we don't want _initial_step (called from the constructor) # when loading from a checkpoint, we don't want _initial_step (called from the constructor)
# to update the lr one more step ahead of itself. # to update the lr one more step ahead of itself.
if self._is_initial: if self._is_initial:
return [group["lr"] for group in self.optimizer.param_groups] return _param_groups_val_list(self.optimizer, "lr")
return [group["lr"] * self.gamma for group in self.optimizer.param_groups] return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
def _get_closed_form_lr(self): def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
@ -935,7 +1204,7 @@ class SequentialLR(LRScheduler):
def state_dict(self) -> dict[str, Any]: def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`. """Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer. is not the optimizer.
The wrapped scheduler states will also be saved. The wrapped scheduler states will also be saved.
""" """
@ -1008,12 +1277,38 @@ class PolynomialLR(LRScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the learning rate.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that the learning rates
follow
.. math::
\texttt{base\_lr} \cdot \left(1 - \frac{\texttt{last\_epoch}}
{\texttt{total\_iters}} \right)^\texttt{power}
Returns the current learning rates unchanged after :attr:`total_iters`
is reached.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
if self._is_initial or self.last_epoch > self.total_iters: if self._is_initial or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups] return _param_groups_val_list(self.optimizer, "lr")
decay_factor = ( decay_factor = (
(1.0 - self.last_epoch / self.total_iters) (1.0 - self.last_epoch / self.total_iters)
@ -1021,7 +1316,19 @@ class PolynomialLR(LRScheduler):
) ** self.power ) ** self.power
return [group["lr"] * decay_factor for group in self.optimizer.param_groups] return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
def _get_closed_form_lr(self): def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [ return [
( (
base_lr base_lr
@ -1094,12 +1401,36 @@ class CosineAnnealingLR(LRScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Retrieve the learning rate of each parameter group.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that their learning
rates approximate
.. math::
\texttt{eta\_min} + \frac{1}{2} (\texttt{base\_lr} -
\texttt{eta\_min}) \left(1 + \cos\left(\pi \cdot
\frac{\texttt{last\_epoch}}{\texttt{T\_max}}\right) \right)
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
if self._is_initial: if self._is_initial:
return [group["lr"] for group in self.optimizer.param_groups] return _param_groups_val_list(self.optimizer, "lr")
elif self._step_count == 1 and self.last_epoch > 0: elif self._step_count == 1 and self.last_epoch > 0:
return [ return [
self.eta_min self.eta_min
@ -1122,7 +1453,19 @@ class CosineAnnealingLR(LRScheduler):
for group in self.optimizer.param_groups for group in self.optimizer.param_groups
] ]
def _get_closed_form_lr(self) -> list[float]: def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [ return [
self.eta_min self.eta_min
+ (base_lr - self.eta_min) + (base_lr - self.eta_min)
@ -1191,23 +1534,19 @@ class ChainedScheduler(LRScheduler):
) )
self._schedulers = schedulers self._schedulers = schedulers
self.optimizer = optimizer self.optimizer = optimizer
self._last_lr = [ self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
group["lr"] for group in self._schedulers[-1].optimizer.param_groups
]
def step(self) -> None: # type: ignore[override] def step(self) -> None: # type: ignore[override]
"""Perform a step.""" """Perform a step."""
for scheduler in self._schedulers: for scheduler in self._schedulers:
scheduler.step() scheduler.step()
self._last_lr = [ self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
group["lr"] for group in self._schedulers[-1].optimizer.param_groups
]
@override @override
def state_dict(self) -> dict[str, Any]: def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`. """Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer. is not the optimizer.
The wrapped scheduler states will also be saved. The wrapped scheduler states will also be saved.
""" """
@ -1334,7 +1673,7 @@ class ReduceLROnPlateau(LRScheduler):
self.cooldown = cooldown self.cooldown = cooldown
self.eps = eps self.eps = eps
self.last_epoch = 0 self.last_epoch = 0
self._last_lr = [group["lr"] for group in self.optimizer.param_groups] self._last_lr = _param_groups_val_list(self.optimizer, "lr")
self._init_is_better( self._init_is_better(
mode=mode, threshold=threshold, threshold_mode=threshold_mode mode=mode, threshold=threshold, threshold_mode=threshold_mode
) )
@ -1371,7 +1710,7 @@ class ReduceLROnPlateau(LRScheduler):
self.cooldown_counter = self.cooldown self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0 self.num_bad_epochs = 0
self._last_lr = [group["lr"] for group in self.optimizer.param_groups] self._last_lr = _param_groups_val_list(self.optimizer, "lr")
def _reduce_lr(self, epoch): def _reduce_lr(self, epoch):
if len(self.optimizer.param_groups) != len(self.min_lrs): if len(self.optimizer.param_groups) != len(self.min_lrs):
@ -1561,11 +1900,7 @@ class CyclicLR(LRScheduler):
base_lrs = _format_param("base_lr", optimizer, base_lr) base_lrs = _format_param("base_lr", optimizer, base_lr)
if last_epoch == -1: if last_epoch == -1:
for lr, group in zip(base_lrs, optimizer.param_groups): for lr, group in zip(base_lrs, optimizer.param_groups):
if isinstance(group["lr"], Tensor): _update_param_group_val(group, "lr", lr)
lr_val = lr.item() if isinstance(lr, Tensor) else lr
group["lr"].fill_(lr_val)
else:
group["lr"] = lr
self.max_lrs = _format_param("max_lr", optimizer, max_lr) self.max_lrs = _format_param("max_lr", optimizer, max_lr)
@ -1649,13 +1984,34 @@ class CyclicLR(LRScheduler):
return gamma**x return gamma**x
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Calculate the learning rate at batch index. r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
This function treats `self.last_epoch` as the last batch index. Advances each ``group["lr"]`` in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` along a cycle between the
group's ``base_lr`` and ``max_lr`` using :meth:`scale_fn`.
If `self.cycle_momentum` is ``True``, this function has a side effect of Returns:
updating the optimizer's momentum. list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
This method treats :attr:`last_epoch` as the index of the previous
batch.
.. note::
When :attr:`cycle_momentum` is ``True``, this method has a side
effect of updating the optimizer's momentum.
""" """
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
@ -1700,7 +2056,7 @@ class CyclicLR(LRScheduler):
def state_dict(self) -> dict[str, Any]: # noqa: D102 def state_dict(self) -> dict[str, Any]: # noqa: D102
"""Return the state of the scheduler as a :class:`dict`. """Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer. is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas. and not if they are functions or lambdas.
@ -1795,8 +2151,36 @@ class CosineAnnealingWarmRestarts(LRScheduler):
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the initial learning rate.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Computes learning rates for the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` following:
.. math::
\texttt{eta\_min} + \frac{1}{2}(\texttt{base\_lr} -
\texttt{eta\_min})\left(1 + \cos\left(\pi \cdot
\frac{\texttt{T\_cur}}{\texttt{T\_i}}\right)\right)
Where :attr:`T_cur` is the number of epochs since the last restart and
:attr:`T_i` is the number of epochs between two restarts. Both
:attr:`T_cur` and :attr:`T_i` are updated in :meth:`step`, and
:attr:`T_i` becomes :attr:`T_mult` times larger after each restart.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
return [ return [
@ -1869,7 +2253,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
_update_param_group_val(param_group, "lr", lr) _update_param_group_val(param_group, "lr", lr)
self._last_lr = [group["lr"] for group in self.optimizer.param_groups] self._last_lr = _param_groups_val_list(self.optimizer, "lr")
class _SchedulePhase(TypedDict): class _SchedulePhase(TypedDict):
@ -2141,8 +2525,31 @@ class OneCycleLR(LRScheduler):
return (end - start) * pct + start return (end - start) * pct + start
@override @override
def get_lr(self) -> list[float]: def get_lr(self) -> list[float | Tensor]:
"""Compute the learning rate of each parameter group.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Finds the appropriate :attr:`_schedule_phases` entry for the current
step and interpolates between its ``start_lr`` and ``end_lr`` using
:meth:`_anneal_func`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
When :attr:`cycle_momentum` is ``True``, this method has a side
effect of updating the optimizer's momentum.
"""
_warn_get_lr_called_within_step(self) _warn_get_lr_called_within_step(self)
lrs = [] lrs = []

View File

@ -454,8 +454,29 @@ class SWALR(LRScheduler):
return swa_lr return swa_lr
return (lr - alpha * swa_lr) / (1 - alpha) return (lr - alpha * swa_lr) / (1 - alpha)
@override
def get_lr(self): def get_lr(self):
"""Get learning rate.""" r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Uses :attr:`anneal_func` to interpolate between each group's
``group["lr"]`` and ``group["swa_lr"]`` over :attr:`anneal_epochs`
epochs. Once :attr:`anneal_epochs` is reached, keeps the learning rate
fixed at ``group["swa_lr"]``.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`, # `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
# so we ignore the type error here. See `LRScheduler.step()` for more details. # so we ignore the type error here. See `LRScheduler.step()` for more details.
if not self._get_lr_called_within_step: if not self._get_lr_called_within_step: