Better Error Message in ChainedScheduler and SequentialLR (#121633)

Fixes #121577

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121633
Approved by: https://github.com/janeyx99
This commit is contained in:
GdoongMathew 2024-04-19 13:37:41 +00:00 committed by PyTorch MergeBot
parent c9db59e9e4
commit 8b1ad51881
2 changed files with 58 additions and 22 deletions

View File

@ -5,6 +5,7 @@ import weakref
from bisect import bisect_right
from collections import Counter
from functools import partial, wraps
from typing import Optional, Sequence
from torch import inf
@ -781,18 +782,29 @@ class SequentialLR(LRScheduler):
def __init__(
self, optimizer, schedulers, milestones, last_epoch=-1, verbose="deprecated"
):
for scheduler_idx in range(len(schedulers)):
if schedulers[scheduler_idx].optimizer != optimizer:
if len(schedulers) < 1:
raise ValueError(
f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler."
)
for scheduler_idx, scheduler in enumerate(schedulers):
if not hasattr(scheduler, "optimizer"):
raise TypeError(
f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
)
if isinstance(scheduler, ReduceLROnPlateau):
raise ValueError(
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
"requires additional kwargs to be specified when calling `step`, "
f"but got one at index {scheduler_idx} in the given schedulers sequence."
)
if optimizer != scheduler.optimizer:
raise ValueError(
f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
f"which is different from {optimizer.__class__.__name__}."
)
if schedulers[scheduler_idx].optimizer != schedulers[0].optimizer:
raise ValueError(
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
f"got schedulers at index {0} and {scheduler_idx} to be different."
)
if len(milestones) != len(schedulers) - 1:
raise ValueError(
"Sequential Schedulers expects number of schedulers provided to be one more "
@ -1024,12 +1036,13 @@ class CosineAnnealingLR(LRScheduler):
class ChainedScheduler(LRScheduler):
"""Chains list of learning rate schedulers. It takes a list of chainable learning
"""Chains list of learning rate schedulers. It takes a sequence of chainable learning
rate schedulers and performs consecutive step() functions belonging to them by just
one call.
Args:
schedulers (list): List of chained schedulers.
schedulers (sequence): sequence of chained schedulers.
optimizer (Optimizer, optional): Wrapped optimizer. Default: None.
Example:
>>> # xdoctest: +SKIP
@ -1041,22 +1054,41 @@ class ChainedScheduler(LRScheduler):
>>> # lr = 0.59049 if epoch >= 4
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = ChainedScheduler([scheduler1, scheduler2])
>>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, schedulers):
for scheduler_idx in range(1, len(schedulers)):
if schedulers[scheduler_idx].optimizer != schedulers[0].optimizer:
raise ValueError(
"ChainedScheduler expects all schedulers to belong to the same optimizer, but "
f"got schedulers at index {0} and {scheduler_idx} to be different"
def __init__(
self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None
):
if len(schedulers) < 1:
raise ValueError(
f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler."
)
optimizer = optimizer or schedulers[0].optimizer
for scheduler_idx, scheduler in enumerate(schedulers):
if not hasattr(scheduler, "optimizer"):
raise TypeError(
f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
)
self._schedulers = list(schedulers)
self.optimizer = schedulers[0].optimizer
if isinstance(scheduler, ReduceLROnPlateau):
raise ValueError(
f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
"requires additional kwargs to be specified when calling `step`, "
f"but got one at index {scheduler_idx} in the given schedulers sequence."
)
if optimizer != scheduler.optimizer:
raise ValueError(
f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
f"which is different from {optimizer.__class__.__name__}."
)
self._schedulers = schedulers
self.optimizer = optimizer
self._last_lr = [
group["lr"] for group in self._schedulers[-1].optimizer.param_groups
]

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union
from .optimizer import Optimizer
@ -109,7 +109,11 @@ class ExponentialLR(LRScheduler):
) -> None: ...
class ChainedScheduler(LRScheduler):
def __init__(self, schedulers: List[LRScheduler]) -> None: ...
def __init__(
self,
schedulers: Sequence[LRScheduler],
optimizer: Optional[Optimizer] = ...,
) -> None: ...
class SequentialLR(LRScheduler):
def __init__(