Remove state_dict from AveragedModel and use buffers instead (#71763)

Summary:
Fixes [https://github.com/pytorch/pytorch/issues/66686](https://github.com/pytorch/pytorch/issues/66686)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71763

Reviewed By: anjali411

Differential Revision: D33770907

Pulled By: prabhat00155

fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
(cherry picked from commit a15898b072)
This commit is contained in:
Prabhat Roy 2022-01-26 04:33:31 -08:00 committed by PyTorch MergeBot
parent 40e88b75c4
commit 942a084c46
2 changed files with 35 additions and 20 deletions

View File

@ -4,7 +4,9 @@ import warnings
import math
import unittest
import functools
import itertools
from copy import deepcopy
import torch
from torch._six import inf
import torch.optim as optim
@ -2475,8 +2477,8 @@ class TestSWAUtils(TestCase):
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
def test_averaged_model_exponential_use_state_dict(self):
# Test AveragedModel with EMA as avg_fn and use_state_dict as True.
def test_averaged_model_exponential_buffers(self):
# Test AveragedModel with EMA as avg_fn and use_buffers as True.
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5, momentum=0.3),
@ -2486,13 +2488,14 @@ class TestSWAUtils(TestCase):
def avg_fn(p_avg, p, n_avg):
return alpha * p_avg + (1 - alpha) * p
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict')
averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values()
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True)
dnn_params = itertools.chain(dnn.parameters(), dnn.buffers())
averaged_params = [torch.zeros_like(param) for param in dnn_params
if param.size() != torch.Size([])]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(dnn.state_dict().values(), averaged_params):
for p, p_avg in zip(dnn_params, averaged_params):
if p.size() == torch.Size([]):
continue
p.detach().add_(torch.randn_like(p))
@ -2504,7 +2507,8 @@ class TestSWAUtils(TestCase):
averaged_dnn.update_parameters(dnn)
averaged_params = updated_averaged_params
for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()):
for p_avg, p_swa in zip(
averaged_params, itertools.chain(averaged_dnn.module.parameters(), averaged_dnn.module.buffers())):
self.assertEqual(p_avg, p_swa)
def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):

View File

@ -1,7 +1,9 @@
import torch
import itertools
import math
from torch.nn import Module
from copy import deepcopy
import torch
from torch.nn import Module
from torch.optim.lr_scheduler import _LRScheduler
@ -26,8 +28,8 @@ class AveragedModel(Module):
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
mode (str, optional): whether to use ``'parameters'`` or ``'state_dict'`` for update
(default: ``'parameters'``)
use_buffers (bool): if ``True``, it will compute running averages for
both the parameters and the buffers of the model. (default: ``False``)
Example:
>>> loader, optimizer, model, loss_fn = ...
@ -55,15 +57,21 @@ class AveragedModel(Module):
equally-weighted average of the weights.
Example:
>>> # Compute exponential moving averages of the weights
>>> # Compute exponential moving averages of the weights and buffers
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
0.1 * averaged_model_parameter + 0.9 * model_parameter
>>> swa_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
>>> swa_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg, use_buffers=True)
.. note::
When using SWA with models containing Batch Normalization you may
need to update the activation statistics for Batch Normalization.
You can do so by using :meth:`torch.optim.swa_utils.update_bn` utility.
This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
or by setting :attr:`use_buffers` to `True`. The first approach updates the
statistics in a post-training step by passing data through the model. The
second does it during the parameter update phase by averaging all buffers.
Empirical evidence has shown that updating the statistics in normalization
layers increases accuracy, but you may wish to empirically test which
approach yields the best results in your problem.
.. note::
:attr:`avg_fn` is not saved in the :meth:`state_dict` of the model.
@ -86,7 +94,7 @@ class AveragedModel(Module):
Generalizes Well:
https://arxiv.org/abs/2001.02312
"""
def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
@ -98,17 +106,20 @@ class AveragedModel(Module):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
modes = ['parameters', 'state_dict']
if mode not in modes:
raise ValueError(f'Invalid mode passed, valid values are {", ".join(modes)}.')
self.use_state_dict = mode == 'state_dict'
self.use_buffers = use_buffers
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update_parameters(self, model):
self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
self_param = (
itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers else self.parameters()
)
model_param = (
itertools.chain(model.parameters(), model.buffers())
if self.use_buffers else model.parameters()
)
for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
p_model_ = p_model.detach().to(device)