Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857
These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
- `GLOSSARY.md`
- `aten/src/ATen/core/op_registration/README.md`
- `scripts/README.md`
- `torch/csrc/jit/codegen/fuser/README.md`
The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```
I looked over the auto-generated changes and didn't see anything that looked problematic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406
Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377
This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348
Reviewed By: walterddr, seemethere
Differential Revision: D26856620
Pulled By: samestep
fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
Summary:
Change `avg_fun -> avg_fn` to match the spelling in the `.py` file.
(`swa_utils.pyi` should match `swa_utils.py`)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51608
Reviewed By: glaringlee
Differential Revision: D26224779
Pulled By: zou3519
fbshipit-source-id: 01ff7173ba0a996f1b7a653438acb6b6b4659de6
Summary:
This PR is based on the issue https://github.com/pytorch/pytorch/issues/29994#issue-524418771 and the discussion in the previous version of the PR https://github.com/pytorch/pytorch/pull/30559. Specifically, I followed the interface outlined in this [comment](https://github.com/pytorch/pytorch/pull/30559#issuecomment-574864768).
## Structure
- `torch/optim/swa_utils.py` contains the implementation of `AveragedModel` class, `SWALR` learning rate scheduler and `update_bn` utility
- `test/test_optim.py` contains unit tests for the three components of SWA
- `torch/optim/swa_utils.pyi` describes the interface of `torch/optim/swa_utils.py`
The new implementation consists of
- `AveragedModel` class; this class creates a copy of a given model and allows to compute running averages of the parameters.
- `SWALR` learning rate scheduler; after a certain number of epochs switches to a constant learning rate; this scheduler is supposed to be chained with other schedulers.
- `update_bn` utility; updates the Batch Normalization activation statistics for a given model and dataloader; this utility is meant to be applied to `AveragedModel` instances.
For `update_bn` I simplified the implementation compared to the [original PR](https://github.com/pytorch/pytorch/pull/30559) according to the sugestions by vadimkantorov.
## Example
```python
loader, optimizer, model = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
# You can use custom averaging functions with `avg_fun` parameter
ema_avg = lambda p_avg, p, n_avg: 0.1 * p_avg + 0.9 * p
ema_model = torch.optim.swa_utils.AveragedModel(model,
avg_function=ema_avg)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, start_epoch=swa_start, swa_lr=0.05)
for i in range(300):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
scheduler.step()
swa_scheduler.step()
if i > swa_start:
swa_model.update_parameters(model)
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
```
UPDATED:
```python3
loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
for i in range(300):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
if i > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
```
Fixes https://github.com/pytorch/pytorch/issues/29994
cc soumith vincentqb andrewgordonwilson vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35032
Differential Revision: D21079606
Pulled By: vincentqb
fbshipit-source-id: e07f5e821f72ada63789814c2dcbdc31f0160c37