pytorch/torch/distributed/algorithms
Olga Andreeva f7d6828467 Adding fsdp fp16 and bf16 hooks (#80557)
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.

I've also made slight adjustments to existing implementation of an `all_reduce` hook including:

- `AllReduceState` ->  `DefaultState` , motivation: `AllReduceState` is not specific to `all_reduce`. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook.
- I've put all 3 hooks into `default_hooks.py`

Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify `MixedPrecision` for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and `MixedPrecision(reduce_dtype=<precision>)` specified, but I am happy to discuss this and adjust current implementation.

As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80557
Approved by: https://github.com/rohan-varma
2022-07-18 22:40:56 +00:00
..
_checkpoint Revert "Revert "[FSDP Optim State] Remove checkpoint prefix (#80480)"" (#80936) 2022-07-06 22:21:07 +00:00
_comm_hooks Adding fsdp fp16 and bf16 hooks (#80557) 2022-07-18 22:40:56 +00:00
_optimizer_overlap make fsdp folder to be public (#72084) 2022-02-02 15:50:14 +00:00
_quantization Enable test: distributed/algorithms/quantization/test_quantization (#80097) 2022-07-01 01:32:33 +00:00
ddp_comm_hooks Enable Zero1's ddp_with_overlap for hpu backend (#80438) 2022-07-18 15:05:27 +00:00
model_averaging Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367) 2022-06-27 21:27:30 +00:00
__init__.py Make _Join, _Joinable, _JoinHook public (#62605) 2021-08-03 12:20:11 -07:00
join.py Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367) 2022-06-27 21:27:30 +00:00