Fix lr_scheduler unexpectedly calls step() when init argument last_epoch is larger than -1 (#149312)

Fixes #102261

## Changes

- Use flag `_is_initial` to replace `self.last_epoch == 0` condition to judge whether `lr` should be initial value
- Add test for `ExponentialLR` checkpoint usecase

## Test Result

```python
pytest -s test/optim/test_lrscheduler.py  -vv
```

![image](https://github.com/user-attachments/assets/6fd32bcc-b4fb-4421-b891-620bd4900dc1)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149312
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
zeshengzong 2025-05-22 08:42:33 +00:00 committed by PyTorch MergeBot
parent 423fc671e9
commit d7a83ab67b
2 changed files with 72 additions and 6 deletions

View File

@ -2575,6 +2575,56 @@ class TestLRScheduler(TestCase):
self.assertEqual(group["swa_lr"], 0.05)
self.assertEqual(sch.base_lrs, [0.1])
@parametrize(
"LRClass",
[
partial(ExponentialLR, gamma=0.999),
partial(LambdaLR, lr_lambda=lambda epoch: epoch // 30),
partial(MultiplicativeLR, lr_lambda=lambda epoch: 0.95),
partial(StepLR, step_size=30),
partial(MultiStepLR, milestones=[30, 80]),
ConstantLR,
LinearLR,
PolynomialLR,
partial(CosineAnnealingLR, T_max=10),
partial(CosineAnnealingWarmRestarts, T_0=20),
partial(CyclicLR, base_lr=0.01, max_lr=0.1),
partial(OneCycleLR, max_lr=0.01, total_steps=10),
partial(SWALR, swa_lr=0.01),
],
)
def test_lr_scheduler_checkpoint(self, LRClass):
model = torch.nn.Linear(3, 3)
optim = torch.optim.AdamW(model.parameters())
sch = LRClass(optim)
optim.step()
sch.step()
optim2 = torch.optim.AdamW(model.parameters())
optim2.load_state_dict(optim.state_dict())
sch2 = LRClass(optim2, last_epoch=0)
self.assertEqual(
sch2._get_closed_form_lr()[0]
if hasattr(self, "_get_closed_form_lr")
else sch2.get_last_lr()[0],
optim.param_groups[0]["lr"],
)
def test_lr_scheduler_checkpoint_on_plateau(self):
model = torch.nn.Linear(3, 3)
optim = torch.optim.AdamW(model.parameters())
sch = ReduceLROnPlateau(optim, mode="min")
optim.step()
sch.step(1)
optim2 = torch.optim.AdamW(model.parameters())
optim2.load_state_dict(optim.state_dict())
sch2 = ReduceLROnPlateau(optim2, mode="min")
self.assertEqual(
sch2._get_closed_form_lr()[0]
if hasattr(self, "_get_closed_form_lr")
else sch2.get_last_lr()[0],
optim.param_groups[0]["lr"],
)
instantiate_parametrized_tests(TestLRScheduler)

View File

@ -82,6 +82,7 @@ class LRScheduler:
r"""Adjusts the learning rate during optimization."""
_get_lr_called_within_step: bool = False
_is_initial: bool = False
def __init__(
self,
@ -141,7 +142,8 @@ class LRScheduler:
def _initial_step(self) -> None:
"""Initialize step counts and perform a step."""
self._step_count = 0
self.step()
with _initial_mode(self):
self.step()
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
@ -195,6 +197,7 @@ class LRScheduler:
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
)
self._step_count += 1
with _enable_get_lr_call(self):
@ -248,6 +251,17 @@ class _enable_get_lr_call:
self.o._get_lr_called_within_step = False
class _initial_mode:
def __init__(self, o: LRScheduler):
self.o = o
def __enter__(self):
self.o._is_initial = True
def __exit__(self, type, value, traceback):
self.o._is_initial = False
class LambdaLR(LRScheduler):
"""Sets the initial learning rate.
@ -450,7 +464,7 @@ class MultiplicativeLR(LRScheduler):
"""Compute the learning rate of each parameter group."""
_warn_get_lr_called_within_step(self)
if self.last_epoch > 0:
if not self._is_initial:
return [
group["lr"] * lmbda(self.last_epoch)
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
@ -715,7 +729,7 @@ class LinearLR(LRScheduler):
group["lr"] * self.start_factor for group in self.optimizer.param_groups
]
if self.last_epoch > self.total_iters:
if self._is_initial or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
return [
@ -779,7 +793,9 @@ class ExponentialLR(LRScheduler):
"""Compute the learning rate of each parameter group."""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0:
# when loading from a checkpoint, we don't want _initial_step (called from the constructor)
# to update the lr one more step ahead of itself.
if self._is_initial:
return [group["lr"] for group in self.optimizer.param_groups]
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
@ -979,7 +995,7 @@ class PolynomialLR(LRScheduler):
"""Compute the learning rate."""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
if self._is_initial or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
decay_factor = (
@ -1065,7 +1081,7 @@ class CosineAnnealingLR(LRScheduler):
"""Retrieve the learning rate of each parameter group."""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0:
if self._is_initial:
return [group["lr"] for group in self.optimizer.param_groups]
elif self._step_count == 1 and self.last_epoch > 0:
return [