diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 42f194b0f40..8de3af4b87f 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -5,6 +5,7 @@ import weakref from bisect import bisect_right from collections import Counter from functools import partial, wraps +from typing import Optional, Sequence from torch import inf @@ -781,18 +782,29 @@ class SequentialLR(LRScheduler): def __init__( self, optimizer, schedulers, milestones, last_epoch=-1, verbose="deprecated" ): - for scheduler_idx in range(len(schedulers)): - if schedulers[scheduler_idx].optimizer != optimizer: + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler." + ) + + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." + ) + if isinstance(scheduler, ReduceLROnPlateau): raise ValueError( - "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " - f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in." + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." ) - if schedulers[scheduler_idx].optimizer != schedulers[0].optimizer: - raise ValueError( - "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " - f"got schedulers at index {0} and {scheduler_idx} to be different." - ) if len(milestones) != len(schedulers) - 1: raise ValueError( "Sequential Schedulers expects number of schedulers provided to be one more " @@ -1024,12 +1036,13 @@ class CosineAnnealingLR(LRScheduler): class ChainedScheduler(LRScheduler): - """Chains list of learning rate schedulers. It takes a list of chainable learning + """Chains list of learning rate schedulers. It takes a sequence of chainable learning rate schedulers and performs consecutive step() functions belonging to them by just one call. Args: - schedulers (list): List of chained schedulers. + schedulers (sequence): sequence of chained schedulers. + optimizer (Optimizer, optional): Wrapped optimizer. Default: None. Example: >>> # xdoctest: +SKIP @@ -1041,22 +1054,41 @@ class ChainedScheduler(LRScheduler): >>> # lr = 0.59049 if epoch >= 4 >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2) >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9) - >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ - def __init__(self, schedulers): - for scheduler_idx in range(1, len(schedulers)): - if schedulers[scheduler_idx].optimizer != schedulers[0].optimizer: - raise ValueError( - "ChainedScheduler expects all schedulers to belong to the same optimizer, but " - f"got schedulers at index {0} and {scheduler_idx} to be different" + def __init__( + self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None + ): + if len(schedulers) < 1: + raise ValueError( + f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler." + ) + + optimizer = optimizer or schedulers[0].optimizer + for scheduler_idx, scheduler in enumerate(schedulers): + if not hasattr(scheduler, "optimizer"): + raise TypeError( + f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute." ) - self._schedulers = list(schedulers) - self.optimizer = schedulers[0].optimizer + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it " + "requires additional kwargs to be specified when calling `step`, " + f"but got one at index {scheduler_idx} in the given schedulers sequence." + ) + if optimizer != scheduler.optimizer: + raise ValueError( + f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but " + f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, " + f"which is different from {optimizer.__class__.__name__}." + ) + self._schedulers = schedulers + self.optimizer = optimizer self._last_lr = [ group["lr"] for group in self._schedulers[-1].optimizer.param_groups ] diff --git a/torch/optim/lr_scheduler.pyi b/torch/optim/lr_scheduler.pyi index 2446c80bc54..ce3f26a4d21 100644 --- a/torch/optim/lr_scheduler.pyi +++ b/torch/optim/lr_scheduler.pyi @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union from .optimizer import Optimizer @@ -109,7 +109,11 @@ class ExponentialLR(LRScheduler): ) -> None: ... class ChainedScheduler(LRScheduler): - def __init__(self, schedulers: List[LRScheduler]) -> None: ... + def __init__( + self, + schedulers: Sequence[LRScheduler], + optimizer: Optional[Optimizer] = ..., + ) -> None: ... class SequentialLR(LRScheduler): def __init__(