mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR addresses the problem described in the comment: https://github.com/pytorch/pytorch/pull/20203#issuecomment-499231276 and previously coded bad behaviour: - a warning was raised all the times when lr schedulling is initialized Now the code checks that: - on the second call of `lr_scheduler.step`, ensure that `optimizer.step` has been already called, otherwise raise a warning (as it was done in #20203 ) - if optimizer's step is overridden -> raise once another warning to aware user about the new pattern: `opt.step()` -> `lrs.step()` as we can not check this . Now tests check that - at initialization (`lrs = StepLR(...)`)there is no warnings - if we replace `optimizer.step` by something else (similarly to the [code of nvidia/apex](https://github.com/NVIDIA/apex/blob/master/apex/amp/_process_optimizer.py#L287)) there is another warning raised. cc ezyang PS. honestly I would say that there is a lot of overhead introduced for simple warnings. I hope all these checks will be removed in future `1.2.0` or other versions... Pull Request resolved: https://github.com/pytorch/pytorch/pull/21460 Differential Revision: D15701776 Pulled By: ezyang fbshipit-source-id: eac5712b9146d9d3392a30f6339cd33d90c497c7
This commit is contained in:
parent
51d0da2802
commit
8ece538a79
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
import math
|
||||
import unittest
|
||||
import functools
|
||||
|
|
@ -529,7 +530,10 @@ class TestLRScheduler(TestCase):
|
|||
|
||||
def test_old_pattern_warning(self):
|
||||
epochs = 35
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
def old_pattern():
|
||||
for e in range(epochs):
|
||||
|
|
@ -540,7 +544,10 @@ class TestLRScheduler(TestCase):
|
|||
|
||||
def test_old_pattern_warning_with_arg(self):
|
||||
epochs = 35
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
def old_pattern2():
|
||||
for e in range(epochs):
|
||||
|
|
@ -554,7 +561,10 @@ class TestLRScheduler(TestCase):
|
|||
for i, group in enumerate(self.opt.param_groups):
|
||||
group['initial_lr'] = 0.01
|
||||
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
def old_pattern():
|
||||
for e in range(epochs):
|
||||
|
|
@ -568,7 +578,38 @@ class TestLRScheduler(TestCase):
|
|||
for i, group in enumerate(self.opt.param_groups):
|
||||
group['initial_lr'] = 0.01
|
||||
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
def old_pattern2():
|
||||
for e in range(epochs):
|
||||
scheduler.step(e)
|
||||
self.opt.step()
|
||||
|
||||
self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
|
||||
|
||||
def test_old_pattern_warning_with_overriden_optim_step(self):
|
||||
epochs = 35
|
||||
for i, group in enumerate(self.opt.param_groups):
|
||||
group['initial_lr'] = 0.01
|
||||
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
# emulate use-case with optimizer.step overriden
|
||||
import types
|
||||
|
||||
old_step = self.opt.step
|
||||
|
||||
def new_step(o, *args, **kwargs):
|
||||
retval = old_step(*args, **kwargs)
|
||||
return retval
|
||||
|
||||
self.opt.step = types.MethodType(new_step, self.opt)
|
||||
|
||||
def old_pattern2():
|
||||
for e in range(epochs):
|
||||
|
|
@ -578,10 +619,11 @@ class TestLRScheduler(TestCase):
|
|||
self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
|
||||
|
||||
def test_new_pattern_no_warning(self):
|
||||
import warnings
|
||||
|
||||
epochs = 35
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
|
|
@ -591,10 +633,11 @@ class TestLRScheduler(TestCase):
|
|||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
def test_new_pattern_no_warning_with_arg(self):
|
||||
import warnings
|
||||
|
||||
epochs = 35
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
|
|
@ -603,6 +646,31 @@ class TestLRScheduler(TestCase):
|
|||
scheduler.step(e)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
def test_new_pattern_no_warning_with_overriden_optim_step(self):
|
||||
epochs = 35
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always") # allow any warning to be raised
|
||||
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
|
||||
self.assertTrue(len(ws) == 0, "No warning should be raised")
|
||||
|
||||
# emulate use-case with optimizer.step overriden
|
||||
import types
|
||||
|
||||
old_step = self.opt.step
|
||||
|
||||
def new_step(o, *args, **kwargs):
|
||||
retval = old_step(*args, **kwargs)
|
||||
return retval
|
||||
|
||||
self.opt.step = types.MethodType(new_step, self.opt)
|
||||
|
||||
def new_pattern():
|
||||
for e in range(epochs):
|
||||
self.opt.step()
|
||||
scheduler.step()
|
||||
|
||||
self.assertWarnsRegex(new_pattern, r'`optimizer.step\(\)` has been overridden')
|
||||
|
||||
def test_step_lr(self):
|
||||
# lr = 0.05 if epoch < 3
|
||||
# lr = 0.005 if 30 <= epoch < 6
|
||||
|
|
|
|||
|
|
@ -29,15 +29,17 @@ class _LRScheduler(object):
|
|||
# Following https://github.com/pytorch/pytorch/issues/20124
|
||||
# We would like to ensure that `lr_scheduler.step()` is called after
|
||||
# `optimizer.step()`
|
||||
def with_counter(func):
|
||||
def with_counter(func, opt):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
wrapper.called += 1
|
||||
opt._step_count += 1
|
||||
return func(*args, **kwargs)
|
||||
wrapper.called = 0
|
||||
wrapper._with_counter = True
|
||||
return wrapper
|
||||
|
||||
self.optimizer.step = with_counter(self.optimizer.step)
|
||||
self.optimizer.step = with_counter(self.optimizer.step, self.optimizer)
|
||||
self.optimizer._step_count = 0
|
||||
self._step_count = 0
|
||||
self.step(last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
|
|
@ -63,13 +65,23 @@ class _LRScheduler(object):
|
|||
def step(self, epoch=None):
|
||||
# Raise a warning if old pattern is detected
|
||||
# https://github.com/pytorch/pytorch/issues/20124
|
||||
if self.optimizer.step.called < 1:
|
||||
if self._step_count == 1:
|
||||
if not hasattr(self.optimizer.step, "_with_counter"):
|
||||
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
|
||||
"initialization. Please, make sure to call `optimizer.step()` before "
|
||||
"`lr_scheduler.step()`. See more details at "
|
||||
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
|
||||
|
||||
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
|
||||
elif self.optimizer._step_count < 1:
|
||||
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
|
||||
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
|
||||
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
|
||||
"will result in PyTorch skipping the first value of the learning rate schedule."
|
||||
"See more details at "
|
||||
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
|
||||
self._step_count += 1
|
||||
|
||||
if epoch is None:
|
||||
epoch = self.last_epoch + 1
|
||||
self.last_epoch = epoch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user