mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add beta1 support to CyclicLR momentum (#113548)
Fixes #73910 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113548 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
d01ba4e94e
commit
d810b10232
|
|
@ -8,7 +8,7 @@ from functools import partial
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from torch.optim import Adam, SGD
|
from torch.optim import Adam, SGD, Rprop
|
||||||
from torch.optim.lr_scheduler import (
|
from torch.optim.lr_scheduler import (
|
||||||
LambdaLR,
|
LambdaLR,
|
||||||
MultiplicativeLR,
|
MultiplicativeLR,
|
||||||
|
|
@ -1510,8 +1510,12 @@ class TestLRScheduler(TestCase):
|
||||||
|
|
||||||
def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
|
def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
adam_opt = Adam(self.net.parameters())
|
rprop_opt = Rprop(self.net.parameters())
|
||||||
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
|
scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True)
|
||||||
|
|
||||||
|
def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self):
|
||||||
|
adam_opt = Adam(self.net.parameters())
|
||||||
|
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
|
||||||
|
|
||||||
def test_cycle_lr_removed_after_out_of_scope(self):
|
def test_cycle_lr_removed_after_out_of_scope(self):
|
||||||
import gc
|
import gc
|
||||||
|
|
|
||||||
|
|
@ -1268,15 +1268,20 @@ class CyclicLR(LRScheduler):
|
||||||
|
|
||||||
self.cycle_momentum = cycle_momentum
|
self.cycle_momentum = cycle_momentum
|
||||||
if cycle_momentum:
|
if cycle_momentum:
|
||||||
if 'momentum' not in optimizer.defaults:
|
if 'momentum' not in optimizer.defaults and 'betas' not in optimizer.defaults:
|
||||||
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
|
raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
|
||||||
|
|
||||||
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
self.use_beta1 = 'betas' in self.optimizer.defaults
|
||||||
if last_epoch == -1:
|
self.base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
||||||
for momentum, group in zip(base_momentums, optimizer.param_groups):
|
|
||||||
group['momentum'] = momentum
|
|
||||||
self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
|
|
||||||
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
||||||
|
if last_epoch == -1:
|
||||||
|
for m_momentum, b_momentum, group in zip(self.max_momentums, self.base_momentums, optimizer.param_groups):
|
||||||
|
if self.use_beta1:
|
||||||
|
group['betas'] = (m_momentum, *group['betas'][1:])
|
||||||
|
else:
|
||||||
|
group['momentum'] = m_momentum
|
||||||
|
group['max_momentum'] = m_momentum
|
||||||
|
group['base_momentum'] = b_momentum
|
||||||
|
|
||||||
super().__init__(optimizer, last_epoch, verbose)
|
super().__init__(optimizer, last_epoch, verbose)
|
||||||
self.base_lrs = base_lrs
|
self.base_lrs = base_lrs
|
||||||
|
|
@ -1359,7 +1364,10 @@ class CyclicLR(LRScheduler):
|
||||||
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
|
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
|
||||||
momentums.append(momentum)
|
momentums.append(momentum)
|
||||||
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
|
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
|
||||||
param_group['momentum'] = momentum
|
if self.use_beta1:
|
||||||
|
param_group['betas'] = (momentum, *param_group['betas'][1:])
|
||||||
|
else:
|
||||||
|
param_group['momentum'] = momentum
|
||||||
|
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
|
|
@ -1721,7 +1729,7 @@ class OneCycleLR(LRScheduler):
|
||||||
self.cycle_momentum = cycle_momentum
|
self.cycle_momentum = cycle_momentum
|
||||||
if self.cycle_momentum:
|
if self.cycle_momentum:
|
||||||
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
|
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
|
||||||
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
|
raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
|
||||||
self.use_beta1 = 'betas' in self.optimizer.defaults
|
self.use_beta1 = 'betas' in self.optimizer.defaults
|
||||||
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
||||||
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user