pytorch/docs/source
milesial 45bf3f6216 Optimized EMA implementation (#94820)
This PR proposes an optimized way to do Exponential Moving Average (EMA), which is faster than the current way using `swa_utils.AveragedModel` described in https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies.

This implementation is asynchronous, and is built as an optimizer wrapper so that the EMA weight update happens without any additional CPU/GPU sync, just after optimizer steps, and with limited code changes.

Example usage:
```
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())

opt = EMAOptimizer(opt, device, 0.9999)

for epoch in range(epochs):
    training_loop(model, opt)

    regular_eval_accuracy = evaluate(model)

    with opt.swap_ema_weights():
        ema_eval_accuracy = evaluate(model)
```

Here are some benchmarks (time per iteration) on various torchvision models:

|model|this PR iteration time                      |swa_utils.AveragedModel iteration time| iteration speedup                                      |
|-----|-----------------------------|-----------------------|---------------------------------------------|
|     |                             |                       |                                             |
|regnet_x_1_6gf|62.73                        |67.998                 |1.08                                         |
|regnet_x_3_2gf|101.75                       |109.422                |1.08                                         |
|regnet_x_400mf|25.13                        |32.005                 |1.27                                         |
|regnet_x_800mf|33.01                        |37.466                 |1.13                                         |
|regnet_x_8gf|128.13                       |134.868                |1.05                                         |
|regnet_y_16gf|252.91                       |261.292                |1.03                                         |
|regnet_y_1_6gf|72.14                        |84.22                  |1.17                                         |
|regnet_y_3_2gf|99.99                        |109.296                |1.09                                         |
|regnet_y_400mf|29.53                        |36.506                 |1.24                                         |
|regnet_y_800mf|37.82                        |43.634                 |1.15                                         |
|regnet_y_8gf|196.63                       |203.317                |1.03                                         |
|resnet101|128.80                       |137.434                |1.07                                         |
|resnet152|182.85                       |196.498                |1.07                                         |
|resnet18|29.06                        |29.975                 |1.03                                         |
|resnet34|50.73                        |53.443                 |1.05                                         |
|resnet50|76.88                        |80.602                 |1.05                                         |
|resnext101_32x8d|277.29                       |280.759                |1.01                                         |
|resnext101_64x4d|269.56                       |281.052                |1.04                                         |
|resnext50_32x4d|100.73                       |101.102                |1.00                                         |
|shufflenet_v2_x0_5|10.56                        |15.419                 |1.46                                         |
|shufflenet_v2_x1_0|13.11                        |18.525                 |1.41                                         |
|shufflenet_v2_x1_5|18.05                        |23.132                 |1.28                                         |
|shufflenet_v2_x2_0|25.04                        |30.008                 |1.20                                         |
|squeezenet1_1|14.26                        |14.325                 |1.00                                         |
|swin_b|264.52                       |274.613                |1.04                                         |
|swin_s|180.66                       |188.914                |1.05                                         |
|swin_t|108.62                       |112.632                |1.04                                         |
|swin_v2_s|220.29                       |231.153                |1.05                                         |
|swin_v2_t|127.27                       |133.586                |1.05                                         |
|vgg11|95.52                        |103.714                |1.09                                         |
|vgg11_bn|106.49                       |120.711                |1.13                                         |
|vgg13|132.94                       |147.063                |1.11                                         |
|vgg13_bn|149.73                       |165.256                |1.10                                         |
|vgg16|158.19                       |172.865                |1.09                                         |
|vgg16_bn|177.04                       |192.888                |1.09                                         |
|vgg19|184.76                       |194.194                |1.05                                         |
|vgg19_bn|203.30                       |213.334                |1.05                                         |
|vit_b_16|217.31                       |219.748                |1.01                                         |
|vit_b_32|69.47                        |75.692                 |1.09                                         |
|vit_l_32|223.20                       |258.487                |1.16                                         |
|wide_resnet101_2|267.38                       |279.836                |1.05                                         |
|wide_resnet50_2|145.06                       |154.918                |1.07                                         |

You can see that in all cases it is faster than using `AveragedModel`. In fact in many cases, adding EMA does not add any overhead since the computation is hidden behind the usual iteration flow.

This is a similar implementation to the one currently in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).

If the team is interested in merging this, let me know and I'll add some documentation similar to `swa_utils` and tests.

Credits to @szmigacz for the implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94820
Approved by: https://github.com/janeyx99
2023-04-26 18:02:11 +00:00
..
_static Move Dynamo docs back to core (#89769) 2022-11-29 04:38:53 +00:00
_templates Add a private API banner (#93996) 2023-02-03 21:40:15 +00:00
community Adding the maintainers approved in 2023Q1 Core Maintainers meeting (#98520) 2023-04-24 17:58:18 +00:00
compile Test and document dynamo backward hooks support (#99382) 2023-04-18 03:03:29 +00:00
elastic [BE] Prefer dash over underscore in command-line options (#94505) 2023-02-09 20:16:49 +00:00
notes Clarify the saving of intermediates in the "extending torch.func" docs (#98020) 2023-03-31 13:57:37 +00:00
rpc Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
scripts Rename Canonical Aten IR to Core Aten IR (#92904) 2023-01-25 05:12:23 +00:00
_dynamo.rst Add torch._dynamo to docs (#89510) 2022-11-23 16:33:13 +00:00
amp.rst Remove deprecated torch.matrix_rank (#70981) 2022-09-22 17:40:46 +00:00
autograd.rst [docs] Add missing functions to autograd.rst (#98854) 2023-04-11 20:45:49 +00:00
backends.rst [CUDA][cuFFT] Minor fix for cuFFT plan cache docs (#96373) 2023-03-14 00:28:14 +00:00
benchmark_utils.rst Cleanup all module references in doc (#73983) 2022-03-10 22:26:29 +00:00
bottleneck.rst add itt unit test and docstrings (#84848) 2022-09-28 01:39:58 +00:00
checkpoint.rst
complex_numbers.rst Add a note on CUDA 11.6 (#80363) 2022-06-27 21:34:24 +00:00
conf.py Add Symbool support in python to C++ translation (#98453) 2023-04-12 03:21:57 +00:00
config_mod.rst rename config module file to work with gh pages better 2022-03-10 20:41:44 +00:00
cpp_extension.rst Check clang++/g++ version when compiling CUDA extensions (#63230) 2022-02-24 08:32:32 +00:00
cpp_index.rst Add C++ Landing Page (#38450) 2020-05-14 16:02:01 -07:00
cuda._sanitizer.rst Fix typos under docs directory (#88033) 2022-10-31 19:31:56 +00:00
cuda.rst Add more GPU metric instrumentation (#91717) 2023-02-24 00:38:03 +00:00
cudnn_persistent_rnn.rst Remove orphan from cuDNN persistent note (#65160) 2021-09-21 11:09:47 -07:00
cudnn_rnn_determinism.rst Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
data.rst [DataLoader] Removing DataLoader2 related code (#88848) 2022-11-11 22:27:01 +00:00
ddp_comm_hooks.rst [BE] [1/3] Rewrite super() calls in caffe2 and benchmarks (#94587) 2023-02-11 18:19:48 +00:00
deploy.rst Delete torch::deploy from pytorch core (#85953) 2022-10-06 07:20:16 +00:00
distributed.algorithms.join.rst Add tutorial link (#62785) 2021-08-05 17:28:02 -07:00
distributed.checkpoint.rst [DCP] Add DCP FSDP sharded_state_dict checkpoint example to DCP .rst file (#95517) 2023-03-03 18:09:10 +00:00
distributed.elastic.rst [1/n][torch/elastic] Move torchelastic docs *.rst (#148) 2021-05-04 00:57:56 -07:00
distributed.optim.rst [distributed][docs] Delete distributed optimimzer section from RPC and add reference to namespace docs page (#68068) 2021-11-09 15:01:54 -08:00
distributed.rst [Doc][Distributed] Add missing functions to distributed.rst (#89905) 2022-12-04 07:22:54 +00:00
distributed.tensor.parallel.rst [PT-D][Sequence Parallelism] Enable DTensor based Naive sequence parallelism (#94369) 2023-02-16 21:21:00 +00:00
distributions.rst [Reinstate] Wishart distribution (#70377) 2021-12-30 11:41:46 -08:00
dlpack.rst Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
docutils.conf Revert "Revert D21337640: [pytorch][PR] Split up documentation into subpages and clean up some warnings" (#37778) 2020-05-04 14:32:35 -07:00
fft.rst Cleanup all module references in doc (#73983) 2022-03-10 22:26:29 +00:00
fsdp.rst [FSDP()][3/N] Refactor public APIs (#87917) 2022-10-31 16:45:21 +00:00
func.api.rst [functorch] linearize (#94173) 2023-02-09 15:45:08 +00:00
func.batch_norm.rst Fix typo under docs directory (#97202) 2023-03-21 01:24:10 +00:00
func.migrating.rst [torch.func] Add migration guide from functorch (#91811) 2023-01-17 22:14:42 +00:00
func.rst Fix typo under docs directory (#92762) 2023-01-23 18:07:22 +00:00
func.ux_limitations.rst [torch.func] Add docs (#91319) 2022-12-30 02:51:18 +00:00
func.whirlwind_tour.rst [torch.func] Add docs (#91319) 2022-12-30 02:51:18 +00:00
futures.rst Update docs to mention CUDA support for Future (#50048) 2021-05-11 08:26:33 -07:00
fx.rst prepare removal of deprecated functionality in torch.testing (#87969) 2022-11-02 14:04:48 +00:00
hub.rst Fix typo under docs directory (#92762) 2023-01-23 18:07:22 +00:00
index.rst Add documentation for torch._logging.set_logs (#99219) 2023-04-24 08:06:57 +00:00
ir.rst Rename Canonical Aten IR to Core Aten IR (#92904) 2023-01-25 05:12:23 +00:00
jit_builtin_functions.rst Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
jit_language_reference_v2.rst Fix typo under docs directory (#97202) 2023-03-21 01:24:10 +00:00
jit_language_reference.rst [BE] [1/3] Rewrite super() calls in caffe2 and benchmarks (#94587) 2023-02-11 18:19:48 +00:00
jit_python_reference.rst [JIT] improve documentation (#57991) 2021-05-19 11:47:32 -07:00
jit_unsupported.rst (Re-open) Adds cudaMallocAsync as an alternative backend for the CUDA allocator (#82682) 2022-10-12 03:44:21 +00:00
jit_utils.rst Create __init__.py (#78629) 2022-06-03 18:14:21 +00:00
jit.rst [BE] [1/3] Rewrite super() calls in caffe2 and benchmarks (#94587) 2023-02-11 18:19:48 +00:00
library.rst Add docs for Python Registration 2022-06-13 23:21:23 +00:00
linalg.rst Add a note on the stability of linalg functions. (#88313) 2022-11-07 22:44:23 +00:00
logging.rst Add documentation for torch._logging.set_logs (#99219) 2023-04-24 08:06:57 +00:00
masked.rst Fix link in docs (#94686) 2023-02-13 20:42:24 +00:00
math-quantizer-equation.png
mobile_optimizer.rst [Reland] Clean Up MobileOptimizerType Rewrite Flags Public API and Documentation (#92081) 2023-01-14 17:06:00 +00:00
model_zoo.rst
monitor.rst torch/monitor: merge Interval and FixedCount stats (#72009) 2022-01-30 23:21:59 +00:00
mps.rst [MPS] Enable Memory Leak Detection for test_mps.py (#94646) 2023-02-13 17:56:24 +00:00
multiprocessing.rst Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
name_inference.rst Add itemsize and nbytes properties to Tensor (#98322) 2023-04-05 12:11:55 +00:00
named_tensor.rst Add torch.unflatten and improve its docs (#81399) 2022-07-29 15:02:42 +00:00
nested.rst Fix format bug in NT docs (#97998) 2023-03-31 01:00:25 +00:00
nn.functional.rst [SDPA] update type hint for scaled_dot_product_attention and documentation (#94008) 2023-02-10 18:02:43 +00:00
nn.init.rst update nn.init doc to reflect the no_grad (#80882) 2022-07-07 17:19:29 +00:00
nn.rst [easy] Expose documentation for a few global nn.Module hooks (#97185) 2023-03-21 20:09:29 +00:00
onnx_diagnostics.rst [ONNX] Document ONNX diagnostics (#88371) 2022-11-16 19:21:46 +00:00
onnx_supported_aten_ops.rst [ONNX] Update ONNX documentation to include unsupported operators (#84496) 2022-09-16 23:48:37 +00:00
onnx.rst [ONNX] Fix missing import numpy for docs example (#99663) 2023-04-21 04:06:45 +00:00
optim.rst Optimized EMA implementation (#94820) 2023-04-26 18:02:11 +00:00
package.rst Fix typos in torch.package documentation (#82994) 2022-08-08 20:19:17 +00:00
pipeline.rst docs: Linking ResNeXt PyTorch Hub Pipeline (#98689) 2023-04-11 02:20:26 +00:00
profiler.rst Fix ITT unit-tests if PyTorch is compiled with USE_ITT=OFF (#86199) 2022-10-04 21:57:05 +00:00
quantization-accuracy-debugging.rst Fix typo under docs directory (#87583) 2022-10-24 23:52:44 +00:00
quantization-backend-configuration.rst update quantization doc: add x86 backend as default backend of server inference (#86794) 2022-12-02 02:10:25 +00:00
quantization-support.rst AO migration: migrate .rst files to new locations (#94211) 2023-02-07 02:32:23 +00:00
quantization.rst Revert "[core][pruning][be] Rename sparsifier folder to pruner (#98758)" 2023-04-13 16:30:47 +00:00
random.rst Remove duplicated entries in random.rst (#39725) 2020-06-10 16:51:15 -07:00
rpc.rst Fix typo under docs directory and RELEASE.md (#85896) 2022-09-29 21:41:59 +00:00
signal.rst Nuttall window (#90103) 2022-12-16 09:05:53 +00:00
sparse.rst Minor error in docs regarding execution time (#93258) 2023-01-31 23:32:42 +00:00
special.rst [primTorch] special: j0, j1, spherical_j0 (#86049) 2022-10-04 18:21:46 +00:00
storage.rst Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) 2022-11-08 18:11:01 +00:00
tensor_attributes.rst Add a warning about performance cost of set_default_device (#92703) 2023-01-21 02:23:13 +00:00
tensor_view.rst Correcting a minor typo: "Users should pay" instead of "Users should be pay" (#72500) 2022-02-08 23:08:25 +00:00
tensorboard.rst Cleanup all module references in doc (#73983) 2022-03-10 22:26:29 +00:00
tensors.rst Add itemsize and nbytes properties to Tensor (#98322) 2023-04-05 12:11:55 +00:00
testing.rst document torch.testing.assert_allclose (#89526) 2022-12-01 11:22:50 +00:00
torch.ao.ns._numeric_suite_fx.rst Quantization docs: add pages for Numeric Suite (Eager and FX) (#66380) 2021-10-11 18:47:58 -07:00
torch.ao.ns._numeric_suite.rst Quantization docs: add pages for Numeric Suite (Eager and FX) (#66380) 2021-10-11 18:47:58 -07:00
torch.overrides.rst Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)"" 2022-05-18 18:40:57 +00:00
torch.rst Rework torch.compile docs (#96706) 2023-03-15 04:41:13 +00:00
type_info.rst ENH: Convert finfo.tiny to finfo.smallest_normal (#76292) 2022-05-20 00:59:48 +00:00