Add beta1 support to CyclicLR momentum (#113548)

Fixes #73910

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113548
Approved by: https://github.com/janeyx99
This commit is contained in:
rockerBOO 2024-01-23 01:16:54 +00:00 committed by PyTorch MergeBot
parent d01ba4e94e
commit d810b10232
2 changed files with 24 additions and 12 deletions

View File

@ -8,7 +8,7 @@ from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import Adam, SGD from torch.optim import Adam, SGD, Rprop
from torch.optim.lr_scheduler import ( from torch.optim.lr_scheduler import (
LambdaLR, LambdaLR,
MultiplicativeLR, MultiplicativeLR,
@ -1510,8 +1510,12 @@ class TestLRScheduler(TestCase):
def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self): def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
adam_opt = Adam(self.net.parameters()) rprop_opt = Rprop(self.net.parameters())
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True) scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True)
def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self):
adam_opt = Adam(self.net.parameters())
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
def test_cycle_lr_removed_after_out_of_scope(self): def test_cycle_lr_removed_after_out_of_scope(self):
import gc import gc

View File

@ -1268,15 +1268,20 @@ class CyclicLR(LRScheduler):
self.cycle_momentum = cycle_momentum self.cycle_momentum = cycle_momentum
if cycle_momentum: if cycle_momentum:
if 'momentum' not in optimizer.defaults: if 'momentum' not in optimizer.defaults and 'betas' not in optimizer.defaults:
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
base_momentums = self._format_param('base_momentum', optimizer, base_momentum) self.use_beta1 = 'betas' in self.optimizer.defaults
if last_epoch == -1: self.base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
for momentum, group in zip(base_momentums, optimizer.param_groups):
group['momentum'] = momentum
self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(self.max_momentums, self.base_momentums, optimizer.param_groups):
if self.use_beta1:
group['betas'] = (m_momentum, *group['betas'][1:])
else:
group['momentum'] = m_momentum
group['max_momentum'] = m_momentum
group['base_momentum'] = b_momentum
super().__init__(optimizer, last_epoch, verbose) super().__init__(optimizer, last_epoch, verbose)
self.base_lrs = base_lrs self.base_lrs = base_lrs
@ -1359,7 +1364,10 @@ class CyclicLR(LRScheduler):
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
momentums.append(momentum) momentums.append(momentum)
for param_group, momentum in zip(self.optimizer.param_groups, momentums): for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['momentum'] = momentum if self.use_beta1:
param_group['betas'] = (momentum, *param_group['betas'][1:])
else:
param_group['momentum'] = momentum
return lrs return lrs
@ -1721,7 +1729,7 @@ class OneCycleLR(LRScheduler):
self.cycle_momentum = cycle_momentum self.cycle_momentum = cycle_momentum
if self.cycle_momentum: if self.cycle_momentum:
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
self.use_beta1 = 'betas' in self.optimizer.defaults self.use_beta1 = 'betas' in self.optimizer.defaults
max_momentums = self._format_param('max_momentum', optimizer, max_momentum) max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
base_momentums = self._format_param('base_momentum', optimizer, base_momentum) base_momentums = self._format_param('base_momentum', optimizer, base_momentum)