Addresses bad behavior with overridden optimizer.step by #20124 (#21460)

Summary:
This PR addresses the problem described in the comment: https://github.com/pytorch/pytorch/pull/20203#issuecomment-499231276
and previously coded bad behaviour:
- a warning was raised all the times when lr schedulling is initialized

Now the code checks that:
- on the second call of `lr_scheduler.step`, ensure that `optimizer.step` has been already called, otherwise raise a warning (as it was done in #20203 )
- if optimizer's step is overridden -> raise once another warning to aware user about the new pattern:
`opt.step()` -> `lrs.step()` as we can not check this .

Now tests check that
- at initialization (`lrs = StepLR(...)`)there is no warnings
- if we replace `optimizer.step` by something else (similarly to the [code of nvidia/apex](https://github.com/NVIDIA/apex/blob/master/apex/amp/_process_optimizer.py#L287)) there is another warning raised.

cc ezyang

PS. honestly I would say that there is a lot of overhead introduced for simple warnings. I hope all these checks will be removed in future `1.2.0` or other versions...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21460

Differential Revision: D15701776

Pulled By: ezyang

fbshipit-source-id: eac5712b9146d9d3392a30f6339cd33d90c497c7
This commit is contained in:
vfn 2019-06-06 13:43:50 -07:00 committed by Facebook Github Bot
parent 51d0da2802
commit 8ece538a79
2 changed files with 101 additions and 21 deletions

View File

@ -1,3 +1,4 @@
import warnings
import math
import unittest
import functools
@ -529,7 +530,10 @@ class TestLRScheduler(TestCase):
def test_old_pattern_warning(self):
epochs = 35
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
self.assertTrue(len(ws) == 0, "No warning should be raised")
def old_pattern():
for e in range(epochs):
@ -540,7 +544,10 @@ class TestLRScheduler(TestCase):
def test_old_pattern_warning_with_arg(self):
epochs = 35
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
self.assertTrue(len(ws) == 0, "No warning should be raised")
def old_pattern2():
for e in range(epochs):
@ -554,7 +561,10 @@ class TestLRScheduler(TestCase):
for i, group in enumerate(self.opt.param_groups):
group['initial_lr'] = 0.01
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
self.assertTrue(len(ws) == 0, "No warning should be raised")
def old_pattern():
for e in range(epochs):
@ -568,7 +578,38 @@ class TestLRScheduler(TestCase):
for i, group in enumerate(self.opt.param_groups):
group['initial_lr'] = 0.01
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
self.assertTrue(len(ws) == 0, "No warning should be raised")
def old_pattern2():
for e in range(epochs):
scheduler.step(e)
self.opt.step()
self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
def test_old_pattern_warning_with_overriden_optim_step(self):
epochs = 35
for i, group in enumerate(self.opt.param_groups):
group['initial_lr'] = 0.01
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
self.assertTrue(len(ws) == 0, "No warning should be raised")
# emulate use-case with optimizer.step overriden
import types
old_step = self.opt.step
def new_step(o, *args, **kwargs):
retval = old_step(*args, **kwargs)
return retval
self.opt.step = types.MethodType(new_step, self.opt)
def old_pattern2():
for e in range(epochs):
@ -578,10 +619,11 @@ class TestLRScheduler(TestCase):
self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
def test_new_pattern_no_warning(self):
import warnings
epochs = 35
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
self.assertTrue(len(ws) == 0, "No warning should be raised")
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
@ -591,10 +633,11 @@ class TestLRScheduler(TestCase):
self.assertTrue(len(ws) == 0, "No warning should be raised")
def test_new_pattern_no_warning_with_arg(self):
import warnings
epochs = 35
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
self.assertTrue(len(ws) == 0, "No warning should be raised")
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
@ -603,6 +646,31 @@ class TestLRScheduler(TestCase):
scheduler.step(e)
self.assertTrue(len(ws) == 0, "No warning should be raised")
def test_new_pattern_no_warning_with_overriden_optim_step(self):
epochs = 35
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
self.assertTrue(len(ws) == 0, "No warning should be raised")
# emulate use-case with optimizer.step overriden
import types
old_step = self.opt.step
def new_step(o, *args, **kwargs):
retval = old_step(*args, **kwargs)
return retval
self.opt.step = types.MethodType(new_step, self.opt)
def new_pattern():
for e in range(epochs):
self.opt.step()
scheduler.step()
self.assertWarnsRegex(new_pattern, r'`optimizer.step\(\)` has been overridden')
def test_step_lr(self):
# lr = 0.05 if epoch < 3
# lr = 0.005 if 30 <= epoch < 6

View File

@ -29,15 +29,17 @@ class _LRScheduler(object):
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(func):
def with_counter(func, opt):
@wraps(func)
def wrapper(*args, **kwargs):
wrapper.called += 1
opt._step_count += 1
return func(*args, **kwargs)
wrapper.called = 0
wrapper._with_counter = True
return wrapper
self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer.step = with_counter(self.optimizer.step, self.optimizer)
self.optimizer._step_count = 0
self._step_count = 0
self.step(last_epoch)
def state_dict(self):
@ -63,13 +65,23 @@ class _LRScheduler(object):
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self.optimizer.step.called < 1:
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule."
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
self._step_count += 1
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch