This is for consistency with FSDP.
- `_FSDP_WRAPPED_MODULE` and `_CHECKPOINT_WRAPPED_MODULE` are exactly the wrapped module variable name, meaning you can call `getattr(module, _FSDP_WRAPPED_MODULE)` or `getattr(module, _CHECKPOINT_WRAPPED_MODULE)`.
- `_FSDP_PREFIX` and `_CHECKPOINT_PREFIX` include the trailing `"."` and are only used for FQNs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87951
Approved by: https://github.com/zhaojuanmao
We change `.module` to pass through `ActivationWrapper` directly to the inner wrapped module. This should fix the state dict issues.
Given the invariant that `.module` always returns the inner wrapped module, FSDP always registers the `FlatParameter` on the inner wrapped module, regardless of if there is an intermediate `ActivationWrapper` or not. This avoids casing on whether `ActivationWrapper` is added before or after FSDP construction.
This PR removes the added unit test in `test_fsdp_misc.py` for changing the wrapped module because I would rather not complicated `_lazy_init()` logic just to support that kind of adversarial behavior. The user should not be swapping out the wrapped module arbitrarily or deleting the `FlatParameter`. I mainly had those tests to make sure that all branches of the code I added was correct.
Differential Revision: [D40799961](https://our.internmc.facebook.com/intern/diff/D40799961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87950
Approved by: https://github.com/zhaojuanmao
`_recursive_wrap()` returns `Tuple[nn.Module, int]`, where the `nn.Module` is the in-place modified module and the `int` is the numel wrapped. In that sense, the return value is not meant to be publicly used. The `apply_activation_checkpointing()` docs already suggest that the function returns `None`, so this PR simply follows that.
**Test Plan**
CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87871
Approved by: https://github.com/zhaojuanmao
Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept.
Now, offload to CPU + checkpoint can be composed together, such as
```
# apply AC to transformer layers
apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer))
# offload the rest of activations to CPU
model = offload_wrapper(model)
```
Will polish / add tests if this proposal sounds good.
Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85459
Approved by: https://github.com/awgu
This fixes the activation offload for checkpoint wrapper, which was previously broken. It was broken because it was tightly coupled with activation checkpoint, i.e. we did:
```
with save_on_cpu:
checkpoint(module_forward())
```
which would not offload any activation tensors to CPU, as those activations would already be not saved by autograd due to the checkpoint implementation taking priority.
Now, if `offload_to_cpu` is specified, we only do `save_on_cpu` and no checkpoint, so all intermediate tensors are offloaded to CPU instead of checkpointed.
These wrappers can be composed, i.e. if we have
`(Linear, Linear) -> (Linear, Linear) -> (Linear, Linear)`
we can do
`Offload( checkpoint(Linear, Linear) -> checkpoint(Linear, Linear) -> checkpoint(Linear, Linear))`
and inner tensors would be checkpointed while outers will be offloaded.
Differential Revision: [D39448882](https://our.internmc.facebook.com/intern/diff/D39448882/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84907
Approved by: https://github.com/awgu
Those functions enable membership introspection into a ProcessGroup. A common scenario
that needs this is library code that consumes a PG but doesn't create it, which means
it likely doesn't know the global ranks used to create it.
Translating from local to global is necessary when using c10d collectives like broadcast
so if your library code adopts the convention of using local rank 0, it needs
to the following:
```python
import torch.distributed as dist
my_pg: dist.ProcessGroup = ...
def my_library_bcast(tensor)
dist.broadcast(tensor, src=dist.get_global_rank(my_pg, local_rank=0), my_pg)
```
This implements some of the helpers needed to implement the `clone` API from: https://github.com/pytorch/pytorch/issues/81291
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82134
Approved by: https://github.com/rohan-varma
Fixes https://github.com/pytorch/pytorch/issues/79114
An implementation of a FSDP communication hook interface for a sharded strategies:
- Added `reduce_scatter_hook` to default hooks. Note the difference of `reduce_scatter` from `all_reduce`, it requires 2 tensors:`input_gradient` and `output` variables and stores result in `output`, which is further used as a summed gradient shard.
- Adjusted FSDP logic to return `reduce_scatter_hook` as a default communication hook for sharded strategies, `DefaultState` is the same for sharded and non-sharded strategies.
- Adjusted low-precision hooks to work with both `all_reduce` and `reduce_scatter` depending on whether `output` tensor is provided or not.
Test plan:
Added all existing sharded strategies as an input parameters to existing tests.
For`test_default_communication_hook_behaviour` double checked how a linear layer is sharded across workers. This test creates a simple net ``1 X N``, where ``N`` - is the number of workers. For sharded cases, ``N`` parameters are sharded across ``N`` workers. This test checks that after backward, each worker has a proper value in it's chunk of the gradient, or the whole gradient on every worker is equal to an expected value.
Checked that low-precision tests work for sharded cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83254
Approved by: https://github.com/rohan-varma, https://github.com/awgu
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
Allow checkpoint_wrapper to take in the checkpoint functional impl. This decouples it from torch.utils.checkpoint and allows other checkpoint implementations to be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83035
Approved by: https://github.com/awgu
### Description
Across PyTorch's docstrings, both `callable` and `Callable` for variable types. The Callable should be capitalized as we are referring to the `Callable` type, and not the Python `callable()` function.
### Testing
There shouldn't be any testing required.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82487
Approved by: https://github.com/albanD
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.
I've also made slight adjustments to existing implementation of an `all_reduce` hook including:
`AllReduceState` ->` DefaultState `, motivation: `AllReduceState` is not specific to all_reduce. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. `fp16_hook` and `bf16_hook`.
I've put all 3 hooks into `default_hooks.py`
Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify MixedPrecision for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and MixedPrecision(reduce_dtype=<precision>) specified, but I am happy to discuss this and adjust current implementation.
As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.
PS. first version of this PR was reverted, because added unittests didn't include NCCL version checks for `bf16_hook` (thus failed on trunk). In this version, I've added appropriate checks for tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81711
Approved by: https://github.com/rohan-varma
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.
I've also made slight adjustments to existing implementation of an `all_reduce` hook including:
- `AllReduceState` -> `DefaultState` , motivation: `AllReduceState` is not specific to `all_reduce`. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook.
- I've put all 3 hooks into `default_hooks.py`
Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify `MixedPrecision` for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and `MixedPrecision(reduce_dtype=<precision>)` specified, but I am happy to discuss this and adjust current implementation.
As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80557
Approved by: https://github.com/rohan-varma
Remove `_checkpoint_wrapped_module` prefixes when creating keys for optimizer state_dict.
Having these does not actually create an issue for optim_state_dict save / load, but we'd like to strip these keys out for downstream code that consumes these APIs typically expecting checkpointing prefixes to not exist (as checkpointing should be a transparent operation which should not change module / parameter names).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80480
Approved by: https://github.com/awgu, https://github.com/fegin
Fixes#75666
Current PR adds the functionality for `PostLocalSGD` communication hook and tests that communication hook can be properly saved and restored. Similar to https://github.com/pytorch/pytorch/pull/79334, where serialization was added to `PowerSGD`.
``__getstate__``
Returns:
```
``Dict[str, Any]`` which will be pickled and saved.
``process_group`` and ``subgroup`` are not serializable and excluded from
a returned state.
```
``__setstate__``
```
Takes provided ``state`` and retrieves ``PostLocalSGDState``.
``process_group`` and ``subgroup`` are set to default process_group and subgroup respectively.
Default subgroup is equivalent to the subgroup on each node.
```
Small adjustment to `PowerSGD`'s warning message.
Refactored unittest, i.e. separated parity and log checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80435
Approved by: https://github.com/awgu
This TODO is no longer needed, as we use `_register_fused_optim` to register the overlapped optimizer in DDP. Also, remove comment about API being experimental, as this API is no longer going to be used by end user.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80453
Approved by: https://github.com/awgu
Fixes#79114
An implementation of a FSDP communication hook interface for a NO_SHARD strategy:
- `FullyShardedDataParallel.register_comm_hook(self, state: object, hook: callable)` checks current sharding strategy. If it is other that NO_SHARD, raises a runtime error. Otherwise, sets and shares a specified hook and its state with all submodules
- When FSDP is ready to communicate a gradient, checks if there is a registered hook, and calls it instead of all_reduce. Additionally, gradient pre and post devision are not performed if a hook is registered.
To test the interface, I've implemented a communication hook, that calls for `all_reduce`.
A unittest:
- checks that is a sharding strategy is anything but NO_SHARD, a runtime error is raised
- checks that for a NO_SHARD case, model with registered all_reduce hook and without a hook work the same.
- checks for 2 types of FSDP models: with the wrapped first layer and without. (to make sure submodules have a hook registered)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79833
Approved by: https://github.com/rohan-varma, https://github.com/awgu
This PR addresses issue address #75666.
Stateful communication hook now can be saved and reloaded to resume training.
Current PR adds the functionality for PowerSGD communication hook and tests that communication hook can be properly saved and restored.
PowerSGD implementation uses ``__slots__``, as a result introduced __getstate__ and __setstate__ methods are implemented to work with `__slots__` and not` __dict__`.
`__getstate__ `
Returns:
A dictionary that represents a ``PowerSGDState`` which will be pickled and saved.
``process_group`` is non-serializable and excluded from a returned state.
`__setstate__`
Takes a provided ``state`` and retrieves ``PowerSGDState``.
``process_group`` is set to default with a proper warning issued to a user.
Unit test
A hook-independent `_test_hook_pickling` is added with this PR, as well as `test_ddp_hook_pickling_powerSGD`, which tests `powerSGD`’s ability to be saved and reloaded.
Currently, the test creates a ddp model with a provided hook, trains it for 10 epochs and saves model’s state and hook’s state.
During reloading, unit test makes sure that a warning was logged (only one warning and the proper one). It then proceeds to check that reloaded hook and original hook are the same. Finally, it checks that a hook’s state was properly initialized:
- it compares slot values (all, but 2: `process_group` and `rng`) for original and reloaded state
- it checks that process group was set to a default group
- it checks that a random state was restored properly with np.testing.assert_array_equal, because `rng` is an instance of `np.random.RandomState`, represented by a tuple. One of entries is of `ndarray dtype[uint32]` type and `np.testing.assert_array_equal` is used for assertion.
Future To-Do:
- Implement similar __getstate__ and __setstate__ for other stateful communication hooks
- Add appropriate tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79334
Approved by: https://github.com/rohan-varma, https://github.com/awgu
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.
This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.
Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
I find that sometimes disabling intra-subgroup gradient allreduce can still give a satisfying accuracy for some cases, so better to make such gradient averaging configurable. This does not take into account the saving in the communication of allreducing gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76723
Approved by: https://github.com/rohan-varma