mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
release/1.6
1 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
22ac071d9a |
Add SWA to PyTorch mainline (#35032)
Summary: This PR is based on the issue https://github.com/pytorch/pytorch/issues/29994#issue-524418771 and the discussion in the previous version of the PR https://github.com/pytorch/pytorch/pull/30559. Specifically, I followed the interface outlined in this [comment](https://github.com/pytorch/pytorch/pull/30559#issuecomment-574864768). ## Structure - `torch/optim/swa_utils.py` contains the implementation of `AveragedModel` class, `SWALR` learning rate scheduler and `update_bn` utility - `test/test_optim.py` contains unit tests for the three components of SWA - `torch/optim/swa_utils.pyi` describes the interface of `torch/optim/swa_utils.py` The new implementation consists of - `AveragedModel` class; this class creates a copy of a given model and allows to compute running averages of the parameters. - `SWALR` learning rate scheduler; after a certain number of epochs switches to a constant learning rate; this scheduler is supposed to be chained with other schedulers. - `update_bn` utility; updates the Batch Normalization activation statistics for a given model and dataloader; this utility is meant to be applied to `AveragedModel` instances. For `update_bn` I simplified the implementation compared to the [original PR](https://github.com/pytorch/pytorch/pull/30559) according to the sugestions by vadimkantorov. ## Example ```python loader, optimizer, model = ... swa_model = torch.optim.swa_utils.AveragedModel(model) # You can use custom averaging functions with `avg_fun` parameter ema_avg = lambda p_avg, p, n_avg: 0.1 * p_avg + 0.9 * p ema_model = torch.optim.swa_utils.AveragedModel(model, avg_function=ema_avg) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) swa_start = 160 swa_scheduler = SWALR(optimizer, start_epoch=swa_start, swa_lr=0.05) for i in range(300): for input, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() scheduler.step() swa_scheduler.step() if i > swa_start: swa_model.update_parameters(model) # Update bn statistics for the swa_model at the end torch.optim.swa_utils.update_bn(loader, swa_model) ``` UPDATED: ```python3 loader, optimizer, model, loss_fn = ... swa_model = torch.optim.swa_utils.AveragedModel(model) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) swa_start = 160 swa_scheduler = SWALR(optimizer, swa_lr=0.05) for i in range(300): for input, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() if i > swa_start: swa_model.update_parameters(model) swa_scheduler.step() else: scheduler.step() # Update bn statistics for the swa_model at the end torch.optim.swa_utils.update_bn(loader, swa_model) ``` Fixes https://github.com/pytorch/pytorch/issues/29994 cc soumith vincentqb andrewgordonwilson vadimkantorov Pull Request resolved: https://github.com/pytorch/pytorch/pull/35032 Differential Revision: D21079606 Pulled By: vincentqb fbshipit-source-id: e07f5e821f72ada63789814c2dcbdc31f0160c37 |