diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index 2671f5b7c93..b44eb9058c4 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -1837,6 +1837,15 @@ class TestLRScheduler(TestCase): ) self._test(scheduler, targets, epochs) + def test_multiplicative_lr_with_lr_lambda(self): + lr_lambda = 0.95 + with self.assertRaisesRegex(TypeError, "lr_lambda should be a function"): + MultiplicativeLR(self.opt, lr_lambda) + + lr_lambda2 = 0.95 + with self.assertRaisesRegex(TypeError, "lr_lambda should be a function"): + MultiplicativeLR(self.opt, [lr_lambda, lr_lambda2]) + @parametrize("T_mult", [1, 2, 4]) def test_CosineAnnealingWarmRestarts_lr1(self, T_mult): iters = 100 diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 06ae24ef08d..8c553fb4901 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -398,6 +398,11 @@ class MultiplicativeLR(LRScheduler): f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}" ) self.lr_lambdas = list(lr_lambda) + for lr_lambda in self.lr_lambdas: + if not callable(lr_lambda): + raise TypeError( + f"lr_lambda should be a function, but got {type(lr_lambda).__name__}" + ) super().__init__(optimizer, last_epoch) @override