mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook. I've also made slight adjustments to existing implementation of an `all_reduce` hook including: - `AllReduceState` -> `DefaultState` , motivation: `AllReduceState` is not specific to `all_reduce`. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook. - I've put all 3 hooks into `default_hooks.py` Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify `MixedPrecision` for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and `MixedPrecision(reduce_dtype=<precision>)` specified, but I am happy to discuss this and adjust current implementation. As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80557 Approved by: https://github.com/rohan-varma |
||
|---|---|---|
| .. | ||
| _checkpoint | ||
| _comm_hooks | ||
| _optimizer_overlap | ||
| _quantization | ||
| ddp_comm_hooks | ||
| model_averaging | ||
| __init__.py | ||
| join.py | ||