mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
* added functionality for state_dict/load_state_dict for lr_scheduler * fixed linting issues/removed unused import * refactor lr_scheduler state_dicts/state_dict holds everything __dict__ but optimizer * changed documentation in lr_scheduler * Update lr_scheduler.py
This commit is contained in:
parent
072d49f787
commit
e44f901b55
|
|
@ -624,6 +624,40 @@ class TestLRScheduler(TestCase):
|
|||
lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
|
||||
self._test(scheduler, targets, epochs)
|
||||
|
||||
def test_step_lr_state_dict(self):
|
||||
self._check_scheduler_state_dict(
|
||||
lambda: StepLR(self.opt, gamma=0.1, step_size=3),
|
||||
lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1))
|
||||
|
||||
def test_multi_step_lr_state_dict(self):
|
||||
self._check_scheduler_state_dict(
|
||||
lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]),
|
||||
lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]))
|
||||
|
||||
def test_exp_step_lr_state_dict(self):
|
||||
self._check_scheduler_state_dict(
|
||||
lambda: ExponentialLR(self.opt, gamma=0.1),
|
||||
lambda: ExponentialLR(self.opt, gamma=0.01))
|
||||
|
||||
def test_cosine_lr_state_dict(self):
|
||||
epochs = 10
|
||||
eta_min = 1e-10
|
||||
self._check_scheduler_state_dict(
|
||||
lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min),
|
||||
lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
|
||||
epochs=epochs)
|
||||
|
||||
def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
|
||||
scheduler = constr()
|
||||
for _ in range(epochs):
|
||||
scheduler.step()
|
||||
scheduler_copy = constr2()
|
||||
scheduler_copy.load_state_dict(scheduler.state_dict())
|
||||
for key in scheduler.__dict__.keys():
|
||||
if key != 'optimizer':
|
||||
self.assertAlmostEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
|
||||
self.assertAlmostEqual(scheduler.get_lr(), scheduler_copy.get_lr())
|
||||
|
||||
def _test(self, scheduler, targets, epochs=10):
|
||||
for epoch in range(epochs):
|
||||
scheduler.step(epoch)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import math
|
||||
from bisect import bisect_right
|
||||
from functools import partial
|
||||
|
||||
from .optimizer import Optimizer
|
||||
|
||||
|
||||
|
|
@ -23,6 +22,29 @@ class _LRScheduler(object):
|
|||
self.step(last_epoch + 1)
|
||||
self.last_epoch = last_epoch
|
||||
|
||||
def __getstate__(self):
|
||||
return self.state_dict()
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.load_state_dict(state)
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a :class:`dict`.
|
||||
|
||||
It contains an entry for every variable in self.__dict__ which
|
||||
is not the optimizer.
|
||||
"""
|
||||
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the schedulers state.
|
||||
|
||||
Arguments:
|
||||
state_dict (dict): scheduler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -55,6 +77,7 @@ class LambdaLR(_LRScheduler):
|
|||
>>> train(...)
|
||||
>>> validate(...)
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
|
||||
self.optimizer = optimizer
|
||||
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user