diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index fbc9e691777..cea85b07646 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -77,7 +77,7 @@ class TestLRScheduler(TestCase): self.opt = SGD( [ {"params": self.net.conv1.parameters()}, - {"params": self.net.conv2.parameters(), "lr": 0.5}, + {"params": self.net.conv2.parameters(), "lr": torch.tensor(0.5)}, ], lr=0.05, ) @@ -2530,7 +2530,7 @@ class TestLRScheduler(TestCase): ], ) def test_constant_initial_lr(self, LRClass): - # Test that the initial learning rate is constant + # Test that the initial learning rate is constant and that it does not alias base_lrs lr = torch.as_tensor(0.1) opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) sch = LRClass(opt) @@ -2544,6 +2544,7 @@ class TestLRScheduler(TestCase): for group, ori_group in zip(opt.param_groups, ori_param_groups): self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) self.assertEqual(sch.base_lrs, [0.1]) + self.assertIsNot(sch.base_lrs[0], group["initial_lr"]) def test_constant_initial_params_cyclelr(self): # Test that the initial learning rate is constant diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 92288d0cbdf..fdbab432ff4 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -79,9 +79,20 @@ def _format_param(name: str, optimizer: Optimizer, param): return list(map(_copy, param)) +def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]: + """Create a list containing group[key] for each optimizer param_group. + Prevents aliasing when group[key] could be a Tensor. + Raises a KeyError when group[key] does not exist. + """ + return [ + group[key].clone() if isinstance(group[key], Tensor) else group[key] + for group in optimizer.param_groups + ] + + def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): - """Set param_group[key] to val without aliasing or assignment when they're both tensors. - Raises a KeyError if param_group[key] does not exist. + """Set param_group[key] to val without aliasing or assignment when they're + both tensors. Raises a KeyError if param_group[key] does not exist. """ if isinstance(param_group[key], Tensor): param_group[key].fill_(_to_scalar(val)) @@ -90,7 +101,25 @@ def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | class LRScheduler: - r"""Adjusts the learning rate during optimization.""" + r"""Base class for all learning rate schedulers. + + Subclasses implement :meth:`get_lr` and optionally override :meth:`step` to + define scheduling behavior. + + Args: + optimizer (Optimizer): The optimizer this scheduler will adjust the + learning rates of. + last_epoch (int): Index of the last epoch seen by the scheduler. Use + ``-1`` (default) to initialize the scheduler. Only use a non-default + value when restoring this scheduler from a saved checkpoint. + + .. warning:: + Initializing a scheduler overwrites its optimizer's + ``param_group["lr"]``\s. When restoring a checkpoint, initialize the + scheduler **before** calling your optimizer's + :meth:`~torch.optim.Optimizer.load_state_dict` to avoid overwriting the + loaded learning rates. + """ _get_lr_called_within_step: bool = False _is_initial: bool = False @@ -121,9 +150,9 @@ class LRScheduler: "1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n" "2. You're using last_epoch >= 0 for a fresh training run (not recommended)" ) - self.base_lrs: list[float] = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs: list[float | Tensor] = _param_groups_val_list( + optimizer, "initial_lr" + ) self.last_epoch = last_epoch # Following https://github.com/pytorch/pytorch/issues/20124 @@ -161,7 +190,7 @@ class LRScheduler: def state_dict(self) -> dict[str, Any]: """Return the state of the scheduler as a :class:`dict`. - It contains an entry for every variable in self.__dict__ which + It contains an entry for every variable in ``self.__dict__`` which is not the optimizer. """ return { @@ -177,16 +206,58 @@ class LRScheduler: """ self.__dict__.update(state_dict) - def get_last_lr(self) -> list[float]: - """Return last computed learning rate by current scheduler.""" + def get_last_lr(self) -> list[float | Tensor]: + r"""Get the most recent learning rates computed by this scheduler. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates with entries + for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`, with the same types as + their ``group["lr"]``\s. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ + # We always update self._last_lr with _param_groups_val_list, so it's a + # .clone() of the group["lr"]s. If we didn't do this, the user could + # corrupt their learning rates by modifying the outputs in place. return self._last_lr - def get_lr(self) -> list[float]: - """Compute learning rate using chainable form of the scheduler.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ raise NotImplementedError def step(self, epoch: Optional[int] = None) -> None: - """Perform a step.""" + """Step the scheduler. + + Args: + epoch (int, optional): + .. deprecated:: 1.4 + If provided, sets :attr:`last_epoch` to ``epoch`` and uses + :meth:`_get_closed_form_lr` if it is available. This is not + universally supported. Use :meth:`step` without arguments + instead. + + .. note:: + Call this method after calling the optimizer's + :meth:`~torch.optim.Optimizer.step`. + """ # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 if self._step_count == 1: @@ -224,16 +295,18 @@ class LRScheduler: else: self.last_epoch = epoch if hasattr(self, "_get_closed_form_lr"): - values = cast(list[float], self._get_closed_form_lr()) + values = cast( + list[Union[float, Tensor]], self._get_closed_form_lr() + ) else: values = self.get_lr() for param_group, lr in zip(self.optimizer.param_groups, values): _update_param_group_val(param_group, "lr", lr) - self._last_lr: list[float] = [ - group["lr"] for group in self.optimizer.param_groups - ] + self._last_lr: list[float | Tensor] = _param_groups_val_list( + self.optimizer, "lr" + ) def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None: @@ -333,8 +406,7 @@ class LambdaLR(LRScheduler): def state_dict(self) -> dict[str, Any]: """Return the state of the scheduler as a :class:`dict`. - It contains an entry for every variable in self.__dict__ which - is not the optimizer. + It contains an entry for every variable in ``self.__dict__`` which is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas. @@ -374,8 +446,26 @@ class LambdaLR(LRScheduler): self.lr_lambdas[idx].__dict__.update(fn) @override - def get_lr(self) -> list[float]: - """Compute learning rate.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Scales the :attr:`base_lrs` by the outputs of the :attr:`lr_lambdas` at + :attr:`last_epoch`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) return [ @@ -436,7 +526,7 @@ class MultiplicativeLR(LRScheduler): def state_dict(self) -> dict[str, Any]: """Return the state of the scheduler as a :class:`dict`. - It contains an entry for every variable in self.__dict__ which + It contains an entry for every variable in ``self.__dict__`` which is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas. @@ -473,8 +563,27 @@ class MultiplicativeLR(LRScheduler): self.lr_lambdas[idx].__dict__.update(fn) @override - def get_lr(self) -> list[float]: - """Compute the learning rate of each parameter group.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Scales the current ``group["lr"]``\s in each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` by the outputs of the + :attr:`lr_lambdas` at :attr:`last_epoch`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) if not self._is_initial: @@ -483,7 +592,7 @@ class MultiplicativeLR(LRScheduler): for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) ] else: - return [group["lr"] for group in self.optimizer.param_groups] + return _param_groups_val_list(self.optimizer, "lr") class StepLR(LRScheduler): @@ -527,15 +636,46 @@ class StepLR(LRScheduler): super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: - """Compute the learning rate of each parameter group.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + If the current epoch is a non-zero multiple of :attr:`step_size`, we + scale the current ``group["lr"]``\s in the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): - return [group["lr"] for group in self.optimizer.param_groups] + return _param_groups_val_list(self.optimizer, "lr") return [group["lr"] * self.gamma for group in self.optimizer.param_groups] - def _get_closed_form_lr(self) -> list[float]: + def _get_closed_form_lr(self) -> list[float | Tensor]: + r"""Compute learning rates for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using + a closed-form formula. + + Uses :attr:`base_lrs` to compute learning rates. This method is called + when an epoch is passed to :meth:`step`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + """ return [ base_lr * self.gamma ** (self.last_epoch // self.step_size) for base_lr in self.base_lrs @@ -582,18 +722,53 @@ class MultiStepLR(LRScheduler): super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: - """Compute the learning rate of each parameter group.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + If the current epoch is in :attr:`milestones`, decays the + ``group["lr"]``\s in the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + + .. note:: + If the current epoch appears in :attr:`milestones` ``n`` times, we + scale by :attr:`gamma` to the power of ``n`` + """ _warn_get_lr_called_within_step(self) if self.last_epoch not in self.milestones: - return [group["lr"] for group in self.optimizer.param_groups] + return _param_groups_val_list(self.optimizer, "lr") return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] def _get_closed_form_lr(self): + r"""Compute learning rates for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using + a closed-form formula. + + Uses :attr:`base_lrs` to compute learning rates. This method is called + when an epoch is passed to :meth:`step`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + """ milestones = sorted(self.milestones.elements()) return [ base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) @@ -651,21 +826,53 @@ class ConstantLR(LRScheduler): super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: - """Compute the learning rate of each parameter group.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + When :attr:`last_epoch` is 0, this method scales the ``group["lr"]``\s + in each of the optimizer's :attr:`~torch.optim.Optimizer.param_groups` + by :attr:`factor`. Once :attr:`total_iters` is reached, it undoes this, + scaling by ``1 / factor``. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) if self.last_epoch == 0: return [group["lr"] * self.factor for group in self.optimizer.param_groups] if self.last_epoch != self.total_iters: - return [group["lr"] for group in self.optimizer.param_groups] + return _param_groups_val_list(self.optimizer, "lr") return [ group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups ] def _get_closed_form_lr(self): + r"""Compute learning rates for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using + a closed-form formula. + + Uses :attr:`base_lrs` to compute learning rates. This method is called + when an epoch is passed to :meth:`step`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + """ return [ base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) @@ -733,8 +940,28 @@ class LinearLR(LRScheduler): super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: - """Compute the learning rate.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Scales the ``group["lr"]``\s in the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` such that successive steps + interpolate linearly from :attr:`start_factor` up to :attr:`end_factor` + across :attr:`total_iters` steps. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) if self.last_epoch == 0: @@ -743,7 +970,7 @@ class LinearLR(LRScheduler): ] if self._is_initial or self.last_epoch > self.total_iters: - return [group["lr"] for group in self.optimizer.param_groups] + return _param_groups_val_list(self.optimizer, "lr") return [ group["lr"] @@ -759,6 +986,18 @@ class LinearLR(LRScheduler): ] def _get_closed_form_lr(self): + r"""Compute learning rates for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using + a closed-form formula. + + Uses :attr:`base_lrs` to compute learning rates. This method is called + when an epoch is passed to :meth:`step`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + """ return [ base_lr * ( @@ -802,17 +1041,47 @@ class ExponentialLR(LRScheduler): super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: - """Compute the learning rate of each parameter group.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Multiplies the current ``group["lr"]``\s in the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) # when loading from a checkpoint, we don't want _initial_step (called from the constructor) # to update the lr one more step ahead of itself. if self._is_initial: - return [group["lr"] for group in self.optimizer.param_groups] + return _param_groups_val_list(self.optimizer, "lr") return [group["lr"] * self.gamma for group in self.optimizer.param_groups] def _get_closed_form_lr(self): + r"""Compute learning rates for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using + a closed-form formula. + + Uses :attr:`base_lrs` to compute learning rates. This method is called + when an epoch is passed to :meth:`step`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + """ return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] @@ -935,7 +1204,7 @@ class SequentialLR(LRScheduler): def state_dict(self) -> dict[str, Any]: """Return the state of the scheduler as a :class:`dict`. - It contains an entry for every variable in self.__dict__ which + It contains an entry for every variable in ``self.__dict__`` which is not the optimizer. The wrapped scheduler states will also be saved. """ @@ -1008,12 +1277,38 @@ class PolynomialLR(LRScheduler): super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: - """Compute the learning rate.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Scales the ``group["lr"]``\s in the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` such that the learning rates + follow + + .. math:: + \texttt{base\_lr} \cdot \left(1 - \frac{\texttt{last\_epoch}} + {\texttt{total\_iters}} \right)^\texttt{power} + + Returns the current learning rates unchanged after :attr:`total_iters` + is reached. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) if self._is_initial or self.last_epoch > self.total_iters: - return [group["lr"] for group in self.optimizer.param_groups] + return _param_groups_val_list(self.optimizer, "lr") decay_factor = ( (1.0 - self.last_epoch / self.total_iters) @@ -1021,7 +1316,19 @@ class PolynomialLR(LRScheduler): ) ** self.power return [group["lr"] * decay_factor for group in self.optimizer.param_groups] - def _get_closed_form_lr(self): + def _get_closed_form_lr(self) -> list[float | Tensor]: + r"""Compute learning rates for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using + a closed-form formula. + + Uses :attr:`base_lrs` to compute learning rates. This method is called + when an epoch is passed to :meth:`step`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + """ return [ ( base_lr @@ -1094,12 +1401,36 @@ class CosineAnnealingLR(LRScheduler): super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: - """Retrieve the learning rate of each parameter group.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Scales the ``group["lr"]``\s in the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` such that their learning + rates approximate + + .. math:: + \texttt{eta\_min} + \frac{1}{2} (\texttt{base\_lr} - + \texttt{eta\_min}) \left(1 + \cos\left(\pi \cdot + \frac{\texttt{last\_epoch}}{\texttt{T\_max}}\right) \right) + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) if self._is_initial: - return [group["lr"] for group in self.optimizer.param_groups] + return _param_groups_val_list(self.optimizer, "lr") elif self._step_count == 1 and self.last_epoch > 0: return [ self.eta_min @@ -1122,7 +1453,19 @@ class CosineAnnealingLR(LRScheduler): for group in self.optimizer.param_groups ] - def _get_closed_form_lr(self) -> list[float]: + def _get_closed_form_lr(self) -> list[float | Tensor]: + r"""Compute learning rates for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using + a closed-form formula. + + Uses :attr:`base_lrs` to compute learning rates. This method is called + when an epoch is passed to :meth:`step`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + """ return [ self.eta_min + (base_lr - self.eta_min) @@ -1191,23 +1534,19 @@ class ChainedScheduler(LRScheduler): ) self._schedulers = schedulers self.optimizer = optimizer - self._last_lr = [ - group["lr"] for group in self._schedulers[-1].optimizer.param_groups - ] + self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr") def step(self) -> None: # type: ignore[override] """Perform a step.""" for scheduler in self._schedulers: scheduler.step() - self._last_lr = [ - group["lr"] for group in self._schedulers[-1].optimizer.param_groups - ] + self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr") @override def state_dict(self) -> dict[str, Any]: """Return the state of the scheduler as a :class:`dict`. - It contains an entry for every variable in self.__dict__ which + It contains an entry for every variable in ``self.__dict__`` which is not the optimizer. The wrapped scheduler states will also be saved. """ @@ -1334,7 +1673,7 @@ class ReduceLROnPlateau(LRScheduler): self.cooldown = cooldown self.eps = eps self.last_epoch = 0 - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + self._last_lr = _param_groups_val_list(self.optimizer, "lr") self._init_is_better( mode=mode, threshold=threshold, threshold_mode=threshold_mode ) @@ -1371,7 +1710,7 @@ class ReduceLROnPlateau(LRScheduler): self.cooldown_counter = self.cooldown self.num_bad_epochs = 0 - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + self._last_lr = _param_groups_val_list(self.optimizer, "lr") def _reduce_lr(self, epoch): if len(self.optimizer.param_groups) != len(self.min_lrs): @@ -1561,11 +1900,7 @@ class CyclicLR(LRScheduler): base_lrs = _format_param("base_lr", optimizer, base_lr) if last_epoch == -1: for lr, group in zip(base_lrs, optimizer.param_groups): - if isinstance(group["lr"], Tensor): - lr_val = lr.item() if isinstance(lr, Tensor) else lr - group["lr"].fill_(lr_val) - else: - group["lr"] = lr + _update_param_group_val(group, "lr", lr) self.max_lrs = _format_param("max_lr", optimizer, max_lr) @@ -1649,13 +1984,34 @@ class CyclicLR(LRScheduler): return gamma**x @override - def get_lr(self) -> list[float]: - """Calculate the learning rate at batch index. + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. - This function treats `self.last_epoch` as the last batch index. + Advances each ``group["lr"]`` in the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` along a cycle between the + group's ``base_lr`` and ``max_lr`` using :meth:`scale_fn`. - If `self.cycle_momentum` is ``True``, this function has a side effect of - updating the optimizer's momentum. + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + + .. note:: + This method treats :attr:`last_epoch` as the index of the previous + batch. + + .. note:: + When :attr:`cycle_momentum` is ``True``, this method has a side + effect of updating the optimizer's momentum. """ _warn_get_lr_called_within_step(self) @@ -1700,7 +2056,7 @@ class CyclicLR(LRScheduler): def state_dict(self) -> dict[str, Any]: # noqa: D102 """Return the state of the scheduler as a :class:`dict`. - It contains an entry for every variable in self.__dict__ which + It contains an entry for every variable in ``self.__dict__`` which is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas. @@ -1795,8 +2151,36 @@ class CosineAnnealingWarmRestarts(LRScheduler): super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: - """Compute the initial learning rate.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Computes learning rates for the optimizer's + :attr:`~torch.optim.Optimizer.param_groups` following: + + .. math:: + \texttt{eta\_min} + \frac{1}{2}(\texttt{base\_lr} - + \texttt{eta\_min})\left(1 + \cos\left(\pi \cdot + \frac{\texttt{T\_cur}}{\texttt{T\_i}}\right)\right) + + Where :attr:`T_cur` is the number of epochs since the last restart and + :attr:`T_i` is the number of epochs between two restarts. Both + :attr:`T_cur` and :attr:`T_i` are updated in :meth:`step`, and + :attr:`T_i` becomes :attr:`T_mult` times larger after each restart. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ _warn_get_lr_called_within_step(self) return [ @@ -1869,7 +2253,7 @@ class CosineAnnealingWarmRestarts(LRScheduler): for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): _update_param_group_val(param_group, "lr", lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + self._last_lr = _param_groups_val_list(self.optimizer, "lr") class _SchedulePhase(TypedDict): @@ -2141,8 +2525,31 @@ class OneCycleLR(LRScheduler): return (end - start) * pct + start @override - def get_lr(self) -> list[float]: - """Compute the learning rate of each parameter group.""" + def get_lr(self) -> list[float | Tensor]: + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Finds the appropriate :attr:`_schedule_phases` entry for the current + step and interpolates between its ``start_lr`` and ``end_lr`` using + :meth:`_anneal_func`. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + + .. note:: + When :attr:`cycle_momentum` is ``True``, this method has a side + effect of updating the optimizer's momentum. + """ _warn_get_lr_called_within_step(self) lrs = [] diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index e15b796cdbe..d19760cfeab 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -454,8 +454,29 @@ class SWALR(LRScheduler): return swa_lr return (lr - alpha * swa_lr) / (1 - alpha) + @override def get_lr(self): - """Get learning rate.""" + r"""Compute the next learning rate for each of the optimizer's + :attr:`~torch.optim.Optimizer.param_groups`. + + Uses :attr:`anneal_func` to interpolate between each group's + ``group["lr"]`` and ``group["swa_lr"]`` over :attr:`anneal_epochs` + epochs. Once :attr:`anneal_epochs` is reached, keeps the learning rate + fixed at ``group["swa_lr"]``. + + Returns: + list[float | Tensor]: A :class:`list` of learning rates for each of + the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the + same types as their current ``group["lr"]``\s. + + .. note:: + If you're trying to inspect the most recent learning rate, use + :meth:`get_last_lr()` instead. + + .. note:: + The returned :class:`~torch.Tensor`\s are copies, and never alias + the optimizer's ``group["lr"]``\s. + """ # `_get_lr_called_within_step` is only available `_enable_get_lr_call`, # so we ignore the type error here. See `LRScheduler.step()` for more details. if not self._get_lr_called_within_step: