from pathlib import Path import matplotlib from matplotlib import pyplot as plt import torch import torch.optim as optim from torch.optim.lr_scheduler import ( ChainedScheduler, ConstantLR, CosineAnnealingLR, CosineAnnealingWarmRestarts, CyclicLR, ExponentialLR, LambdaLR, LinearLR, MultiplicativeLR, MultiStepLR, OneCycleLR, PolynomialLR, ReduceLROnPlateau, SequentialLR, StepLR, ) matplotlib.use("Agg") LR_SCHEDULER_IMAGE_PATH = Path(__file__).parent / "lr_scheduler_images" if not LR_SCHEDULER_IMAGE_PATH.exists(): LR_SCHEDULER_IMAGE_PATH.mkdir() model = torch.nn.Linear(10, 1) optimizer = optim.SGD(model.parameters(), lr=0.05) num_epochs = 100 scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=num_epochs // 5) scheduler2 = ExponentialLR(optimizer, gamma=0.9) schedulers = [ (lambda opt: LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)), (lambda opt: MultiplicativeLR(opt, lr_lambda=lambda epoch: 0.95)), (lambda opt: StepLR(opt, step_size=30, gamma=0.1)), (lambda opt: MultiStepLR(opt, milestones=[30, 80], gamma=0.1)), (lambda opt: ConstantLR(opt, factor=0.5, total_iters=40)), (lambda opt: LinearLR(opt, start_factor=0.05, total_iters=40)), (lambda opt: ExponentialLR(opt, gamma=0.95)), (lambda opt: PolynomialLR(opt, total_iters=num_epochs / 2, power=0.9)), (lambda opt: CosineAnnealingLR(opt, T_max=num_epochs)), (lambda opt: CosineAnnealingWarmRestarts(opt, T_0=20)), (lambda opt: CyclicLR(opt, base_lr=0.01, max_lr=0.1, step_size_up=10)), (lambda opt: OneCycleLR(opt, max_lr=0.01, epochs=10, steps_per_epoch=10)), (lambda opt: ReduceLROnPlateau(opt, mode="min")), (lambda opt: ChainedScheduler([scheduler1, scheduler2])), ( lambda opt: SequentialLR( opt, schedulers=[scheduler1, scheduler2], milestones=[num_epochs // 5] ) ), ] def plot_function(scheduler): plt.clf() plt.grid(color="k", alpha=0.2, linestyle="--") lrs = [] optimizer.param_groups[0]["lr"] = 0.05 scheduler = scheduler(optimizer) plot_path = LR_SCHEDULER_IMAGE_PATH / f"{scheduler.__class__.__name__}.png" if plot_path.exists(): return for _ in range(num_epochs): lrs.append(optimizer.param_groups[0]["lr"]) if isinstance(scheduler, ReduceLROnPlateau): val_loss = torch.randn(1).item() scheduler.step(val_loss) else: scheduler.step() plt.plot(range(num_epochs), lrs) plt.title(f"Learning Rate: {scheduler.__class__.__name__}") plt.xlabel("Epoch") plt.ylabel("Learning Rate") plt.xlim([0, num_epochs]) plt.savefig(plot_path) print( f"Saved learning rate scheduler image for {scheduler.__class__.__name__} at {plot_path}" ) for scheduler in schedulers: plot_function(scheduler)