mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add lr_lambda type check in MultiplicativeLR (#151973)
Fixes #81554 ## TestResult ### Before ```python In [3]: import torch ...: class SimpleLinearModel(torch.nn.Module): ...: def __init__(self): ...: super(SimpleLinearModel, self).__init__() ...: self.linear = torch.nn.Linear(10, 1) ...: ...: def forward(self, x): ...: return self.linear(x) ...: ...: net = SimpleLinearModel() ...: optimizer = torch.optim.Adam(net.parameters(), lr=0.01) ...: scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 0.95) ...: for i in range(10): ...: print(i, scheduler.get_last_lr()) ...: scheduler.step() TypeError: 'float' object is not callable ### After ```python ...: scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 0.95) TypeError: lr_lambda should be a function, but got float ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/151973 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
dcd9a444b3
commit
eb69f4e609
|
|
@ -1837,6 +1837,15 @@ class TestLRScheduler(TestCase):
|
|||
)
|
||||
self._test(scheduler, targets, epochs)
|
||||
|
||||
def test_multiplicative_lr_with_lr_lambda(self):
|
||||
lr_lambda = 0.95
|
||||
with self.assertRaisesRegex(TypeError, "lr_lambda should be a function"):
|
||||
MultiplicativeLR(self.opt, lr_lambda)
|
||||
|
||||
lr_lambda2 = 0.95
|
||||
with self.assertRaisesRegex(TypeError, "lr_lambda should be a function"):
|
||||
MultiplicativeLR(self.opt, [lr_lambda, lr_lambda2])
|
||||
|
||||
@parametrize("T_mult", [1, 2, 4])
|
||||
def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
|
||||
iters = 100
|
||||
|
|
|
|||
|
|
@ -398,6 +398,11 @@ class MultiplicativeLR(LRScheduler):
|
|||
f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
|
||||
)
|
||||
self.lr_lambdas = list(lr_lambda)
|
||||
for lr_lambda in self.lr_lambdas:
|
||||
if not callable(lr_lambda):
|
||||
raise TypeError(
|
||||
f"lr_lambda should be a function, but got {type(lr_lambda).__name__}"
|
||||
)
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
@override
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user