mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
* added functionality for state_dict/load_state_dict for lr_scheduler * fixed linting issues/removed unused import * refactor lr_scheduler state_dicts/state_dict holds everything __dict__ but optimizer * changed documentation in lr_scheduler * Update lr_scheduler.py
This commit is contained in:
parent
072d49f787
commit
e44f901b55
|
|
@ -624,6 +624,40 @@ class TestLRScheduler(TestCase):
|
||||||
lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
|
lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
|
||||||
self._test(scheduler, targets, epochs)
|
self._test(scheduler, targets, epochs)
|
||||||
|
|
||||||
|
def test_step_lr_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: StepLR(self.opt, gamma=0.1, step_size=3),
|
||||||
|
lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1))
|
||||||
|
|
||||||
|
def test_multi_step_lr_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]),
|
||||||
|
lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]))
|
||||||
|
|
||||||
|
def test_exp_step_lr_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: ExponentialLR(self.opt, gamma=0.1),
|
||||||
|
lambda: ExponentialLR(self.opt, gamma=0.01))
|
||||||
|
|
||||||
|
def test_cosine_lr_state_dict(self):
|
||||||
|
epochs = 10
|
||||||
|
eta_min = 1e-10
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min),
|
||||||
|
lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
|
||||||
|
epochs=epochs)
|
||||||
|
|
||||||
|
def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
|
||||||
|
scheduler = constr()
|
||||||
|
for _ in range(epochs):
|
||||||
|
scheduler.step()
|
||||||
|
scheduler_copy = constr2()
|
||||||
|
scheduler_copy.load_state_dict(scheduler.state_dict())
|
||||||
|
for key in scheduler.__dict__.keys():
|
||||||
|
if key != 'optimizer':
|
||||||
|
self.assertAlmostEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
|
||||||
|
self.assertAlmostEqual(scheduler.get_lr(), scheduler_copy.get_lr())
|
||||||
|
|
||||||
def _test(self, scheduler, targets, epochs=10):
|
def _test(self, scheduler, targets, epochs=10):
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
scheduler.step(epoch)
|
scheduler.step(epoch)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import math
|
import math
|
||||||
from bisect import bisect_right
|
from bisect import bisect_right
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from .optimizer import Optimizer
|
from .optimizer import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -23,6 +22,29 @@ class _LRScheduler(object):
|
||||||
self.step(last_epoch + 1)
|
self.step(last_epoch + 1)
|
||||||
self.last_epoch = last_epoch
|
self.last_epoch = last_epoch
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return self.state_dict()
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.load_state_dict(state)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Returns the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
|
It contains an entry for every variable in self.__dict__ which
|
||||||
|
is not the optimizer.
|
||||||
|
"""
|
||||||
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Loads the schedulers state.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
state_dict (dict): scheduler state. Should be an object returned
|
||||||
|
from a call to :meth:`state_dict`.
|
||||||
|
"""
|
||||||
|
self.__dict__.update(state_dict)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -55,6 +77,7 @@ class LambdaLR(_LRScheduler):
|
||||||
>>> train(...)
|
>>> train(...)
|
||||||
>>> validate(...)
|
>>> validate(...)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
|
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user