mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add lr_lambda type check in MultiplicativeLR (#151973)
Fixes #81554 ## TestResult ### Before ```python In [3]: import torch ...: class SimpleLinearModel(torch.nn.Module): ...: def __init__(self): ...: super(SimpleLinearModel, self).__init__() ...: self.linear = torch.nn.Linear(10, 1) ...: ...: def forward(self, x): ...: return self.linear(x) ...: ...: net = SimpleLinearModel() ...: optimizer = torch.optim.Adam(net.parameters(), lr=0.01) ...: scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 0.95) ...: for i in range(10): ...: print(i, scheduler.get_last_lr()) ...: scheduler.step() TypeError: 'float' object is not callable ### After ```python ...: scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 0.95) TypeError: lr_lambda should be a function, but got float ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/151973 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
dcd9a444b3
commit
eb69f4e609
|
|
@ -1837,6 +1837,15 @@ class TestLRScheduler(TestCase):
|
||||||
)
|
)
|
||||||
self._test(scheduler, targets, epochs)
|
self._test(scheduler, targets, epochs)
|
||||||
|
|
||||||
|
def test_multiplicative_lr_with_lr_lambda(self):
|
||||||
|
lr_lambda = 0.95
|
||||||
|
with self.assertRaisesRegex(TypeError, "lr_lambda should be a function"):
|
||||||
|
MultiplicativeLR(self.opt, lr_lambda)
|
||||||
|
|
||||||
|
lr_lambda2 = 0.95
|
||||||
|
with self.assertRaisesRegex(TypeError, "lr_lambda should be a function"):
|
||||||
|
MultiplicativeLR(self.opt, [lr_lambda, lr_lambda2])
|
||||||
|
|
||||||
@parametrize("T_mult", [1, 2, 4])
|
@parametrize("T_mult", [1, 2, 4])
|
||||||
def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
|
def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
|
||||||
iters = 100
|
iters = 100
|
||||||
|
|
|
||||||
|
|
@ -398,6 +398,11 @@ class MultiplicativeLR(LRScheduler):
|
||||||
f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
|
f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
|
||||||
)
|
)
|
||||||
self.lr_lambdas = list(lr_lambda)
|
self.lr_lambdas = list(lr_lambda)
|
||||||
|
for lr_lambda in self.lr_lambdas:
|
||||||
|
if not callable(lr_lambda):
|
||||||
|
raise TypeError(
|
||||||
|
f"lr_lambda should be a function, but got {type(lr_lambda).__name__}"
|
||||||
|
)
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user