Add lr_lambda type check in MultiplicativeLR (#151973)

Fixes #81554

## TestResult

### Before

```python
In [3]: import torch
   ...: class SimpleLinearModel(torch.nn.Module):
   ...:     def __init__(self):
   ...:         super(SimpleLinearModel, self).__init__()
   ...:         self.linear = torch.nn.Linear(10, 1)
   ...:
   ...:     def forward(self, x):
   ...:         return self.linear(x)
   ...:
   ...: net = SimpleLinearModel()
   ...: optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
   ...: scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 0.95)
   ...: for i in range(10):
   ...:     print(i, scheduler.get_last_lr())
   ...:     scheduler.step()
TypeError: 'float' object is not callable

### After

```python
   ...: scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 0.95)
TypeError: lr_lambda should be a function, but got float
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151973
Approved by: https://github.com/janeyx99
This commit is contained in:
zeshengzong 2025-04-29 08:21:37 +00:00 committed by PyTorch MergeBot
parent dcd9a444b3
commit eb69f4e609
2 changed files with 14 additions and 0 deletions

View File

@ -1837,6 +1837,15 @@ class TestLRScheduler(TestCase):
) )
self._test(scheduler, targets, epochs) self._test(scheduler, targets, epochs)
def test_multiplicative_lr_with_lr_lambda(self):
lr_lambda = 0.95
with self.assertRaisesRegex(TypeError, "lr_lambda should be a function"):
MultiplicativeLR(self.opt, lr_lambda)
lr_lambda2 = 0.95
with self.assertRaisesRegex(TypeError, "lr_lambda should be a function"):
MultiplicativeLR(self.opt, [lr_lambda, lr_lambda2])
@parametrize("T_mult", [1, 2, 4]) @parametrize("T_mult", [1, 2, 4])
def test_CosineAnnealingWarmRestarts_lr1(self, T_mult): def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
iters = 100 iters = 100

View File

@ -398,6 +398,11 @@ class MultiplicativeLR(LRScheduler):
f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
) )
self.lr_lambdas = list(lr_lambda) self.lr_lambdas = list(lr_lambda)
for lr_lambda in self.lr_lambdas:
if not callable(lr_lambda):
raise TypeError(
f"lr_lambda should be a function, but got {type(lr_lambda).__name__}"
)
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
@override @override