From e44f901b55873ebb6b1b0d3bab30fd89d487b71c Mon Sep 17 00:00:00 2001 From: Armen Date: Thu, 19 Apr 2018 04:09:03 -0700 Subject: [PATCH] added functionality for state_dict/load_state_dict for lr_scheduler ( Fixes: #3026 ) (#6342) * added functionality for state_dict/load_state_dict for lr_scheduler * fixed linting issues/removed unused import * refactor lr_scheduler state_dicts/state_dict holds everything __dict__ but optimizer * changed documentation in lr_scheduler * Update lr_scheduler.py --- test/test_optim.py | 34 ++++++++++++++++++++++++++++++++++ torch/optim/lr_scheduler.py | 25 ++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/test/test_optim.py b/test/test_optim.py index 8467add29f9..8d5e41c228a 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -624,6 +624,40 @@ class TestLRScheduler(TestCase): lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2]) self._test(scheduler, targets, epochs) + def test_step_lr_state_dict(self): + self._check_scheduler_state_dict( + lambda: StepLR(self.opt, gamma=0.1, step_size=3), + lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1)) + + def test_multi_step_lr_state_dict(self): + self._check_scheduler_state_dict( + lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]), + lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6])) + + def test_exp_step_lr_state_dict(self): + self._check_scheduler_state_dict( + lambda: ExponentialLR(self.opt, gamma=0.1), + lambda: ExponentialLR(self.opt, gamma=0.01)) + + def test_cosine_lr_state_dict(self): + epochs = 10 + eta_min = 1e-10 + self._check_scheduler_state_dict( + lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min), + lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2), + epochs=epochs) + + def _check_scheduler_state_dict(self, constr, constr2, epochs=10): + scheduler = constr() + for _ in range(epochs): + scheduler.step() + scheduler_copy = constr2() + scheduler_copy.load_state_dict(scheduler.state_dict()) + for key in scheduler.__dict__.keys(): + if key != 'optimizer': + self.assertAlmostEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) + self.assertAlmostEqual(scheduler.get_lr(), scheduler_copy.get_lr()) + def _test(self, scheduler, targets, epochs=10): for epoch in range(epochs): scheduler.step(epoch) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 9ce2988b201..1c9ae024f3a 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,7 +1,6 @@ import math from bisect import bisect_right from functools import partial - from .optimizer import Optimizer @@ -23,6 +22,29 @@ class _LRScheduler(object): self.step(last_epoch + 1) self.last_epoch = last_epoch + def __getstate__(self): + return self.state_dict() + + def __setstate__(self, state): + self.load_state_dict(state) + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Arguments: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + def get_lr(self): raise NotImplementedError @@ -55,6 +77,7 @@ class LambdaLR(_LRScheduler): >>> train(...) >>> validate(...) """ + def __init__(self, optimizer, lr_lambda, last_epoch=-1): self.optimizer = optimizer if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):