Commit Graph

131 Commits

Author SHA1 Message Date
Aaron Orenstein
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
Jeeja
556e4ec6c9 [FSDP] Add device in pin_memory argument (#119878)
Add device to pin_memory argument to support other backends like HPU

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119878
Approved by: https://github.com/awgu
2024-05-14 10:30:00 +00:00
Yuxin Wu
9cd4bcb2c4 [FSDP] mark pre_backward_hook unserializable (#125464)
Saw a warning like this:

```
/opt/conda/lib/python3.10/site-packages/torch/utils/hooks.py:86: UserWarning: backward hook functools.partial(<function _pre_backward_hook at 0x7f9a3940fac0>, FullyShardedDataParallel(

....

), <torch.distributed.fsdp.flat_param.FlatParamHandle object at 0x7f25202a9720>) on tensor will not be serialized.  If this is expected, you can decorate the function with @torch.utils.hooks.unserializable_hook to suppress this warning
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125464
Approved by: https://github.com/ezyang
2024-05-06 20:20:31 +00:00
Andrew Gu
79af814369 [FSDP] Added private _unshard API (#124304)
Some toy example:
<img width="998" alt="Screenshot 2024-04-17 at 2 00 05 PM" src="https://github.com/pytorch/pytorch/assets/31054793/b5665a63-beb0-4ca1-92c6-c57a052812fd">

We define `FullyShardedDataParallel._unshard(async_op: bool = False)` that can be used to prefetch all-gathers. The user should make sure:
1. Run lazy init before the first `_unshard` call of training. For example, this can hackily be done via `root_module.check_is_root()` on the root FSDP module `root_module`.
2. Call `root_module._wait_unshard_streams_on_current_stream()` before the first `_unshard` call of the current iteration (just need to call it once after last optimizer step and before first `_unshard` of this iteration).

Differential Revision: [D56262876](https://our.internmc.facebook.com/intern/diff/D56262876)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124304
Approved by: https://github.com/wanchaol
2024-05-03 13:14:15 +00:00
Chirag Pandya
b6201a60c5 [BE] minor logging cleanup in distributed (#122921)
Summary:
    Minor logging cleanup in distributed library
    1. Don't use "f" formatted strings - address linter issues.
    2. Nits: Make use of unused `e` (error) in a few logs.
    3. Change info->debug as asked in issue #113545
    4. Nit: rename log -> logger in a few files for consistency
    5. Fix a linter error.

    Test Plan:
    1. Local build passes.
    2. Linter is happy.

    Reviewers: wanchaol

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122921
Approved by: https://github.com/wanchaol
2024-03-29 03:34:01 +00:00
Mihir Patel
d9d8c2b79f Remove HSDP validation check (#112435)
Currently, HSDP validates that all intra/inter node PGs are the same. This makes sense if you are only using HSDP with no other forms of parallelism and is a nice but not necessary sanity check.

However, if you want to mix HSDP with other forms, say tensor parallelism on the FFN of a transformer block, the intra/inter node PGs will be different for that layer. This check raises errors in this scenario, so we need to remove this assumption.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112435
Approved by: https://github.com/wz337, https://github.com/Skylion007
2024-02-05 22:27:53 +00:00
Wei (Will) Feng
91d5f94f85 [FSDP] Idempotent reshard (#117997)
address assertion error "Expects storage to be allocated" by making reshard idempotent https://github.com/pytorch/pytorch/issues/117510

```pytest test/distributed/fsdp/test_fsdp_fine_tune.py -k test_parity_with_non_frozen_fsdp```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117997
Approved by: https://github.com/awgu
2024-01-25 23:29:23 +00:00
Wanchao Liang
848cfe8d45 [reland] unflatten_tensor on compute stream for DTensorExtension (#117020)
reland of https://github.com/pytorch/pytorch/pull/116559, which was reverted by internal.

The underlying reason for the revert is that the torch.dynamo.disable can't be used by the
pytorch codebase, as it's conflicting with some torch.deploy together, although the later one
only run some inference, but it somehow take that weird dependency on fsdp..

We have seen this issue with our functional collectives that we can't
use any dynamo components otherwise torch.deploy would complain..

verified internally that after removing torch.dynamo.disable the test
passed again

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117020
Approved by: https://github.com/awgu
2024-01-09 21:25:15 +00:00
Qinfan Wu
b847290ddd Back out "[2d] unflatten_tensor on compute stream for DTensorExtension (#116559)" (#116939)
Summary:
Original commit changeset: 65298112f3db

Original Phabricator Diff: D52530451

Differential Revision: D52583345

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116939
Approved by: https://github.com/842974287
2024-01-07 03:53:40 +00:00
Wanchao Liang
d9c0e37bab [2d] unflatten_tensor on compute stream for DTensorExtension (#116559)
Context: Existing FSDPExtension have some bug in the case when the
unflatten tensor involves some compute/communications in cuda stream,
the current logic of FSDPExtension unflatten tensor happens in the
unshard stream, which makes runtime lost sync with the compute stream,
and if there're some dependencies between the compute stream and the
unflatten tensor logic, currently it would lose sync point, which could
possibly lead to NaN.

This PR make the FSDPExtension to record the compute stream and let
DTensorExtension to directly use the compute stream for unflatten_tensor.

In long term we might want to directly make the FSDP runtime logic to only
make the unshard happen in unshard stream, and use unshard views to
happen in the compute stream. We currently fix this in the Extension
directly as this is the simplest thing to do without affecting FSDP
runtime logic

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116559
Approved by: https://github.com/awgu, https://github.com/fduwjj, https://github.com/yifuwang
ghstack dependencies: #116426
2024-01-03 07:29:08 +00:00
voznesenskym
77d5f60740 [fsdp][torch.compile] FSDP changes (#115497)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115497
Approved by: https://github.com/albanD
2023-12-19 18:44:36 +00:00
voznesenskym
310f6ab11a [fsdp] Replace acc_grad hooking with register_post_accumulate_grad_hook on flat_param (#112184)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112184
Approved by: https://github.com/albanD
ghstack dependencies: #115315
2023-12-13 16:24:44 +00:00
Chien-Chin Huang
4ba649e207 [FSDP][state_dict] Avoid assigning the root _device_mesh to the children _device_mesh (#114384)
Assigning the root _device_mesh to the children _device_mesh is not correct as each FSDP state can have a different DeviceMesh. We are also replacing fully_shard with a new implementation. So there is no need to worry about the fully_shard behavior.

Differential Revision: [D51507959](https://our.internmc.facebook.com/intern/diff/D51507959/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114384
Approved by: https://github.com/wz337
2023-11-30 02:08:31 +00:00
CK Luk
0ea126e834 add use_fake_all_gather and use_fake_reduce_scatter to FSDP for ablation studies (#113106)
Summary: As titled

Test Plan: Not needed because this is only for doing ablation studies

Reviewed By: awgu

Differential Revision: D50867908

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113106
Approved by: https://github.com/awgu
2023-11-17 05:43:30 +00:00
wz337
31ded95cd5 [2D] Bind _fsdp_extension to FSDP instances (#113237)
Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).

This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
2023-11-09 03:31:03 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Luo Bo
b691d09010 fix: reset prefetch flag upon reshard (#111354)
The `prefetched` flag should be reset upon reshard. Otherwise, for zero2, next access to the corresponding parameter will skip "unshard" operation, and results in wrong parameter shape.

The need of unsharding is also metioned [in the comment of `FlatParameterHandle.unshard`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_flat_param.py#L1241-L1242).

As [`FlatParameterHandle` already guarded it against unnecessary all gather](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_flat_param.py#L1240), this shouldn't incur extra communication overhead.

_Personally I also find `_prefetched` a bit of mis-named, it should really be `_unsharded`._
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111354
Approved by: https://github.com/awgu
2023-10-16 18:31:33 +00:00
Edwiv
5caf2e55d4 [FSDP] fix: fix for fsdp zero2 validation error (#110139)
# Problem
When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape.
<img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e">

# Analyze
When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward.

The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True.

This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110139
Approved by: https://github.com/awgu
2023-10-14 20:59:28 +00:00
Luo Bo
5ace912263 fix: do not reshard parameters twice (#110948)
This PR fixes potential double resharding of parameters that both:

1. requires no gradient and,
2. were used more than once during forward pass.

[`_register_post_backward_hook`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_runtime_utils.py#L1415) handles the case correctly, this PR does the same for `_register_post_backward_reshard_only_hook`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110948
Approved by: https://github.com/awgu
2023-10-12 15:09:33 +00:00
Rohan Varma
24e5d61af8 Log usage of optimizer in backward (#110206)
This will allow us to inspect and aggregate jobs that use optimizer in
backward

Differential Revision: [D48674740](https://our.internmc.facebook.com/intern/diff/D48674740/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110206
Approved by: https://github.com/awgu
2023-09-29 11:00:07 +00:00
Matthew Hoffman
68b0db1274 Define the public API for torch.distributed.fsdp (#109922)
Related: https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation
Related: https://github.com/microsoft/pylance-release/issues/2953

This fixes pylance issues for these classes:

```
"FullyShardedDataParallel" is not exported from module "torch.distributed.fsdp"
```

These classes all have public docs:

* [`BackwardPrefetch`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch)
* [`CPUOffload`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.CPUOffload)
* [`FullyShardedDataParallel`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel)
* [`MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision)
* [`ShardingStrategy`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy)

And it seems like all the newly added classes will have docs once they are released.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109922
Approved by: https://github.com/wanchaol
2023-09-28 02:15:58 +00:00
wz337
0aedacb4f7 [2D][1/N] Add _enable_extension to fsdp state (#109242)
Add _enable_extension to fsdp state. We will use this to determine whether we should enable the extension or not.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109242
Approved by: https://github.com/fegin
2023-09-16 19:03:10 +00:00
wz337
66af4f6ec7 [HSDP] Add device_mesh to FSDP kwarg and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-05 21:21:21 +00:00
PyTorch MergeBot
ab5b4c4419 Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"
This reverts commit cc220e45a8.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
2023-09-01 01:26:30 +00:00
wz337
cc220e45a8 [HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-01 00:15:00 +00:00
weifengpy
ec10b17cfb [FSDP] verify backward_prefetch works correctly with unit test (#107058)
issue resolved: https://github.com/pytorch/pytorch/pull/105984

context:
* CI did not catch the commit that breaks backward_prefetch https://github.com/pytorch/pytorch/pull/105006
* we had an action item to add unit test to prevent similar cases: https://github.com/pytorch/pytorch/pull/105984

what's included in this unit test
* monkey patch
torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch and check which handles are prefetched

for backward_prefetch = BackwardPrefetch.BACKWARD_PRE
* state._exec_order_data.handles_post_forward_order equals forward order: encoder 0...5 -> decoder 0...5 -> root
* pre-backward hook order: root -> decoder 5...0 -> encoder 5...0
* prefetch order: decoder 5...0 -> encoder 5...0 -> None
  * when current_handle=encoder 0, _get_handle_to_prefetch returns None

for backward_prefetch = BackwardPrefetch.BACKWARD_POST
* state._exec_order_data.handles_post_forward_order equals forward order: encoder 0...5 -> decoder 0...5 -> root
* post-backward hook (AccumulateGrad) order: decoder 5, 4...0 -> encoder 5...0 -> root
* prefetch order: decoder 4...0 -> encoder 5...0 -> None -> None
  * 1st None: when current_handle=encoder 0, _get_handle_to_prefetch returns None
  * 2nd None: when current_handle=root, we get decoder 5 inside _get_handle_to_prefetch but is not needed. so returns None
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107058
Approved by: https://github.com/awgu
2023-08-25 01:12:43 +00:00
Andrew Gu
2b964d6efd [FSDP] Enable async all-reduce for HSDP (#106080)
**Overview**
This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather.

For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter.

**Experiment**
<details>
<summary> Example 'before' trace </summary>
<img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5">

</details>

<details>
<summary> Example 'after' trace </summary>
<img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44">

</details>

For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth):
- Reference FSDP:
    - **1160 ms / iteration**
    - ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth
    - ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth
    - 50 GB/s theoretical inter-node bandwidth
- Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR:
    - **665 ms / iteration**
    - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
    - ~5 ms / decoder AG/RS --> 212 GB/s bandwidth
    - ~30 ms / encoder AR --> 2.34 GB/s bandwidth
    - ~55 ms / decoder AR --> 2.65 GB/s bandwidth
    - 300 GB/s theoretical intra-node bandwidth
- New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR:
    - **597 ms / iteration**
    - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
    - ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower)
    - ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster)
    - ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster)
    - ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower)
    - Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node!
- New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`:
    - **556 ms / iteration**
    - Speedup comes from faster overlapped AR

Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup.

**Unit Test**
This requires >= 4 GPUs:
```
python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity
```

Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106080
Approved by: https://github.com/ezyang
ghstack dependencies: #106068
2023-08-23 18:36:15 +00:00
Andrew Gu
50e1378680 [FSDP] Break up _post_backward_hook into smaller funcs (#106068)
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views).

The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.

Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.

Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068
Approved by: https://github.com/ezyang, https://github.com/fegin
2023-08-23 18:36:15 +00:00
Michael Voznesensky
42660015b4 [Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly (#106886)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106886
Approved by: https://github.com/awgu, https://github.com/wconstab
ghstack dependencies: #106884
2023-08-11 22:35:50 +00:00
weifengpy
4bc846c101 [FSDP] Ignore buffer type casting in ignored modules (#106766)
issue resolved: https://github.com/pytorch/pytorch/issues/97791

before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce

after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
2023-08-09 23:09:43 +00:00
Michael Voznesensky
d1a99a083f Reland Simplify handle indexing (#105006) (#106357)
This reverts commit a9a3c45649.

This PR changes the following:
- `_ExecOrderData.handle_to_handle_index` -> `FlatParamHandle._handle_index`
- `_ExecOrderData.handles_to_pre_forward_order_index` -> `FlatParamHandle._pre_forward_order_index`
- `_ExecOrderData.handles_to_post_forward_order_index` -> `FlatParamHandle._post_forward_index`
- `_FSDPState._needs_pre_forward_unshard` -> `FlatParamHandle._needs_pre_forward_unshard`
- `_FSDPState._needs_pre_backward_unshard` -> `FlatParamHandle._needs_pre_backward_unshard`
- `_FSDPState._handles_prefetched` -> `FlatParamHandle._prefetched`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106357
Approved by: https://github.com/awgu
2023-08-03 19:17:32 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Albert Chen
7c8efc9049 [PT][FSDP] Combine _utils.py into _common_utils.py [2/2] (#106181)
Summary:
https://github.com/pytorch/pytorch/issues/97813
This diffs moves `_no_dispatch_record_stream` and `_same_storage_as_data_ptr`

Test Plan: CI

Differential Revision: D47706114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106181
Approved by: https://github.com/awgu
2023-07-28 17:15:25 +00:00
Andrew Gu
841b4acf1e [FSDP][Easy] Rename to _comm_hook, _comm_hook_state (#106033)
This is just out of preference to make the naming convention consistent with `register_comm_hook()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106033
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Andrew Gu
035704e88d [FSDP][Easy] Move post-bwd hook logging to own func (#106032)
This is to help make `_post_backward_hook()` easier to read. I plan to refactor some other parts in future PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106032
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Daniel Dale
6b6702f506 Enhance no_grad-context FSDP backward handling (#105374)
Fixes #105369
Fixes #105371

Addressing two somewhat distinct issues that involve the same test in this PR:

1. To fix #105369:
    - Add a `no_grad` guard to [`_register_post_backward_reshard_only_hooks`](93f852f201/torch/distributed/fsdp/_runtime_utils.py (L1406)) to avoid registering post-backward hooks that would not be removed in that context.

2. To fix #105371:
    - Add a `grad` context condition to [`_use_sharded_flat_param`](93f852f201/torch/distributed/fsdp/flat_param.py (L1645C9-L1645C32)) logic to trigger post-forward `_use_sharded_views` in a `no_grad` context for `NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105374
Approved by: https://github.com/awgu
2023-07-26 14:12:13 +00:00
Andrew Gu
c099b80073 [FSDP] Add record_function for explicit prefetching (#105985)
Example:
<img width="568" alt="Screenshot 2023-07-25 at 7 41 43 PM" src="https://github.com/pytorch/pytorch/assets/31054793/5f3f07b3-97f4-4493-9cab-5619484e2f6d">

This can be particularly help when `with_stack=False`, in which case it is harder to tell the prefetch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105985
Approved by: https://github.com/fegin
2023-07-26 12:16:35 +00:00
Andrew Gu
a9a3c45649 Revert "Simplify handle indexing (#105006)" (#105984)
This reverts commit 429d45f91a.

Unfortunately, https://github.com/pytorch/pytorch/pull/105006 broke backward prefetching (where backward prefetching working correctly was not captured in our unit tests).

I need more time to dig into this (tomorrow), but I think the issue is related to:
429d45f91a (diff-9a6937168d232432c34c2c4605b96f3147afa2786e287f74b6074b20aa5980e6R143-R146)

Follow-ups:
1. Investigate this thoroughly
2. Add unit tests to capture backward prefetch functionality
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105984
Approved by: https://github.com/fegin
2023-07-26 12:12:14 +00:00
Rohan Varma
a326f5621e composable fsdp, checkpoint, + compile test (#105180)
Test to ensure that composable FSDP, checkpoint, and compile all work
together. Includes a change from https://github.com/pytorch/pytorch/pull/105090
which we can land in that PR first.

Differential Revision: [D47452973](https://our.internmc.facebook.com/intern/diff/D47452973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105180
Approved by: https://github.com/awgu
2023-07-26 07:03:09 +00:00
Michael Voznesensky
429d45f91a Simplify handle indexing (#105006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105006
Approved by: https://github.com/awgu
2023-07-21 05:53:23 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Rohan Varma
242fc29c96 [FSDP] Refactor optimizer in backward (#104813)
1) Use zero_grad(set_to_none=True) to set grad to None, 2) call
prepare_grad_for_optim() before call to .step, 3) use
_reset_flat_param_grad_info to set flat param gradient back to None. These
changes should just be refactors and equivalent to how gradient memory was
managed  before.

Differential Revision: [D47310761](https://our.internmc.facebook.com/intern/diff/D47310761/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104813
Approved by: https://github.com/awgu
2023-07-13 06:42:53 +00:00
Rohan Varma
f2eed129c4 FSDP optimizer overlap (#98667)
constraints:

1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.

Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-13 06:42:53 +00:00
Andrew Gu
954bae8e53 [FSDP][Easy] Rename streams; add back stream sharing test (#104966)
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
2023-07-13 00:24:41 +00:00