mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add beta1 support to CyclicLR momentum (#113548)
Fixes #73910 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113548 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
d01ba4e94e
commit
d810b10232
|
|
@ -8,7 +8,7 @@ from functools import partial
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Adam, SGD
|
||||
from torch.optim import Adam, SGD, Rprop
|
||||
from torch.optim.lr_scheduler import (
|
||||
LambdaLR,
|
||||
MultiplicativeLR,
|
||||
|
|
@ -1510,6 +1510,10 @@ class TestLRScheduler(TestCase):
|
|||
|
||||
def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
|
||||
with self.assertRaises(ValueError):
|
||||
rprop_opt = Rprop(self.net.parameters())
|
||||
scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True)
|
||||
|
||||
def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self):
|
||||
adam_opt = Adam(self.net.parameters())
|
||||
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1268,15 +1268,20 @@ class CyclicLR(LRScheduler):
|
|||
|
||||
self.cycle_momentum = cycle_momentum
|
||||
if cycle_momentum:
|
||||
if 'momentum' not in optimizer.defaults:
|
||||
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
|
||||
if 'momentum' not in optimizer.defaults and 'betas' not in optimizer.defaults:
|
||||
raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
|
||||
|
||||
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
||||
if last_epoch == -1:
|
||||
for momentum, group in zip(base_momentums, optimizer.param_groups):
|
||||
group['momentum'] = momentum
|
||||
self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
|
||||
self.use_beta1 = 'betas' in self.optimizer.defaults
|
||||
self.base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
||||
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
||||
if last_epoch == -1:
|
||||
for m_momentum, b_momentum, group in zip(self.max_momentums, self.base_momentums, optimizer.param_groups):
|
||||
if self.use_beta1:
|
||||
group['betas'] = (m_momentum, *group['betas'][1:])
|
||||
else:
|
||||
group['momentum'] = m_momentum
|
||||
group['max_momentum'] = m_momentum
|
||||
group['base_momentum'] = b_momentum
|
||||
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
self.base_lrs = base_lrs
|
||||
|
|
@ -1359,6 +1364,9 @@ class CyclicLR(LRScheduler):
|
|||
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
|
||||
momentums.append(momentum)
|
||||
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
|
||||
if self.use_beta1:
|
||||
param_group['betas'] = (momentum, *param_group['betas'][1:])
|
||||
else:
|
||||
param_group['momentum'] = momentum
|
||||
|
||||
return lrs
|
||||
|
|
@ -1721,7 +1729,7 @@ class OneCycleLR(LRScheduler):
|
|||
self.cycle_momentum = cycle_momentum
|
||||
if self.cycle_momentum:
|
||||
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
|
||||
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
|
||||
raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
|
||||
self.use_beta1 = 'betas' in self.optimizer.defaults
|
||||
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
|
||||
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user