pytorch/torch
milesial 45bf3f6216 Optimized EMA implementation (#94820)
This PR proposes an optimized way to do Exponential Moving Average (EMA), which is faster than the current way using `swa_utils.AveragedModel` described in https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies.

This implementation is asynchronous, and is built as an optimizer wrapper so that the EMA weight update happens without any additional CPU/GPU sync, just after optimizer steps, and with limited code changes.

Example usage:
```
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())

opt = EMAOptimizer(opt, device, 0.9999)

for epoch in range(epochs):
    training_loop(model, opt)

    regular_eval_accuracy = evaluate(model)

    with opt.swap_ema_weights():
        ema_eval_accuracy = evaluate(model)
```

Here are some benchmarks (time per iteration) on various torchvision models:

|model|this PR iteration time                      |swa_utils.AveragedModel iteration time| iteration speedup                                      |
|-----|-----------------------------|-----------------------|---------------------------------------------|
|     |                             |                       |                                             |
|regnet_x_1_6gf|62.73                        |67.998                 |1.08                                         |
|regnet_x_3_2gf|101.75                       |109.422                |1.08                                         |
|regnet_x_400mf|25.13                        |32.005                 |1.27                                         |
|regnet_x_800mf|33.01                        |37.466                 |1.13                                         |
|regnet_x_8gf|128.13                       |134.868                |1.05                                         |
|regnet_y_16gf|252.91                       |261.292                |1.03                                         |
|regnet_y_1_6gf|72.14                        |84.22                  |1.17                                         |
|regnet_y_3_2gf|99.99                        |109.296                |1.09                                         |
|regnet_y_400mf|29.53                        |36.506                 |1.24                                         |
|regnet_y_800mf|37.82                        |43.634                 |1.15                                         |
|regnet_y_8gf|196.63                       |203.317                |1.03                                         |
|resnet101|128.80                       |137.434                |1.07                                         |
|resnet152|182.85                       |196.498                |1.07                                         |
|resnet18|29.06                        |29.975                 |1.03                                         |
|resnet34|50.73                        |53.443                 |1.05                                         |
|resnet50|76.88                        |80.602                 |1.05                                         |
|resnext101_32x8d|277.29                       |280.759                |1.01                                         |
|resnext101_64x4d|269.56                       |281.052                |1.04                                         |
|resnext50_32x4d|100.73                       |101.102                |1.00                                         |
|shufflenet_v2_x0_5|10.56                        |15.419                 |1.46                                         |
|shufflenet_v2_x1_0|13.11                        |18.525                 |1.41                                         |
|shufflenet_v2_x1_5|18.05                        |23.132                 |1.28                                         |
|shufflenet_v2_x2_0|25.04                        |30.008                 |1.20                                         |
|squeezenet1_1|14.26                        |14.325                 |1.00                                         |
|swin_b|264.52                       |274.613                |1.04                                         |
|swin_s|180.66                       |188.914                |1.05                                         |
|swin_t|108.62                       |112.632                |1.04                                         |
|swin_v2_s|220.29                       |231.153                |1.05                                         |
|swin_v2_t|127.27                       |133.586                |1.05                                         |
|vgg11|95.52                        |103.714                |1.09                                         |
|vgg11_bn|106.49                       |120.711                |1.13                                         |
|vgg13|132.94                       |147.063                |1.11                                         |
|vgg13_bn|149.73                       |165.256                |1.10                                         |
|vgg16|158.19                       |172.865                |1.09                                         |
|vgg16_bn|177.04                       |192.888                |1.09                                         |
|vgg19|184.76                       |194.194                |1.05                                         |
|vgg19_bn|203.30                       |213.334                |1.05                                         |
|vit_b_16|217.31                       |219.748                |1.01                                         |
|vit_b_32|69.47                        |75.692                 |1.09                                         |
|vit_l_32|223.20                       |258.487                |1.16                                         |
|wide_resnet101_2|267.38                       |279.836                |1.05                                         |
|wide_resnet50_2|145.06                       |154.918                |1.07                                         |

You can see that in all cases it is faster than using `AveragedModel`. In fact in many cases, adding EMA does not add any overhead since the computation is hidden behind the usual iteration flow.

This is a similar implementation to the one currently in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).

If the team is interested in merging this, let me know and I'll add some documentation similar to `swa_utils` and tests.

Credits to @szmigacz for the implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94820
Approved by: https://github.com/janeyx99
2023-04-26 18:02:11 +00:00
..
_awaits
_C fix per-dispatchkey-mode caching bug (#98030) 2023-04-25 21:58:14 +00:00
_C_flatbuffer
_decomp [inductor] Lowering of rngprims philox_rand (#99289) 2023-04-26 01:22:41 +00:00
_dispatch
_dynamo Make sizevar addition work properly (#100015) 2023-04-26 15:59:26 +00:00
_export [export] Constraints API (#98433) 2023-04-13 21:20:10 +00:00
_functorch [philox_rand] Dynamic shape support (#99290) 2023-04-25 22:40:28 +00:00
_inductor Add all_reduce_coalesced to functional collectives (#98640) 2023-04-26 17:05:54 +00:00
_lazy
_logging Add documentation for torch._logging.set_logs (#99219) 2023-04-24 08:06:57 +00:00
_prims [philox_rand] Dynamic shape support (#99290) 2023-04-25 22:40:28 +00:00
_prims_common [philox_rand] Dynamic shape support (#99290) 2023-04-25 22:40:28 +00:00
_refs [pt2] add SymInt support for roll (#99114) 2023-04-15 18:01:39 +00:00
_subclasses [BE] Enable C419 rule for any all shortcircuiting (#99890) 2023-04-25 15:02:13 +00:00
amp refactor macro with AMP (#99285) 2023-04-19 01:00:00 +00:00
ao [reland][quant][pt2e][refactor] Cleanup the logic for deciding whether to insert observer/fq or not (#99220) (#99767) 2023-04-25 16:53:02 +00:00
autograd Change 'w.r.t.' to 'wrt' in function docstrings to fix doc rendering (#100028) 2023-04-25 23:53:26 +00:00
backends Convert logging f-strings to use % format, part five (#98765) 2023-04-11 13:17:59 +00:00
contrib
cpu
csrc add some missing includes (#100049) 2023-04-26 14:27:06 +00:00
cuda [BE] Enable C419 rule for any all shortcircuiting (#99890) 2023-04-25 15:02:13 +00:00
distributed Add all_reduce_coalesced to functional collectives (#98640) 2023-04-26 17:05:54 +00:00
distributions Remove in-place operations in NegativeBinomial (#96748) 2023-04-26 14:45:08 +00:00
fft
func
futures
fx change torch._dynamo.export(aten_graph=...) to allow pre_autograd tracing (#98031) 2023-04-25 21:58:14 +00:00
jit Add shape function for aten::cross_entropy_loss (#97875) 2023-04-12 22:11:56 +00:00
legacy
lib
linalg
masked
monitor
mps
multiprocessing Reduce overhead in CUDAGraph Trees (#98529) 2023-04-07 05:46:08 +00:00
nested [BE] Enable C419 rule for any all shortcircuiting (#99890) 2023-04-25 15:02:13 +00:00
nn Move nn.module state dict pre hook (#98964) 2023-04-26 16:51:13 +00:00
onnx [ONNX] Support aten::atan2 in torchscript exporter (#100040) 2023-04-26 04:00:47 +00:00
optim Optimized EMA implementation (#94820) 2023-04-26 18:02:11 +00:00
package Convert logging f-strings to use % format, part five (#98765) 2023-04-11 13:17:59 +00:00
profiler [Profiler] Support HTML plot output for profiler export_memory_timeline API (#99751) 2023-04-22 04:21:58 +00:00
quantization
signal Fix flake8 lint errors reported by ruff - take 2 (#99798) 2023-04-23 23:09:51 +00:00
sparse
special
testing Add all_reduce_coalesced to functional collectives (#98640) 2023-04-26 17:05:54 +00:00
utils fix per-dispatchkey-mode caching bug (#98030) 2023-04-25 21:58:14 +00:00
__config__.py
__future__.py
__init__.py [BE] Enable C419 rule for any all shortcircuiting (#99890) 2023-04-25 15:02:13 +00:00
_appdirs.py
_classes.py
_custom_op.py Reland "Simple Custom Operator API, V0 (#98440)" (#99416) 2023-04-18 23:48:33 +00:00
_deploy.py
_guards.py In detect_fake_mode, assert that all detected fake modes are consistent (#99392) 2023-04-18 15:35:05 +00:00
_jit_internal.py [JIT] Allow tuple and list generics (#98703) 2023-04-09 22:58:58 +00:00
_linalg_utils.py
_lobpcg.py
_lowrank.py
_meta_registrations.py [BE] Enable C419 rule for any all shortcircuiting (#99890) 2023-04-25 15:02:13 +00:00
_namedtensor_internals.py
_ops.py
_python_dispatcher.py
_sources.py
_storage_docs.py
_tensor_docs.py Fix Tensor.uniform_ documentation to mention generator argument (#99510) 2023-04-19 19:23:12 +00:00
_tensor_str.py Fix FakeTensor printing (#99205) 2023-04-18 13:26:27 +00:00
_tensor.py Change 'w.r.t.' to 'wrt' in function docstrings to fix doc rendering (#100028) 2023-04-25 23:53:26 +00:00
_torch_docs.py Update torch.arange doc. (#99963) 2023-04-26 04:18:56 +00:00
_utils_internal.py Log PT2 compile to Scuba (#98790) 2023-04-11 20:10:35 +00:00
_utils.py add get_device_index for custom device (#98804) 2023-04-12 23:58:31 +00:00
_VF.py
_vmap_internals.py [BE] Enable C419 rule for any all shortcircuiting (#99890) 2023-04-25 15:02:13 +00:00
_weights_only_unpickler.py
abi-check.cpp
CMakeLists.txt
custom_class_detail.h
custom_class.h
extension.h
functional.py
hub.py torch.hub: add safe weights_only option to load_state_dict_from_url (#98479) 2023-04-11 12:44:25 +00:00
library.h
library.py torch.library.Library.impl: add missing param in docstring example (#98619) 2023-04-11 06:09:46 +00:00
overrides.py Fix flake8 lint errors - part 2 - manual fixes (#99799) 2023-04-24 06:03:26 +00:00
py.typed
quasirandom.py
random.py add rng_state support for custom device (#98069) 2023-04-10 22:36:55 +00:00
README.txt
return_types.py
script.h
serialization.py Support non-ASCII characters in model file paths (#99453) 2023-04-26 01:15:49 +00:00
storage.py Fix loading data on different encoding (#94503) 2023-04-25 21:05:20 +00:00
torch_version.py
types.py

Note [TH abstraction violation]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.