Fix SequentialLR deprecate warning about invoke step(epoch) (#149392)

Fixes #116776 #76113 #113222 #67958
## Changes

- Refactor `LRScheduler.step` method, leave `epoch` check logic in public method `step`
- Move update `lr` logic to `_update_lr` method
- Make `SequentialLR` use `_update_lr` to avoid unnecessary warning message

## Test Result

```bash
pytest test/optim/test_lrscheduler.py -vv
```

![image](https://github.com/user-attachments/assets/e1c5527e-193e-4328-bf95-023139ea0416)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149392
Approved by: https://github.com/janeyx99
This commit is contained in:
zeshengzong 2025-08-29 11:45:07 +00:00 committed by PyTorch MergeBot
parent ed370ae4b0
commit 448a7e7e31
2 changed files with 18 additions and 2 deletions

View File

@ -784,6 +784,19 @@ class TestLRScheduler(TestCase):
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
self._test(scheduler, targets, epochs)
def test_sequentiallr_no_warnings(self):
scheduler1 = LinearLR(self.opt, start_factor=0.5, end_factor=0.1, total_iters=5)
scheduler2 = ExponentialLR(self.opt, gamma=0.9)
scheduler = SequentialLR(
self.opt, schedulers=[scheduler1, scheduler2], milestones=[5]
)
for _ in range(10):
self.opt.step()
with warnings.catch_warnings(record=True) as ws:
scheduler.step()
self.assertTrue(len(ws) == 0, "No warning should be raised")
def test_get_last_lr_sequentiallr(self):
epochs = 12
milestones = [3, 6]

View File

@ -200,13 +200,16 @@ class LRScheduler:
)
self._step_count += 1
if epoch is not None:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self._update_lr(epoch)
def _update_lr(self, epoch: Optional[int] = None):
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = cast(list[float], self._get_closed_form_lr())
@ -913,7 +916,7 @@ class SequentialLR(LRScheduler):
idx = bisect_right(self._milestones, self.last_epoch)
scheduler = self._schedulers[idx]
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
scheduler.step(0)
scheduler._update_lr(0)
else:
scheduler.step()