mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix SequentialLR deprecate warning about invoke step(epoch) (#149392)
Fixes #116776 #76113 #113222 #67958 ## Changes - Refactor `LRScheduler.step` method, leave `epoch` check logic in public method `step` - Move update `lr` logic to `_update_lr` method - Make `SequentialLR` use `_update_lr` to avoid unnecessary warning message ## Test Result ```bash pytest test/optim/test_lrscheduler.py -vv ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/149392 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
ed370ae4b0
commit
448a7e7e31
|
|
@ -784,6 +784,19 @@ class TestLRScheduler(TestCase):
|
|||
scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
|
||||
self._test(scheduler, targets, epochs)
|
||||
|
||||
def test_sequentiallr_no_warnings(self):
|
||||
scheduler1 = LinearLR(self.opt, start_factor=0.5, end_factor=0.1, total_iters=5)
|
||||
scheduler2 = ExponentialLR(self.opt, gamma=0.9)
|
||||
scheduler = SequentialLR(
|
||||
self.opt, schedulers=[scheduler1, scheduler2], milestones=[5]
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
self.opt.step()
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
scheduler.step()
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
def test_get_last_lr_sequentiallr(self):
|
||||
epochs = 12
|
||||
milestones = [3, 6]
|
||||
|
|
|
|||
|
|
@ -200,13 +200,16 @@ class LRScheduler:
|
|||
)
|
||||
|
||||
self._step_count += 1
|
||||
if epoch is not None:
|
||||
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
|
||||
self._update_lr(epoch)
|
||||
|
||||
def _update_lr(self, epoch: Optional[int] = None):
|
||||
with _enable_get_lr_call(self):
|
||||
if epoch is None:
|
||||
self.last_epoch += 1
|
||||
values = self.get_lr()
|
||||
else:
|
||||
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
|
||||
self.last_epoch = epoch
|
||||
if hasattr(self, "_get_closed_form_lr"):
|
||||
values = cast(list[float], self._get_closed_form_lr())
|
||||
|
|
@ -913,7 +916,7 @@ class SequentialLR(LRScheduler):
|
|||
idx = bisect_right(self._milestones, self.last_epoch)
|
||||
scheduler = self._schedulers[idx]
|
||||
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
|
||||
scheduler.step(0)
|
||||
scheduler._update_lr(0)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user