### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.
$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$
Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.
### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$
This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.
### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.
Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
constraints:
1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.
Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.
_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's _runtime_utils.py.
Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.
This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.
However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:
os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"
Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.
Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.
Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).
- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
- Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
- The approach is to register a multi-grad ~~post~~-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).
~~This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called).~~ This was resolved in https://github.com/pytorch/pytorch/pull/102859.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101982
Approved by: https://github.com/rohan-varma
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).
Approved by: https://github.com/awgu
Add device mesh to fsdp state
skip dist.get_world_size(pg) != dist.get_world_size()
address test_fake_pg.py test failure
fix test_fake_py.py failure
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102551
Approved by: https://github.com/fegin
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.
This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.
The main issues addressed are:
#### 1. Device decision for FSDP wrapping of Modules without Parameters
Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.
#### 2. Abstraction of a cuda-like device
Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
Fixes#99545
There is currently no topological constraint dictating FSDP instances own ``FlatParamHandle`` s directly. If all parameters are managed by descendant FSDP instances leaving an FSDP instance with no direct ``state._handles``, the ``should_cast_forward_inputs`` decisions below in both ``_root_pre_forward()`` and ``_pre_forward()`` respectively can return incorrect decisions [^1].
For [``_root_pre_forward()``](436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L514)):
436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L602-L604)
For [``_pre_forward``](436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L384)):
436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L420-L422)
See the [related issue](https://github.com/pytorch/pytorch/issues/99545) for reproduction.
### Remediation
In this PR, I amend the two decision statements referenced above (in both `_root_pre_forward()` and `_pre_forward()`) to account for FSDP instances without direct handles:
```python
should_cast_forward_inputs = len(state._handles) > 0 and all(
not handle._force_full_precision for handle in state._handles
)
```
If one configures ``MixedPrecision`` in the example above with ``cast_forward_inputs=True`` and the ``should_cast_forward_inputs`` adjustment above, FSDP returns to the expected behavior and produces no error.
Though the check is the same in both ``_root_pre_forward()`` and ``_pre_forward()`` and hence could be refactored into a separate function, I figured it may make sense to retain separate statements to preserve the ability for root-specific behavior in the future. Whichever approach the team prefers I can update this PR with.
### Implementation considerations and questions:
1. Rather than write a test that would arguably have a poor utility/resource usage profile, I have not added any tests associated with this PR. The new decision logic is exercised by all existing tests (which continue to pass after this PR of course) so I think the utility of new tests is fairly modest. Let me know if you think new tests should be added and I'm happy to do so.
2. As discussed above, the decision statement shared among ``_pre_forward()`` and ``_root_pre_forward()`` could be factored out into a separate function. Given the simplicity of the statement and to retain current flexibility for root-specific decisions it might not be worth the refactor so I haven't done it yet. Let me know if you'd like me to do so.
3. The note below could be updated to indicate the utility of setting ``cast_forward_inputs=True`` for the situations addressed with this PR but I haven't done so since I'm not sure it's worth complicating the current usage guidance. I'd be happy to add verbiage describing the use case if the team wants it.
cde35b4069/torch/distributed/fsdp/api.py (L175-L181)
Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!
[^1]: Though one could keep the existing decision logic and impose a new topological constraint requiring all FSDP instances have direct `_handles`, I think retaining the current wrapping flexibility is both convenient and useful enough (e.g. programmatic wrapping of modules that may or may not already have all parameters handled by descendant FSDP instances) to update the decision logic as discussed here instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99546
Approved by: https://github.com/awgu
If model.eval() is true, then runs the model in full precision.
Changes:
- Changed _force_full_precision to check self.is_training
- Check for _force_full_precision when casting gradients to reduced dtype
- Small change when accessing _full_prec_param_padded
- tests for class based and fully_shard APIs
Differential Revision: [D43933690](https://our.internmc.facebook.com/intern/diff/D43933690/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97645
Approved by: https://github.com/awgu
This PR ensures that when prefetching a `FlatParamHandle.unshard()`, we temporarily set the `FlatParamHandle._training_state` to the expected training state as if the `unshard()` were not prefetched since the `as_params` argument to `_use_unsharded_views()` depends on the handle's training state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98249
Approved by: https://github.com/rohan-varma
- The previous PR addressed one tree traversal in `_root_pre_forward()` but not the main one from `_get_fsdp_handles()` that runs for all settings.
- This PR saves `_all_handles` to cache `_get_fsdp_handles()` and `_all_fsdp_states` to cache `_get_fsdp_states()` (renamed from `_fsdp_states` compared to last PR) on the root state.
- This PR introduces a dummy `_RootFSDPState` class that inherits from `_FSDPState` to be used only for type checking since some attributes are only defined for root states.
- I found this approach to be better than adding `_p_assert(state.root_only_attr is not None, ...)` upon each usage of `root_only_attr`.
- This hopefully also helps readers to quickly see which attributes are defined only on root states.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95465
Approved by: https://github.com/fduwjj