mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR is based on the issue https://github.com/pytorch/pytorch/issues/29994#issue-524418771 and the discussion in the previous version of the PR https://github.com/pytorch/pytorch/pull/30559. Specifically, I followed the interface outlined in this [comment](https://github.com/pytorch/pytorch/pull/30559#issuecomment-574864768). ## Structure - `torch/optim/swa_utils.py` contains the implementation of `AveragedModel` class, `SWALR` learning rate scheduler and `update_bn` utility - `test/test_optim.py` contains unit tests for the three components of SWA - `torch/optim/swa_utils.pyi` describes the interface of `torch/optim/swa_utils.py` The new implementation consists of - `AveragedModel` class; this class creates a copy of a given model and allows to compute running averages of the parameters. - `SWALR` learning rate scheduler; after a certain number of epochs switches to a constant learning rate; this scheduler is supposed to be chained with other schedulers. - `update_bn` utility; updates the Batch Normalization activation statistics for a given model and dataloader; this utility is meant to be applied to `AveragedModel` instances. For `update_bn` I simplified the implementation compared to the [original PR](https://github.com/pytorch/pytorch/pull/30559) according to the sugestions by vadimkantorov. ## Example ```python loader, optimizer, model = ... swa_model = torch.optim.swa_utils.AveragedModel(model) # You can use custom averaging functions with `avg_fun` parameter ema_avg = lambda p_avg, p, n_avg: 0.1 * p_avg + 0.9 * p ema_model = torch.optim.swa_utils.AveragedModel(model, avg_function=ema_avg) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) swa_start = 160 swa_scheduler = SWALR(optimizer, start_epoch=swa_start, swa_lr=0.05) for i in range(300): for input, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() scheduler.step() swa_scheduler.step() if i > swa_start: swa_model.update_parameters(model) # Update bn statistics for the swa_model at the end torch.optim.swa_utils.update_bn(loader, swa_model) ``` UPDATED: ```python3 loader, optimizer, model, loss_fn = ... swa_model = torch.optim.swa_utils.AveragedModel(model) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) swa_start = 160 swa_scheduler = SWALR(optimizer, swa_lr=0.05) for i in range(300): for input, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() if i > swa_start: swa_model.update_parameters(model) swa_scheduler.step() else: scheduler.step() # Update bn statistics for the swa_model at the end torch.optim.swa_utils.update_bn(loader, swa_model) ``` Fixes https://github.com/pytorch/pytorch/issues/29994 cc soumith vincentqb andrewgordonwilson vadimkantorov Pull Request resolved: https://github.com/pytorch/pytorch/pull/35032 Differential Revision: D21079606 Pulled By: vincentqb fbshipit-source-id: e07f5e821f72ada63789814c2dcbdc31f0160c37
18 lines
717 B
Python
18 lines
717 B
Python
from .optimizer import Optimizer
|
|
from ..nn.modules import Module
|
|
from .lr_scheduler import _LRScheduler
|
|
from .. import device, Tensor
|
|
from typing import Iterable, Any, Optional, Callable, Union, List
|
|
|
|
class AveragedModel(Module):
|
|
def __init__(self, model: Module, device: Union[int, device]=...,
|
|
avg_fun: Callable[[Tensor, Tensor, int], Tensor]=...) -> None:...
|
|
|
|
def update_parameters(self, model: Module) -> None:...
|
|
|
|
def update_bn(loader: Iterable, model: Module, device: Union[int, device]=...) -> None:...
|
|
|
|
class SWALR(_LRScheduler):
|
|
def __init__(self, optimizer: Optimizer, swa_lr: float, anneal_epochs: int,
|
|
anneal_strategy: str, last_epoch: int=...) -> None:...
|