Saw a warning like this:
```
/opt/conda/lib/python3.10/site-packages/torch/utils/hooks.py:86: UserWarning: backward hook functools.partial(<function _pre_backward_hook at 0x7f9a3940fac0>, FullyShardedDataParallel(
....
), <torch.distributed.fsdp.flat_param.FlatParamHandle object at 0x7f25202a9720>) on tensor will not be serialized. If this is expected, you can decorate the function with @torch.utils.hooks.unserializable_hook to suppress this warning
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125464
Approved by: https://github.com/ezyang
Some toy example:
<img width="998" alt="Screenshot 2024-04-17 at 2 00 05 PM" src="https://github.com/pytorch/pytorch/assets/31054793/b5665a63-beb0-4ca1-92c6-c57a052812fd">
We define `FullyShardedDataParallel._unshard(async_op: bool = False)` that can be used to prefetch all-gathers. The user should make sure:
1. Run lazy init before the first `_unshard` call of training. For example, this can hackily be done via `root_module.check_is_root()` on the root FSDP module `root_module`.
2. Call `root_module._wait_unshard_streams_on_current_stream()` before the first `_unshard` call of the current iteration (just need to call it once after last optimizer step and before first `_unshard` of this iteration).
Differential Revision: [D56262876](https://our.internmc.facebook.com/intern/diff/D56262876)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124304
Approved by: https://github.com/wanchaol
Summary:
Minor logging cleanup in distributed library
1. Don't use "f" formatted strings - address linter issues.
2. Nits: Make use of unused `e` (error) in a few logs.
3. Change info->debug as asked in issue #113545
4. Nit: rename log -> logger in a few files for consistency
5. Fix a linter error.
Test Plan:
1. Local build passes.
2. Linter is happy.
Reviewers: wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122921
Approved by: https://github.com/wanchaol
Currently, HSDP validates that all intra/inter node PGs are the same. This makes sense if you are only using HSDP with no other forms of parallelism and is a nice but not necessary sanity check.
However, if you want to mix HSDP with other forms, say tensor parallelism on the FFN of a transformer block, the intra/inter node PGs will be different for that layer. This check raises errors in this scenario, so we need to remove this assumption.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112435
Approved by: https://github.com/wz337, https://github.com/Skylion007
reland of https://github.com/pytorch/pytorch/pull/116559, which was reverted by internal.
The underlying reason for the revert is that the torch.dynamo.disable can't be used by the
pytorch codebase, as it's conflicting with some torch.deploy together, although the later one
only run some inference, but it somehow take that weird dependency on fsdp..
We have seen this issue with our functional collectives that we can't
use any dynamo components otherwise torch.deploy would complain..
verified internally that after removing torch.dynamo.disable the test
passed again
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117020
Approved by: https://github.com/awgu
Context: Existing FSDPExtension have some bug in the case when the
unflatten tensor involves some compute/communications in cuda stream,
the current logic of FSDPExtension unflatten tensor happens in the
unshard stream, which makes runtime lost sync with the compute stream,
and if there're some dependencies between the compute stream and the
unflatten tensor logic, currently it would lose sync point, which could
possibly lead to NaN.
This PR make the FSDPExtension to record the compute stream and let
DTensorExtension to directly use the compute stream for unflatten_tensor.
In long term we might want to directly make the FSDP runtime logic to only
make the unshard happen in unshard stream, and use unshard views to
happen in the compute stream. We currently fix this in the Extension
directly as this is the simplest thing to do without affecting FSDP
runtime logic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116559
Approved by: https://github.com/awgu, https://github.com/fduwjj, https://github.com/yifuwang
ghstack dependencies: #116426
Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).
This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
# Problem
When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape.
<img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e">
# Analyze
When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward.
The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True.
This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110139
Approved by: https://github.com/awgu
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
**Overview**
This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather.
For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter.
**Experiment**
<details>
<summary> Example 'before' trace </summary>
<img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5">
</details>
<details>
<summary> Example 'after' trace </summary>
<img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44">
</details>
For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth):
- Reference FSDP:
- **1160 ms / iteration**
- ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth
- ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth
- 50 GB/s theoretical inter-node bandwidth
- Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR:
- **665 ms / iteration**
- ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
- ~5 ms / decoder AG/RS --> 212 GB/s bandwidth
- ~30 ms / encoder AR --> 2.34 GB/s bandwidth
- ~55 ms / decoder AR --> 2.65 GB/s bandwidth
- 300 GB/s theoretical intra-node bandwidth
- New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR:
- **597 ms / iteration**
- ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
- ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower)
- ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster)
- ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster)
- ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower)
- Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node!
- New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`:
- **556 ms / iteration**
- Speedup comes from faster overlapped AR
Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup.
**Unit Test**
This requires >= 4 GPUs:
```
python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity
```
Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106080
Approved by: https://github.com/ezyang
ghstack dependencies: #106068
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views).
The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.
Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.
Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068
Approved by: https://github.com/ezyang, https://github.com/fegin
issue resolved: https://github.com/pytorch/pytorch/issues/97791
before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce
after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.
$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$
Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.
### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$
This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.
### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.
Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
constraints:
1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.
Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma