mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove state_dict from AveragedModel and use buffers instead (#71763)
Summary:
Fixes [https://github.com/pytorch/pytorch/issues/66686](https://github.com/pytorch/pytorch/issues/66686)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71763
Reviewed By: anjali411
Differential Revision: D33770907
Pulled By: prabhat00155
fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
(cherry picked from commit a15898b072)
This commit is contained in:
parent
40e88b75c4
commit
942a084c46
|
|
@ -4,7 +4,9 @@ import warnings
|
|||
import math
|
||||
import unittest
|
||||
import functools
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
import torch.optim as optim
|
||||
|
|
@ -2475,8 +2477,8 @@ class TestSWAUtils(TestCase):
|
|||
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
|
||||
self.assertEqual(p_avg, p_swa)
|
||||
|
||||
def test_averaged_model_exponential_use_state_dict(self):
|
||||
# Test AveragedModel with EMA as avg_fn and use_state_dict as True.
|
||||
def test_averaged_model_exponential_buffers(self):
|
||||
# Test AveragedModel with EMA as avg_fn and use_buffers as True.
|
||||
dnn = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||
torch.nn.BatchNorm2d(5, momentum=0.3),
|
||||
|
|
@ -2486,13 +2488,14 @@ class TestSWAUtils(TestCase):
|
|||
|
||||
def avg_fn(p_avg, p, n_avg):
|
||||
return alpha * p_avg + (1 - alpha) * p
|
||||
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict')
|
||||
averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values()
|
||||
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True)
|
||||
dnn_params = itertools.chain(dnn.parameters(), dnn.buffers())
|
||||
averaged_params = [torch.zeros_like(param) for param in dnn_params
|
||||
if param.size() != torch.Size([])]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
updated_averaged_params = []
|
||||
for p, p_avg in zip(dnn.state_dict().values(), averaged_params):
|
||||
for p, p_avg in zip(dnn_params, averaged_params):
|
||||
if p.size() == torch.Size([]):
|
||||
continue
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
|
|
@ -2504,7 +2507,8 @@ class TestSWAUtils(TestCase):
|
|||
averaged_dnn.update_parameters(dnn)
|
||||
averaged_params = updated_averaged_params
|
||||
|
||||
for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()):
|
||||
for p_avg, p_swa in zip(
|
||||
averaged_params, itertools.chain(averaged_dnn.module.parameters(), averaged_dnn.module.buffers())):
|
||||
self.assertEqual(p_avg, p_swa)
|
||||
|
||||
def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
import itertools
|
||||
import math
|
||||
from torch.nn import Module
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
|
|
@ -26,8 +28,8 @@ class AveragedModel(Module):
|
|||
:class:`AveragedModel` parameter, the current value of :attr:`model`
|
||||
parameter and the number of models already averaged; if None,
|
||||
equally weighted average is used (default: None)
|
||||
mode (str, optional): whether to use ``'parameters'`` or ``'state_dict'`` for update
|
||||
(default: ``'parameters'``)
|
||||
use_buffers (bool): if ``True``, it will compute running averages for
|
||||
both the parameters and the buffers of the model. (default: ``False``)
|
||||
|
||||
Example:
|
||||
>>> loader, optimizer, model, loss_fn = ...
|
||||
|
|
@ -55,15 +57,21 @@ class AveragedModel(Module):
|
|||
equally-weighted average of the weights.
|
||||
|
||||
Example:
|
||||
>>> # Compute exponential moving averages of the weights
|
||||
>>> # Compute exponential moving averages of the weights and buffers
|
||||
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
|
||||
0.1 * averaged_model_parameter + 0.9 * model_parameter
|
||||
>>> swa_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
|
||||
>>> swa_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg, use_buffers=True)
|
||||
|
||||
.. note::
|
||||
When using SWA with models containing Batch Normalization you may
|
||||
need to update the activation statistics for Batch Normalization.
|
||||
You can do so by using :meth:`torch.optim.swa_utils.update_bn` utility.
|
||||
This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
|
||||
or by setting :attr:`use_buffers` to `True`. The first approach updates the
|
||||
statistics in a post-training step by passing data through the model. The
|
||||
second does it during the parameter update phase by averaging all buffers.
|
||||
Empirical evidence has shown that updating the statistics in normalization
|
||||
layers increases accuracy, but you may wish to empirically test which
|
||||
approach yields the best results in your problem.
|
||||
|
||||
.. note::
|
||||
:attr:`avg_fn` is not saved in the :meth:`state_dict` of the model.
|
||||
|
|
@ -86,7 +94,7 @@ class AveragedModel(Module):
|
|||
Generalizes Well:
|
||||
https://arxiv.org/abs/2001.02312
|
||||
"""
|
||||
def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
|
||||
def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
|
||||
super(AveragedModel, self).__init__()
|
||||
self.module = deepcopy(model)
|
||||
if device is not None:
|
||||
|
|
@ -98,17 +106,20 @@ class AveragedModel(Module):
|
|||
return averaged_model_parameter + \
|
||||
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
|
||||
self.avg_fn = avg_fn
|
||||
modes = ['parameters', 'state_dict']
|
||||
if mode not in modes:
|
||||
raise ValueError(f'Invalid mode passed, valid values are {", ".join(modes)}.')
|
||||
self.use_state_dict = mode == 'state_dict'
|
||||
self.use_buffers = use_buffers
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def update_parameters(self, model):
|
||||
self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
|
||||
model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
|
||||
self_param = (
|
||||
itertools.chain(self.module.parameters(), self.module.buffers())
|
||||
if self.use_buffers else self.parameters()
|
||||
)
|
||||
model_param = (
|
||||
itertools.chain(model.parameters(), model.buffers())
|
||||
if self.use_buffers else model.parameters()
|
||||
)
|
||||
for p_swa, p_model in zip(self_param, model_param):
|
||||
device = p_swa.device
|
||||
p_model_ = p_model.detach().to(device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user