Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:
1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed.
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously.
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision.
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.
Follow-ups:
1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads,
3. allow BN, LN, or custom units to run in reduced precision,
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order.
8. Entirely unused modules probably don't need to be cast.
Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92882
Approved by: https://github.com/zhaojuanmao
Allow _apply_optim_in_backward to work with DDP.
Example:
```
dist.init_process_group("nccl", rank=rank, world_size=2)
torch.cuda.set_device(rank)
e = enc().cuda(rank)
_apply_optimizer_in_backward(
optimizer_class=torch.optim.SGD,
params=e.parameters(),
optimizer_kwargs={"lr": 0.03},
)
e = DDP(e, device_ids=[rank])
inp = torch.randn(1, 10, device=rank)
e(inp).sum().backward()
```
Constraints:
1. Custom communication hook not yet supported
2. _apply_optim_in_backward needs to be called _before_ wrapping model in DDP.
3. DDP will remove the gradient hooks _apply_optim_in_backward registers, so these gradient hooks will not be fired and cannot be used.
4. All DDP managed parameters have grads set to None by default once optimizer is applied. There is no support for setting only some parameter grads to None, this must be done manually by user (and DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0 needs to be set.)
Differential Revision: [D41329694](https://our.internmc.facebook.com/intern/diff/D41329694/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41329694/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89194
Approved by: https://github.com/zhaojuanmao
This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.
**Overview**
This PR addresses everything in https://github.com/pytorch/pytorch/issues/90657 except renaming `keep_low_precision_grads` to `keep_grads_in_reduce_dtype` since that is BC breaking. I recommend reading the issue before preceding.
For `MixedPrecision(param_dtype, reduce_dtype, ...)`, the exact rule for parameter and gradient reduction mixed precision that we are following is:
> If `param_dtype is not None` and `reduce_dtype is None`, then we infer `reduce_dtype = param_dtype`. Otherwise, we take `param_dtype` and `reduce_dtype` as is.
This PR enforces that, at the `FlatParamHandle` level, `handle._config.fwd_bwd_param_dtype` and `handle._config.reduce_dtype` are never `None`. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in `handle._orig_param_dtype`. It is no longer to check against `None`.
This avoids ambiguous cases such as when the user passes `MixedPrecision(param_dtype=torch.float32)`. In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).
**Additional Details**
- We remove `FullyShardedDataParallel._mixed_precision_enabled_for_params`, `FullyShardedDataParallel._mixed_precision_enabled_for_reduce`, and `FullyShardedDataParallel._mixed_precision_keep_low_precision_grads` since they are not used.
- The unit test `test_meta_device_with_mixed_precision()` exercises a tricky edge case with meta device initialization, `apply()` (calling into `summon_full_params()`), and `param_dtype=torch.float32` for a nested wrapping case, where each nested instance has parameters.
- We include some minor fixes/improvements to the communication hook implementation.
**Follow-Ups**
- We should get rid of `HandleConfig` and store its fields as attributes on `FlatParamHandle` directly.
- Rename `keep_low_precision_grads` to `keep_grads_in_reduce_dtype`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90660
Approved by: https://github.com/zhaojuanmao
This is for consistency with FSDP.
- `_FSDP_WRAPPED_MODULE` and `_CHECKPOINT_WRAPPED_MODULE` are exactly the wrapped module variable name, meaning you can call `getattr(module, _FSDP_WRAPPED_MODULE)` or `getattr(module, _CHECKPOINT_WRAPPED_MODULE)`.
- `_FSDP_PREFIX` and `_CHECKPOINT_PREFIX` include the trailing `"."` and are only used for FQNs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87951
Approved by: https://github.com/zhaojuanmao
We change `.module` to pass through `ActivationWrapper` directly to the inner wrapped module. This should fix the state dict issues.
Given the invariant that `.module` always returns the inner wrapped module, FSDP always registers the `FlatParameter` on the inner wrapped module, regardless of if there is an intermediate `ActivationWrapper` or not. This avoids casing on whether `ActivationWrapper` is added before or after FSDP construction.
This PR removes the added unit test in `test_fsdp_misc.py` for changing the wrapped module because I would rather not complicated `_lazy_init()` logic just to support that kind of adversarial behavior. The user should not be swapping out the wrapped module arbitrarily or deleting the `FlatParameter`. I mainly had those tests to make sure that all branches of the code I added was correct.
Differential Revision: [D40799961](https://our.internmc.facebook.com/intern/diff/D40799961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87950
Approved by: https://github.com/zhaojuanmao
`_recursive_wrap()` returns `Tuple[nn.Module, int]`, where the `nn.Module` is the in-place modified module and the `int` is the numel wrapped. In that sense, the return value is not meant to be publicly used. The `apply_activation_checkpointing()` docs already suggest that the function returns `None`, so this PR simply follows that.
**Test Plan**
CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87871
Approved by: https://github.com/zhaojuanmao
Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept.
Now, offload to CPU + checkpoint can be composed together, such as
```
# apply AC to transformer layers
apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer))
# offload the rest of activations to CPU
model = offload_wrapper(model)
```
Will polish / add tests if this proposal sounds good.
Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85459
Approved by: https://github.com/awgu
This fixes the activation offload for checkpoint wrapper, which was previously broken. It was broken because it was tightly coupled with activation checkpoint, i.e. we did:
```
with save_on_cpu:
checkpoint(module_forward())
```
which would not offload any activation tensors to CPU, as those activations would already be not saved by autograd due to the checkpoint implementation taking priority.
Now, if `offload_to_cpu` is specified, we only do `save_on_cpu` and no checkpoint, so all intermediate tensors are offloaded to CPU instead of checkpointed.
These wrappers can be composed, i.e. if we have
`(Linear, Linear) -> (Linear, Linear) -> (Linear, Linear)`
we can do
`Offload( checkpoint(Linear, Linear) -> checkpoint(Linear, Linear) -> checkpoint(Linear, Linear))`
and inner tensors would be checkpointed while outers will be offloaded.
Differential Revision: [D39448882](https://our.internmc.facebook.com/intern/diff/D39448882/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84907
Approved by: https://github.com/awgu
Those functions enable membership introspection into a ProcessGroup. A common scenario
that needs this is library code that consumes a PG but doesn't create it, which means
it likely doesn't know the global ranks used to create it.
Translating from local to global is necessary when using c10d collectives like broadcast
so if your library code adopts the convention of using local rank 0, it needs
to the following:
```python
import torch.distributed as dist
my_pg: dist.ProcessGroup = ...
def my_library_bcast(tensor)
dist.broadcast(tensor, src=dist.get_global_rank(my_pg, local_rank=0), my_pg)
```
This implements some of the helpers needed to implement the `clone` API from: https://github.com/pytorch/pytorch/issues/81291
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82134
Approved by: https://github.com/rohan-varma
Fixes https://github.com/pytorch/pytorch/issues/79114
An implementation of a FSDP communication hook interface for a sharded strategies:
- Added `reduce_scatter_hook` to default hooks. Note the difference of `reduce_scatter` from `all_reduce`, it requires 2 tensors:`input_gradient` and `output` variables and stores result in `output`, which is further used as a summed gradient shard.
- Adjusted FSDP logic to return `reduce_scatter_hook` as a default communication hook for sharded strategies, `DefaultState` is the same for sharded and non-sharded strategies.
- Adjusted low-precision hooks to work with both `all_reduce` and `reduce_scatter` depending on whether `output` tensor is provided or not.
Test plan:
Added all existing sharded strategies as an input parameters to existing tests.
For`test_default_communication_hook_behaviour` double checked how a linear layer is sharded across workers. This test creates a simple net ``1 X N``, where ``N`` - is the number of workers. For sharded cases, ``N`` parameters are sharded across ``N`` workers. This test checks that after backward, each worker has a proper value in it's chunk of the gradient, or the whole gradient on every worker is equal to an expected value.
Checked that low-precision tests work for sharded cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83254
Approved by: https://github.com/rohan-varma, https://github.com/awgu
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
Allow checkpoint_wrapper to take in the checkpoint functional impl. This decouples it from torch.utils.checkpoint and allows other checkpoint implementations to be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83035
Approved by: https://github.com/awgu
### Description
Across PyTorch's docstrings, both `callable` and `Callable` for variable types. The Callable should be capitalized as we are referring to the `Callable` type, and not the Python `callable()` function.
### Testing
There shouldn't be any testing required.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82487
Approved by: https://github.com/albanD
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.
I've also made slight adjustments to existing implementation of an `all_reduce` hook including:
`AllReduceState` ->` DefaultState `, motivation: `AllReduceState` is not specific to all_reduce. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. `fp16_hook` and `bf16_hook`.
I've put all 3 hooks into `default_hooks.py`
Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify MixedPrecision for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and MixedPrecision(reduce_dtype=<precision>) specified, but I am happy to discuss this and adjust current implementation.
As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.
PS. first version of this PR was reverted, because added unittests didn't include NCCL version checks for `bf16_hook` (thus failed on trunk). In this version, I've added appropriate checks for tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81711
Approved by: https://github.com/rohan-varma
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.
I've also made slight adjustments to existing implementation of an `all_reduce` hook including:
- `AllReduceState` -> `DefaultState` , motivation: `AllReduceState` is not specific to `all_reduce`. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook.
- I've put all 3 hooks into `default_hooks.py`
Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify `MixedPrecision` for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and `MixedPrecision(reduce_dtype=<precision>)` specified, but I am happy to discuss this and adjust current implementation.
As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80557
Approved by: https://github.com/rohan-varma
Remove `_checkpoint_wrapped_module` prefixes when creating keys for optimizer state_dict.
Having these does not actually create an issue for optim_state_dict save / load, but we'd like to strip these keys out for downstream code that consumes these APIs typically expecting checkpointing prefixes to not exist (as checkpointing should be a transparent operation which should not change module / parameter names).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80480
Approved by: https://github.com/awgu, https://github.com/fegin
Fixes#75666
Current PR adds the functionality for `PostLocalSGD` communication hook and tests that communication hook can be properly saved and restored. Similar to https://github.com/pytorch/pytorch/pull/79334, where serialization was added to `PowerSGD`.
``__getstate__``
Returns:
```
``Dict[str, Any]`` which will be pickled and saved.
``process_group`` and ``subgroup`` are not serializable and excluded from
a returned state.
```
``__setstate__``
```
Takes provided ``state`` and retrieves ``PostLocalSGDState``.
``process_group`` and ``subgroup`` are set to default process_group and subgroup respectively.
Default subgroup is equivalent to the subgroup on each node.
```
Small adjustment to `PowerSGD`'s warning message.
Refactored unittest, i.e. separated parity and log checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80435
Approved by: https://github.com/awgu
This TODO is no longer needed, as we use `_register_fused_optim` to register the overlapped optimizer in DDP. Also, remove comment about API being experimental, as this API is no longer going to be used by end user.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80453
Approved by: https://github.com/awgu
Fixes#79114
An implementation of a FSDP communication hook interface for a NO_SHARD strategy:
- `FullyShardedDataParallel.register_comm_hook(self, state: object, hook: callable)` checks current sharding strategy. If it is other that NO_SHARD, raises a runtime error. Otherwise, sets and shares a specified hook and its state with all submodules
- When FSDP is ready to communicate a gradient, checks if there is a registered hook, and calls it instead of all_reduce. Additionally, gradient pre and post devision are not performed if a hook is registered.
To test the interface, I've implemented a communication hook, that calls for `all_reduce`.
A unittest:
- checks that is a sharding strategy is anything but NO_SHARD, a runtime error is raised
- checks that for a NO_SHARD case, model with registered all_reduce hook and without a hook work the same.
- checks for 2 types of FSDP models: with the wrapped first layer and without. (to make sure submodules have a hook registered)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79833
Approved by: https://github.com/rohan-varma, https://github.com/awgu
This PR addresses issue address #75666.
Stateful communication hook now can be saved and reloaded to resume training.
Current PR adds the functionality for PowerSGD communication hook and tests that communication hook can be properly saved and restored.
PowerSGD implementation uses ``__slots__``, as a result introduced __getstate__ and __setstate__ methods are implemented to work with `__slots__` and not` __dict__`.
`__getstate__ `
Returns:
A dictionary that represents a ``PowerSGDState`` which will be pickled and saved.
``process_group`` is non-serializable and excluded from a returned state.
`__setstate__`
Takes a provided ``state`` and retrieves ``PowerSGDState``.
``process_group`` is set to default with a proper warning issued to a user.
Unit test
A hook-independent `_test_hook_pickling` is added with this PR, as well as `test_ddp_hook_pickling_powerSGD`, which tests `powerSGD`’s ability to be saved and reloaded.
Currently, the test creates a ddp model with a provided hook, trains it for 10 epochs and saves model’s state and hook’s state.
During reloading, unit test makes sure that a warning was logged (only one warning and the proper one). It then proceeds to check that reloaded hook and original hook are the same. Finally, it checks that a hook’s state was properly initialized:
- it compares slot values (all, but 2: `process_group` and `rng`) for original and reloaded state
- it checks that process group was set to a default group
- it checks that a random state was restored properly with np.testing.assert_array_equal, because `rng` is an instance of `np.random.RandomState`, represented by a tuple. One of entries is of `ndarray dtype[uint32]` type and `np.testing.assert_array_equal` is used for assertion.
Future To-Do:
- Implement similar __getstate__ and __setstate__ for other stateful communication hooks
- Add appropriate tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79334
Approved by: https://github.com/rohan-varma, https://github.com/awgu