Commit Graph

118 Commits

Author SHA1 Message Date
CK Luk
0ea126e834 add use_fake_all_gather and use_fake_reduce_scatter to FSDP for ablation studies (#113106)
Summary: As titled

Test Plan: Not needed because this is only for doing ablation studies

Reviewed By: awgu

Differential Revision: D50867908

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113106
Approved by: https://github.com/awgu
2023-11-17 05:43:30 +00:00
wz337
31ded95cd5 [2D] Bind _fsdp_extension to FSDP instances (#113237)
Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).

This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
2023-11-09 03:31:03 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Luo Bo
b691d09010 fix: reset prefetch flag upon reshard (#111354)
The `prefetched` flag should be reset upon reshard. Otherwise, for zero2, next access to the corresponding parameter will skip "unshard" operation, and results in wrong parameter shape.

The need of unsharding is also metioned [in the comment of `FlatParameterHandle.unshard`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_flat_param.py#L1241-L1242).

As [`FlatParameterHandle` already guarded it against unnecessary all gather](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_flat_param.py#L1240), this shouldn't incur extra communication overhead.

_Personally I also find `_prefetched` a bit of mis-named, it should really be `_unsharded`._
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111354
Approved by: https://github.com/awgu
2023-10-16 18:31:33 +00:00
Edwiv
5caf2e55d4 [FSDP] fix: fix for fsdp zero2 validation error (#110139)
# Problem
When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape.
<img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e">

# Analyze
When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward.

The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True.

This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110139
Approved by: https://github.com/awgu
2023-10-14 20:59:28 +00:00
Luo Bo
5ace912263 fix: do not reshard parameters twice (#110948)
This PR fixes potential double resharding of parameters that both:

1. requires no gradient and,
2. were used more than once during forward pass.

[`_register_post_backward_hook`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_runtime_utils.py#L1415) handles the case correctly, this PR does the same for `_register_post_backward_reshard_only_hook`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110948
Approved by: https://github.com/awgu
2023-10-12 15:09:33 +00:00
Rohan Varma
24e5d61af8 Log usage of optimizer in backward (#110206)
This will allow us to inspect and aggregate jobs that use optimizer in
backward

Differential Revision: [D48674740](https://our.internmc.facebook.com/intern/diff/D48674740/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110206
Approved by: https://github.com/awgu
2023-09-29 11:00:07 +00:00
Matthew Hoffman
68b0db1274 Define the public API for torch.distributed.fsdp (#109922)
Related: https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation
Related: https://github.com/microsoft/pylance-release/issues/2953

This fixes pylance issues for these classes:

```
"FullyShardedDataParallel" is not exported from module "torch.distributed.fsdp"
```

These classes all have public docs:

* [`BackwardPrefetch`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch)
* [`CPUOffload`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.CPUOffload)
* [`FullyShardedDataParallel`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel)
* [`MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision)
* [`ShardingStrategy`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy)

And it seems like all the newly added classes will have docs once they are released.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109922
Approved by: https://github.com/wanchaol
2023-09-28 02:15:58 +00:00
wz337
0aedacb4f7 [2D][1/N] Add _enable_extension to fsdp state (#109242)
Add _enable_extension to fsdp state. We will use this to determine whether we should enable the extension or not.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109242
Approved by: https://github.com/fegin
2023-09-16 19:03:10 +00:00
wz337
66af4f6ec7 [HSDP] Add device_mesh to FSDP kwarg and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-05 21:21:21 +00:00
PyTorch MergeBot
ab5b4c4419 Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"
This reverts commit cc220e45a8.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
2023-09-01 01:26:30 +00:00
wz337
cc220e45a8 [HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-01 00:15:00 +00:00
weifengpy
ec10b17cfb [FSDP] verify backward_prefetch works correctly with unit test (#107058)
issue resolved: https://github.com/pytorch/pytorch/pull/105984

context:
* CI did not catch the commit that breaks backward_prefetch https://github.com/pytorch/pytorch/pull/105006
* we had an action item to add unit test to prevent similar cases: https://github.com/pytorch/pytorch/pull/105984

what's included in this unit test
* monkey patch
torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch and check which handles are prefetched

for backward_prefetch = BackwardPrefetch.BACKWARD_PRE
* state._exec_order_data.handles_post_forward_order equals forward order: encoder 0...5 -> decoder 0...5 -> root
* pre-backward hook order: root -> decoder 5...0 -> encoder 5...0
* prefetch order: decoder 5...0 -> encoder 5...0 -> None
  * when current_handle=encoder 0, _get_handle_to_prefetch returns None

for backward_prefetch = BackwardPrefetch.BACKWARD_POST
* state._exec_order_data.handles_post_forward_order equals forward order: encoder 0...5 -> decoder 0...5 -> root
* post-backward hook (AccumulateGrad) order: decoder 5, 4...0 -> encoder 5...0 -> root
* prefetch order: decoder 4...0 -> encoder 5...0 -> None -> None
  * 1st None: when current_handle=encoder 0, _get_handle_to_prefetch returns None
  * 2nd None: when current_handle=root, we get decoder 5 inside _get_handle_to_prefetch but is not needed. so returns None
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107058
Approved by: https://github.com/awgu
2023-08-25 01:12:43 +00:00
Andrew Gu
2b964d6efd [FSDP] Enable async all-reduce for HSDP (#106080)
**Overview**
This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather.

For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter.

**Experiment**
<details>
<summary> Example 'before' trace </summary>
<img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5">

</details>

<details>
<summary> Example 'after' trace </summary>
<img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44">

</details>

For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth):
- Reference FSDP:
    - **1160 ms / iteration**
    - ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth
    - ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth
    - 50 GB/s theoretical inter-node bandwidth
- Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR:
    - **665 ms / iteration**
    - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
    - ~5 ms / decoder AG/RS --> 212 GB/s bandwidth
    - ~30 ms / encoder AR --> 2.34 GB/s bandwidth
    - ~55 ms / decoder AR --> 2.65 GB/s bandwidth
    - 300 GB/s theoretical intra-node bandwidth
- New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR:
    - **597 ms / iteration**
    - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
    - ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower)
    - ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster)
    - ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster)
    - ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower)
    - Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node!
- New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`:
    - **556 ms / iteration**
    - Speedup comes from faster overlapped AR

Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup.

**Unit Test**
This requires >= 4 GPUs:
```
python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity
```

Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106080
Approved by: https://github.com/ezyang
ghstack dependencies: #106068
2023-08-23 18:36:15 +00:00
Andrew Gu
50e1378680 [FSDP] Break up _post_backward_hook into smaller funcs (#106068)
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views).

The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.

Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.

Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068
Approved by: https://github.com/ezyang, https://github.com/fegin
2023-08-23 18:36:15 +00:00
Michael Voznesensky
42660015b4 [Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly (#106886)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106886
Approved by: https://github.com/awgu, https://github.com/wconstab
ghstack dependencies: #106884
2023-08-11 22:35:50 +00:00
weifengpy
4bc846c101 [FSDP] Ignore buffer type casting in ignored modules (#106766)
issue resolved: https://github.com/pytorch/pytorch/issues/97791

before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce

after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
2023-08-09 23:09:43 +00:00
Michael Voznesensky
d1a99a083f Reland Simplify handle indexing (#105006) (#106357)
This reverts commit a9a3c45649.

This PR changes the following:
- `_ExecOrderData.handle_to_handle_index` -> `FlatParamHandle._handle_index`
- `_ExecOrderData.handles_to_pre_forward_order_index` -> `FlatParamHandle._pre_forward_order_index`
- `_ExecOrderData.handles_to_post_forward_order_index` -> `FlatParamHandle._post_forward_index`
- `_FSDPState._needs_pre_forward_unshard` -> `FlatParamHandle._needs_pre_forward_unshard`
- `_FSDPState._needs_pre_backward_unshard` -> `FlatParamHandle._needs_pre_backward_unshard`
- `_FSDPState._handles_prefetched` -> `FlatParamHandle._prefetched`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106357
Approved by: https://github.com/awgu
2023-08-03 19:17:32 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Albert Chen
7c8efc9049 [PT][FSDP] Combine _utils.py into _common_utils.py [2/2] (#106181)
Summary:
https://github.com/pytorch/pytorch/issues/97813
This diffs moves `_no_dispatch_record_stream` and `_same_storage_as_data_ptr`

Test Plan: CI

Differential Revision: D47706114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106181
Approved by: https://github.com/awgu
2023-07-28 17:15:25 +00:00
Andrew Gu
841b4acf1e [FSDP][Easy] Rename to _comm_hook, _comm_hook_state (#106033)
This is just out of preference to make the naming convention consistent with `register_comm_hook()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106033
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Andrew Gu
035704e88d [FSDP][Easy] Move post-bwd hook logging to own func (#106032)
This is to help make `_post_backward_hook()` easier to read. I plan to refactor some other parts in future PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106032
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Daniel Dale
6b6702f506 Enhance no_grad-context FSDP backward handling (#105374)
Fixes #105369
Fixes #105371

Addressing two somewhat distinct issues that involve the same test in this PR:

1. To fix #105369:
    - Add a `no_grad` guard to [`_register_post_backward_reshard_only_hooks`](93f852f201/torch/distributed/fsdp/_runtime_utils.py (L1406)) to avoid registering post-backward hooks that would not be removed in that context.

2. To fix #105371:
    - Add a `grad` context condition to [`_use_sharded_flat_param`](93f852f201/torch/distributed/fsdp/flat_param.py (L1645C9-L1645C32)) logic to trigger post-forward `_use_sharded_views` in a `no_grad` context for `NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105374
Approved by: https://github.com/awgu
2023-07-26 14:12:13 +00:00
Andrew Gu
c099b80073 [FSDP] Add record_function for explicit prefetching (#105985)
Example:
<img width="568" alt="Screenshot 2023-07-25 at 7 41 43 PM" src="https://github.com/pytorch/pytorch/assets/31054793/5f3f07b3-97f4-4493-9cab-5619484e2f6d">

This can be particularly help when `with_stack=False`, in which case it is harder to tell the prefetch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105985
Approved by: https://github.com/fegin
2023-07-26 12:16:35 +00:00
Andrew Gu
a9a3c45649 Revert "Simplify handle indexing (#105006)" (#105984)
This reverts commit 429d45f91a.

Unfortunately, https://github.com/pytorch/pytorch/pull/105006 broke backward prefetching (where backward prefetching working correctly was not captured in our unit tests).

I need more time to dig into this (tomorrow), but I think the issue is related to:
429d45f91a (diff-9a6937168d232432c34c2c4605b96f3147afa2786e287f74b6074b20aa5980e6R143-R146)

Follow-ups:
1. Investigate this thoroughly
2. Add unit tests to capture backward prefetch functionality
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105984
Approved by: https://github.com/fegin
2023-07-26 12:12:14 +00:00
Rohan Varma
a326f5621e composable fsdp, checkpoint, + compile test (#105180)
Test to ensure that composable FSDP, checkpoint, and compile all work
together. Includes a change from https://github.com/pytorch/pytorch/pull/105090
which we can land in that PR first.

Differential Revision: [D47452973](https://our.internmc.facebook.com/intern/diff/D47452973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105180
Approved by: https://github.com/awgu
2023-07-26 07:03:09 +00:00
Michael Voznesensky
429d45f91a Simplify handle indexing (#105006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105006
Approved by: https://github.com/awgu
2023-07-21 05:53:23 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Rohan Varma
242fc29c96 [FSDP] Refactor optimizer in backward (#104813)
1) Use zero_grad(set_to_none=True) to set grad to None, 2) call
prepare_grad_for_optim() before call to .step, 3) use
_reset_flat_param_grad_info to set flat param gradient back to None. These
changes should just be refactors and equivalent to how gradient memory was
managed  before.

Differential Revision: [D47310761](https://our.internmc.facebook.com/intern/diff/D47310761/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104813
Approved by: https://github.com/awgu
2023-07-13 06:42:53 +00:00
Rohan Varma
f2eed129c4 FSDP optimizer overlap (#98667)
constraints:

1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.

Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-13 06:42:53 +00:00
Andrew Gu
954bae8e53 [FSDP][Easy] Rename streams; add back stream sharing test (#104966)
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
2023-07-13 00:24:41 +00:00
Iris
4f8ba6f8f6 [DeviceMesh]Add validate mesh flag to DeviceMesh (#104807)
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.

_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's  _runtime_utils.py.

Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
2023-07-12 23:42:13 +00:00
Andrew Gu
63d1fb21f5 [FSDP] Default limit_all_gathers=True (#104900)
This PR defaults to `limit_all_gathers=True`.

I included a `record_function()` for the rate limiter synchronization to help with user confusion on the gap in the pre-forward:
<img width="874" alt="Screenshot 2023-07-10 at 3 28 18 PM" src="https://github.com/pytorch/pytorch/assets/31054793/61f55e0e-58d7-4162-9395-bea06d3e8d8a">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104900
Approved by: https://github.com/fegin
2023-07-11 01:04:29 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Rohan Varma
0bf39d5663 [FSDP] Option for eval in fp32/bf16 (#104682)
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.

However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:

os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"

Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.

Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
2023-07-07 08:14:23 +00:00
Iris
434fcffa21 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-06 05:36:19 +00:00
Wanchao Liang
8457703e8d lazy init device mesh in fsdp (#104447)
since fsdp state is lazy init, we also need to lazy init device mesh
otherwise devicemesh allgather check would trigger some mismatch in
allgather counts in fsdp tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104447
Approved by: https://github.com/wconstab
2023-06-30 04:40:16 +00:00
Rohan Varma
c866446d6c [FSDP] Check module.training for _root_cast_forward_inputs (#104223)
We might erroneously cast forward inputs for the root if it doesn't
manage any handles (FSDP parameters). As a fix, pass in the module and check
its training attribute to ensure we don't cast inputs in eval mode.

Differential Revision: [D47041673](https://our.internmc.facebook.com/intern/diff/D47041673/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104223
Approved by: https://github.com/fegin
2023-06-28 16:38:01 +00:00
Andrew Gu
9db8ad7f1d [FSDP] Support unfreezing params for reshard-only hook (#104186)
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).

- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-06-28 11:04:57 +00:00
shibo19
c2095af3f8 make funcs argument type from torch.cuda.stream as torch.Stream (#104156)
Fixes #ISSUE_NUMBER
1. we want to support fsdp for custom device, so we make funcs argument type from torch.cuda.stream as torch.Stream
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104156
Approved by: https://github.com/awgu
2023-06-28 06:02:56 +00:00
Michael Voznesensky
02f28de408 [dynamo x fsdp] Simplify stream logic handling (#103902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103902
Approved by: https://github.com/awgu
2023-06-21 01:34:19 +00:00
Andrew Gu
48056b168f [FSDP] Reshard frozen params in backward (#101982)
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
- Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
- The approach is to register a multi-grad ~~post~~-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).

~~This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called).~~ This was resolved in https://github.com/pytorch/pytorch/pull/102859.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101982
Approved by: https://github.com/rohan-varma
2023-06-08 21:12:45 +00:00
Iris
d5142c52d3 [FSDP]Remove dim_group from device_mesh init (#103218)
1) remove dim_group
2) don't init device_mesh if not using default_pg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103218
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-08 03:29:19 +00:00
Iris
a02a58d862 [FSDP][1/N]Add device_mesh to FSDPstate (#102317) (#102551)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).
Approved by: https://github.com/awgu

Add device mesh to fsdp state
skip dist.get_world_size(pg) != dist.get_world_size()
address test_fake_pg.py test failure
fix test_fake_py.py failure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102551
Approved by: https://github.com/fegin
2023-06-07 04:14:00 +00:00