pytorch/torch/distributed/algorithms
Andrew Gu fc429512d5 [FSDP] Clean up FlatParamHandle dtypes, post-backward hook (#90660)
This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.

**Overview**
This PR addresses everything in https://github.com/pytorch/pytorch/issues/90657 except renaming `keep_low_precision_grads` to `keep_grads_in_reduce_dtype` since that is BC breaking. I recommend reading the issue before preceding.

For `MixedPrecision(param_dtype, reduce_dtype, ...)`, the exact rule for parameter and gradient reduction mixed precision that we are following is:
> If `param_dtype is not None` and `reduce_dtype is None`, then we infer `reduce_dtype = param_dtype`. Otherwise, we take `param_dtype` and `reduce_dtype` as is.

This PR enforces that, at the `FlatParamHandle` level, `handle._config.fwd_bwd_param_dtype` and `handle._config.reduce_dtype` are never `None`. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in `handle._orig_param_dtype`. It is no longer to check against `None`.

This avoids ambiguous cases such as when the user passes `MixedPrecision(param_dtype=torch.float32)`. In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).

**Additional Details**
- We remove `FullyShardedDataParallel._mixed_precision_enabled_for_params`, `FullyShardedDataParallel._mixed_precision_enabled_for_reduce`, and `FullyShardedDataParallel._mixed_precision_keep_low_precision_grads` since they are not used.
- The unit test `test_meta_device_with_mixed_precision()` exercises a tricky edge case with meta device initialization, `apply()` (calling into `summon_full_params()`), and `param_dtype=torch.float32` for a nested wrapping case, where each nested instance has parameters.
- We include some minor fixes/improvements to the communication hook implementation.

**Follow-Ups**
- We should get rid of `HandleConfig` and store its fields as attributes on `FlatParamHandle` directly.
- Rename `keep_low_precision_grads` to `keep_grads_in_reduce_dtype`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90660
Approved by: https://github.com/zhaojuanmao
2022-12-13 07:34:59 +00:00
..
_checkpoint [AC] Add trailing "." to _CHECKPOINT_PREFIX like FSDP (#87951) 2022-10-28 22:05:29 +00:00
_comm_hooks [FSDP] Clean up FlatParamHandle dtypes, post-backward hook (#90660) 2022-12-13 07:34:59 +00:00
_optimizer_overlap
_quantization Change docstring type callable to Callable for consistency (#82487) 2022-08-01 17:26:09 +00:00
ddp_comm_hooks Remove DDP import (#89982) 2022-12-01 14:56:48 +00:00
model_averaging Update hierarchical_model_averager.py (#85648) 2022-10-03 06:15:20 +00:00
__init__.py
join.py Integrate xdoctest - Rebased (#82797) 2022-08-12 02:08:01 +00:00