mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix lr_scheduler unexpectedly calls step() when init argument last_epoch is larger than -1 (#149312)
Fixes #102261 ## Changes - Use flag `_is_initial` to replace `self.last_epoch == 0` condition to judge whether `lr` should be initial value - Add test for `ExponentialLR` checkpoint usecase ## Test Result ```python pytest -s test/optim/test_lrscheduler.py -vv ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/149312 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
parent
423fc671e9
commit
d7a83ab67b
|
|
@ -2575,6 +2575,56 @@ class TestLRScheduler(TestCase):
|
|||
self.assertEqual(group["swa_lr"], 0.05)
|
||||
self.assertEqual(sch.base_lrs, [0.1])
|
||||
|
||||
@parametrize(
|
||||
"LRClass",
|
||||
[
|
||||
partial(ExponentialLR, gamma=0.999),
|
||||
partial(LambdaLR, lr_lambda=lambda epoch: epoch // 30),
|
||||
partial(MultiplicativeLR, lr_lambda=lambda epoch: 0.95),
|
||||
partial(StepLR, step_size=30),
|
||||
partial(MultiStepLR, milestones=[30, 80]),
|
||||
ConstantLR,
|
||||
LinearLR,
|
||||
PolynomialLR,
|
||||
partial(CosineAnnealingLR, T_max=10),
|
||||
partial(CosineAnnealingWarmRestarts, T_0=20),
|
||||
partial(CyclicLR, base_lr=0.01, max_lr=0.1),
|
||||
partial(OneCycleLR, max_lr=0.01, total_steps=10),
|
||||
partial(SWALR, swa_lr=0.01),
|
||||
],
|
||||
)
|
||||
def test_lr_scheduler_checkpoint(self, LRClass):
|
||||
model = torch.nn.Linear(3, 3)
|
||||
optim = torch.optim.AdamW(model.parameters())
|
||||
sch = LRClass(optim)
|
||||
optim.step()
|
||||
sch.step()
|
||||
optim2 = torch.optim.AdamW(model.parameters())
|
||||
optim2.load_state_dict(optim.state_dict())
|
||||
sch2 = LRClass(optim2, last_epoch=0)
|
||||
self.assertEqual(
|
||||
sch2._get_closed_form_lr()[0]
|
||||
if hasattr(self, "_get_closed_form_lr")
|
||||
else sch2.get_last_lr()[0],
|
||||
optim.param_groups[0]["lr"],
|
||||
)
|
||||
|
||||
def test_lr_scheduler_checkpoint_on_plateau(self):
|
||||
model = torch.nn.Linear(3, 3)
|
||||
optim = torch.optim.AdamW(model.parameters())
|
||||
sch = ReduceLROnPlateau(optim, mode="min")
|
||||
optim.step()
|
||||
sch.step(1)
|
||||
optim2 = torch.optim.AdamW(model.parameters())
|
||||
optim2.load_state_dict(optim.state_dict())
|
||||
sch2 = ReduceLROnPlateau(optim2, mode="min")
|
||||
self.assertEqual(
|
||||
sch2._get_closed_form_lr()[0]
|
||||
if hasattr(self, "_get_closed_form_lr")
|
||||
else sch2.get_last_lr()[0],
|
||||
optim.param_groups[0]["lr"],
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestLRScheduler)
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ class LRScheduler:
|
|||
r"""Adjusts the learning rate during optimization."""
|
||||
|
||||
_get_lr_called_within_step: bool = False
|
||||
_is_initial: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -141,7 +142,8 @@ class LRScheduler:
|
|||
def _initial_step(self) -> None:
|
||||
"""Initialize step counts and perform a step."""
|
||||
self._step_count = 0
|
||||
self.step()
|
||||
with _initial_mode(self):
|
||||
self.step()
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""Return the state of the scheduler as a :class:`dict`.
|
||||
|
|
@ -195,6 +197,7 @@ class LRScheduler:
|
|||
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self._step_count += 1
|
||||
|
||||
with _enable_get_lr_call(self):
|
||||
|
|
@ -248,6 +251,17 @@ class _enable_get_lr_call:
|
|||
self.o._get_lr_called_within_step = False
|
||||
|
||||
|
||||
class _initial_mode:
|
||||
def __init__(self, o: LRScheduler):
|
||||
self.o = o
|
||||
|
||||
def __enter__(self):
|
||||
self.o._is_initial = True
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.o._is_initial = False
|
||||
|
||||
|
||||
class LambdaLR(LRScheduler):
|
||||
"""Sets the initial learning rate.
|
||||
|
||||
|
|
@ -450,7 +464,7 @@ class MultiplicativeLR(LRScheduler):
|
|||
"""Compute the learning rate of each parameter group."""
|
||||
_warn_get_lr_called_within_step(self)
|
||||
|
||||
if self.last_epoch > 0:
|
||||
if not self._is_initial:
|
||||
return [
|
||||
group["lr"] * lmbda(self.last_epoch)
|
||||
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
|
||||
|
|
@ -715,7 +729,7 @@ class LinearLR(LRScheduler):
|
|||
group["lr"] * self.start_factor for group in self.optimizer.param_groups
|
||||
]
|
||||
|
||||
if self.last_epoch > self.total_iters:
|
||||
if self._is_initial or self.last_epoch > self.total_iters:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
return [
|
||||
|
|
@ -779,7 +793,9 @@ class ExponentialLR(LRScheduler):
|
|||
"""Compute the learning rate of each parameter group."""
|
||||
_warn_get_lr_called_within_step(self)
|
||||
|
||||
if self.last_epoch == 0:
|
||||
# when loading from a checkpoint, we don't want _initial_step (called from the constructor)
|
||||
# to update the lr one more step ahead of itself.
|
||||
if self._is_initial:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
|
||||
|
||||
|
|
@ -979,7 +995,7 @@ class PolynomialLR(LRScheduler):
|
|||
"""Compute the learning rate."""
|
||||
_warn_get_lr_called_within_step(self)
|
||||
|
||||
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
|
||||
if self._is_initial or self.last_epoch > self.total_iters:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
decay_factor = (
|
||||
|
|
@ -1065,7 +1081,7 @@ class CosineAnnealingLR(LRScheduler):
|
|||
"""Retrieve the learning rate of each parameter group."""
|
||||
_warn_get_lr_called_within_step(self)
|
||||
|
||||
if self.last_epoch == 0:
|
||||
if self._is_initial:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
elif self._step_count == 1 and self.last_epoch > 0:
|
||||
return [
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user