added functionality for state_dict/load_state_dict for lr_scheduler ( Fixes: #3026 ) (#6342)

* added functionality for state_dict/load_state_dict for lr_scheduler

* fixed linting issues/removed unused import

* refactor lr_scheduler state_dicts/state_dict holds everything __dict__ but optimizer

* changed documentation in lr_scheduler

* Update lr_scheduler.py
This commit is contained in:
Armen 2018-04-19 04:09:03 -07:00 committed by Soumith Chintala
parent 072d49f787
commit e44f901b55
2 changed files with 58 additions and 1 deletions

View File

@ -624,6 +624,40 @@ class TestLRScheduler(TestCase):
lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
self._test(scheduler, targets, epochs)
def test_step_lr_state_dict(self):
self._check_scheduler_state_dict(
lambda: StepLR(self.opt, gamma=0.1, step_size=3),
lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1))
def test_multi_step_lr_state_dict(self):
self._check_scheduler_state_dict(
lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]),
lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]))
def test_exp_step_lr_state_dict(self):
self._check_scheduler_state_dict(
lambda: ExponentialLR(self.opt, gamma=0.1),
lambda: ExponentialLR(self.opt, gamma=0.01))
def test_cosine_lr_state_dict(self):
epochs = 10
eta_min = 1e-10
self._check_scheduler_state_dict(
lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min),
lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
epochs=epochs)
def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
scheduler = constr()
for _ in range(epochs):
scheduler.step()
scheduler_copy = constr2()
scheduler_copy.load_state_dict(scheduler.state_dict())
for key in scheduler.__dict__.keys():
if key != 'optimizer':
self.assertAlmostEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
self.assertAlmostEqual(scheduler.get_lr(), scheduler_copy.get_lr())
def _test(self, scheduler, targets, epochs=10):
for epoch in range(epochs):
scheduler.step(epoch)

View File

@ -1,7 +1,6 @@
import math
from bisect import bisect_right
from functools import partial
from .optimizer import Optimizer
@ -23,6 +22,29 @@ class _LRScheduler(object):
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def __getstate__(self):
return self.state_dict()
def __setstate__(self, state):
self.load_state_dict(state)
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
@ -55,6 +77,7 @@ class LambdaLR(_LRScheduler):
>>> train(...)
>>> validate(...)
"""
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
self.optimizer = optimizer
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):