Summary:
- Makes it possible to use non-sharded optimizer checkpoints (as long as the model/param groups are the same, of course)
- Makes it possible to save with a given world size, and load with another world size
- Use Torch Distributed built-in broadcast object list instead of a ad-hoc version
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50956
Reviewed By: malfet
Differential Revision: D26113953
Pulled By: blefaudeux
fbshipit-source-id: 030bfeee2c34c2d987590d45dc8efe05515f2e5c
Summary:
Implement the first stage of ZeRO, sharding of the optimizer state, as described in [this blog post](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/) and [this paper](https://arxiv.org/abs/1910.02054). This implementation is completely independent from the [DeepSpeed](https://github.com/microsoft/DeepSpeed) framework, and aims at providing ZeRO-compliant building blocks within the PyTorch scheme of things.
This works by:
- acting as a wrapper to a pytorch optimizer. ZeROptimizer does not optimize anything by itself, it only shards optimizers for distributed jobs
- each rank distributes parameters according to a given partitioning scheme (could be updated), and owns the update of a given shard only
- the .step() is called on each rank as expected, the fact that the optimizer actually works on a shard of the model is not visible from the outside
- when the update is completed, each rank broadcasts the updated model shard to all the other ranks
This can be used with DDP, although some communications are wasted in that case (gradients are all-reduced to all ranks). This implementation was initially developed in [Fairscale](https://github.com/facebookresearch/fairscale), and can also be used with an optimized DDP which only reduces to the relevant ranks. More context on ZeRO and PyTorch can be found in [this RFC](https://github.com/pytorch/pytorch/issues/42849)
The API with respect to loading and saving the state is a known pain point and should probably be discussed an updated. Other possible follow ups include integrating more closely to a [modularized DDP](https://github.com/pytorch/pytorch/issues/37002), [making the checkpoints partition-agnostic](https://github.com/facebookresearch/fairscale/issues/164), [exposing a gradient clipping option](https://github.com/facebookresearch/fairscale/issues/98) and making sure that mixed precision states are properly handled.
original authors include msbaines, min-xu-ai and myself
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46750
Reviewed By: mruberry
Differential Revision: D25958918
Pulled By: blefaudeux
fbshipit-source-id: 14280f2fd90cf251eee8ef9ac0f1fa6025ae9c50