Commit Graph

98 Commits

Author SHA1 Message Date
Mihir Patel
3f1f057adf Remove parent device mesh check (#118620)
Removes raising error if a device_mesh has a parent.

The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- https://github.com/pytorch/pytorch/pull/118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.

I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118620
Approved by: https://github.com/wz337, https://github.com/wanchaol
2024-02-02 05:29:49 +00:00
Catherine Lee
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
PyTorch MergeBot
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
Edward Z. Yang
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
Wanchao Liang
eebf115686 [fsdp][2d] FSDP sync module states handle tensor subclass (#117336)
This PR adds the ability to let FSDP sync module states kwarg to handle
tensor subclass, because FSDP works on the "dp" mesh dimension, as long
as FSDP works on a different device mesh dimension, we can safety let
FSDP just broadcast the DTensor local shards.

fixes https://github.com/pytorch/pytorch/issues/117126

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117336
Approved by: https://github.com/awgu
2024-01-13 19:33:47 +00:00
Wanchao Liang
848cfe8d45 [reland] unflatten_tensor on compute stream for DTensorExtension (#117020)
reland of https://github.com/pytorch/pytorch/pull/116559, which was reverted by internal.

The underlying reason for the revert is that the torch.dynamo.disable can't be used by the
pytorch codebase, as it's conflicting with some torch.deploy together, although the later one
only run some inference, but it somehow take that weird dependency on fsdp..

We have seen this issue with our functional collectives that we can't
use any dynamo components otherwise torch.deploy would complain..

verified internally that after removing torch.dynamo.disable the test
passed again

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117020
Approved by: https://github.com/awgu
2024-01-09 21:25:15 +00:00
Qinfan Wu
b847290ddd Back out "[2d] unflatten_tensor on compute stream for DTensorExtension (#116559)" (#116939)
Summary:
Original commit changeset: 65298112f3db

Original Phabricator Diff: D52530451

Differential Revision: D52583345

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116939
Approved by: https://github.com/842974287
2024-01-07 03:53:40 +00:00
Wanchao Liang
d9c0e37bab [2d] unflatten_tensor on compute stream for DTensorExtension (#116559)
Context: Existing FSDPExtension have some bug in the case when the
unflatten tensor involves some compute/communications in cuda stream,
the current logic of FSDPExtension unflatten tensor happens in the
unshard stream, which makes runtime lost sync with the compute stream,
and if there're some dependencies between the compute stream and the
unflatten tensor logic, currently it would lose sync point, which could
possibly lead to NaN.

This PR make the FSDPExtension to record the compute stream and let
DTensorExtension to directly use the compute stream for unflatten_tensor.

In long term we might want to directly make the FSDP runtime logic to only
make the unshard happen in unshard stream, and use unshard views to
happen in the compute stream. We currently fix this in the Extension
directly as this is the simplest thing to do without affecting FSDP
runtime logic

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116559
Approved by: https://github.com/awgu, https://github.com/fduwjj, https://github.com/yifuwang
ghstack dependencies: #116426
2024-01-03 07:29:08 +00:00
Iris Zhang (PyTorch)
23fa9621e4 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099) (#115193)
Summary:

Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.

Test Plan: CI.

Differential Revision: D51861018

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
2023-12-08 08:44:32 +00:00
Nikita Shulga
a827ac71f2 Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)"
This reverts commit eaa64339d6.
2023-12-05 08:59:36 -08:00
Iris Zhang (PyTorch)
eaa64339d6 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.

Test Plan: CI.

Differential Revision: D51825114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-12-05 05:44:52 +00:00
PyTorch MergeBot
3a2e2044cd Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)"
This reverts commit 729ac7317a.

Reverted https://github.com/pytorch/pytorch/pull/114991 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/114991#issuecomment-1837214567))
2023-12-02 17:55:51 +00:00
Iris Zhang (PyTorch)
729ac7317a [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)
Summary:

Same content of changes as https://github.com/pytorch/pytorch/pull/114710

Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.
ghstack-source-id: 208980207
exported-using-ghexport

Test Plan: CI.

Reviewed By: wanchaol

Differential Revision: D51629761

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114991
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/fegin
2023-12-02 04:39:41 +00:00
wz337
7b3e45be59 [DeviceMesh] Rename get_dim_groups to get_group (#114708)
Rename get_dim_groups to get_group and update all callsites.

Differential Revision: [D51629801](https://our.internmc.facebook.com/intern/diff/D51629801/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114708
Approved by: https://github.com/XilunWu, https://github.com/wanchaol, https://github.com/fegin
2023-11-30 23:40:14 +00:00
Aaron Gokaslan
b7b2178204 [BE]: Remove useless lambdas (#113602)
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602
Approved by: https://github.com/albanD
2023-11-14 20:06:48 +00:00
BJ Hargrave
670abff6ff docs: Fix docstring lint errors in torch/distributed/fsdp/_flat_param.py & torch/distributed/fsdp/_init_utils.py (#113358)
Fixes #113189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113358
Approved by: https://github.com/kit1980
2023-11-11 01:53:02 +00:00
wz337
31ded95cd5 [2D] Bind _fsdp_extension to FSDP instances (#113237)
Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).

This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
2023-11-09 03:31:03 +00:00
Iris Zhang
12c1465d76 [DeviceMesh] Make mesh_resources private (#112294)
This is to prepare moving DeviceMesh as a standalone distributed package.

`_mesh_resources` should only be used in torch.distributed package.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112294
Approved by: https://github.com/fegin
2023-10-28 17:28:46 +00:00
Matthew Hoffman
68b0db1274 Define the public API for torch.distributed.fsdp (#109922)
Related: https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation
Related: https://github.com/microsoft/pylance-release/issues/2953

This fixes pylance issues for these classes:

```
"FullyShardedDataParallel" is not exported from module "torch.distributed.fsdp"
```

These classes all have public docs:

* [`BackwardPrefetch`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.BackwardPrefetch)
* [`CPUOffload`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.CPUOffload)
* [`FullyShardedDataParallel`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel)
* [`MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision)
* [`ShardingStrategy`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy)

And it seems like all the newly added classes will have docs once they are released.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109922
Approved by: https://github.com/wanchaol
2023-09-28 02:15:58 +00:00
wz337
8140494afd [3/N][2D] Enable training with new 2D flow (#110034)
Replacing https://github.com/pytorch/pytorch/pull/109553 as it gets reverted.

This PR enables training with new 2D flow and adds associated test. In addition, this PR moves the tensor/parallel/_data_parallel_utils.py that are fsdp specific back to tensor/parallel/fsdp.py to avoid circular dependency for ddp.py and test/distributed/tensor/parallel/test_ddp_2d_parallel.py.

state_dict related changes would be in later PRs.

cc. @fegin, @fduwjj, @wanchaol, @awgu
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110034
Approved by: https://github.com/fduwjj
2023-09-26 09:14:15 +00:00
PyTorch MergeBot
f5886bf352 Revert "[3/N][2D] Enable training with new 2D flow (#109553)"
This reverts commit 217b37c023.

Reverted https://github.com/pytorch/pytorch/pull/109553 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but those distributed failures look legit and they are failing in trunk https://hud.pytorch.org/pr/109553 ([comment](https://github.com/pytorch/pytorch/pull/109553#issuecomment-1734100546))
2023-09-25 16:37:19 +00:00
wz337
217b37c023 [3/N][2D] Enable training with new 2D flow (#109553)
This PR enables training with new 2D flow and adds associated test.

state_dict related changes would be in later PRs.

cc. @fegin, @fduwjj, @wanchaol, @awgu
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109553
Approved by: https://github.com/fegin, https://github.com/awgu
2023-09-25 05:32:07 +00:00
wz337
0aedacb4f7 [2D][1/N] Add _enable_extension to fsdp state (#109242)
Add _enable_extension to fsdp state. We will use this to determine whether we should enable the extension or not.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109242
Approved by: https://github.com/fegin
2023-09-16 19:03:10 +00:00
lxg2015
e19a855b4d [HSDP] Fix Node 1 unable receive parameters from Node 0 (#108331)
When use hybrid_shard mode FSDP,
state.process_group means gpu_0,1,,,~,7 on node 0,so gpus on node 1 cannot receive parameters, setting process_group to default_group(global_group)can fix this issue

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108331
Approved by: https://github.com/awgu
2023-09-11 15:13:28 +00:00
wz337
7bc25e38c0 [HSDP] Raise error when HSDP device_mesh has a parent_mesh (#108603)
As we don't currently support HSDP + TP yet, raises an error for HSDP initialization if a device_mesh passed in has a parent mesh.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108603
Approved by: https://github.com/awgu
2023-09-07 04:17:10 +00:00
Rohan Varma
db6d09c086 [RFC][FSDP] Don't move ignored params / buffers to device (#108033)
Since these are ignored by FSDP, don't move them.

Differential Revision: [D48727044](https://our.internmc.facebook.com/intern/diff/D48727044/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108033
Approved by: https://github.com/awgu
ghstack dependencies: #108032
2023-09-05 21:43:41 +00:00
Rohan Varma
3334ec3a00 [RFC] Don't materialize ignored modules for FSDP (#108032)
Per title. This seems needed for cases where I have a large embedding
I want to separately manage, but FSDP would initialize it and thus consume the
memory.

Currently the interaction with torchdistX materialize_module is not tested,
this can be done as follow up work.

Differential Revision: [D48722046](https://our.internmc.facebook.com/intern/diff/D48722046/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108032
Approved by: https://github.com/awgu
2023-09-05 21:43:41 +00:00
wz337
66af4f6ec7 [HSDP] Add device_mesh to FSDP kwarg and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-05 21:21:21 +00:00
PyTorch MergeBot
ab5b4c4419 Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"
This reverts commit cc220e45a8.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
2023-09-01 01:26:30 +00:00
wz337
cc220e45a8 [HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-01 00:15:00 +00:00
Michael Voznesensky
42660015b4 [Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly (#106886)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106886
Approved by: https://github.com/awgu, https://github.com/wconstab
ghstack dependencies: #106884
2023-08-11 22:35:50 +00:00
weifengpy
4bc846c101 [FSDP] Ignore buffer type casting in ignored modules (#106766)
issue resolved: https://github.com/pytorch/pytorch/issues/97791

before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce

after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
2023-08-09 23:09:43 +00:00
Michael Voznesensky
d1a99a083f Reland Simplify handle indexing (#105006) (#106357)
This reverts commit a9a3c45649.

This PR changes the following:
- `_ExecOrderData.handle_to_handle_index` -> `FlatParamHandle._handle_index`
- `_ExecOrderData.handles_to_pre_forward_order_index` -> `FlatParamHandle._pre_forward_order_index`
- `_ExecOrderData.handles_to_post_forward_order_index` -> `FlatParamHandle._post_forward_index`
- `_FSDPState._needs_pre_forward_unshard` -> `FlatParamHandle._needs_pre_forward_unshard`
- `_FSDPState._needs_pre_backward_unshard` -> `FlatParamHandle._needs_pre_backward_unshard`
- `_FSDPState._handles_prefetched` -> `FlatParamHandle._prefetched`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106357
Approved by: https://github.com/awgu
2023-08-03 19:17:32 +00:00
Andrew Gu
15953fdf35 [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969)
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
2023-08-03 12:42:14 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Andrew Gu
841b4acf1e [FSDP][Easy] Rename to _comm_hook, _comm_hook_state (#106033)
This is just out of preference to make the naming convention consistent with `register_comm_hook()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106033
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Andrew Gu
a9a3c45649 Revert "Simplify handle indexing (#105006)" (#105984)
This reverts commit 429d45f91a.

Unfortunately, https://github.com/pytorch/pytorch/pull/105006 broke backward prefetching (where backward prefetching working correctly was not captured in our unit tests).

I need more time to dig into this (tomorrow), but I think the issue is related to:
429d45f91a (diff-9a6937168d232432c34c2c4605b96f3147afa2786e287f74b6074b20aa5980e6R143-R146)

Follow-ups:
1. Investigate this thoroughly
2. Add unit tests to capture backward prefetch functionality
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105984
Approved by: https://github.com/fegin
2023-07-26 12:12:14 +00:00
Michael Voznesensky
429d45f91a Simplify handle indexing (#105006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105006
Approved by: https://github.com/awgu
2023-07-21 05:53:23 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Andrew Gu
610f74627e [FSDP][4/N] Remove _get_fully_sharded_module_to_states (#104409)
`_get_fully_sharded_module_to_states()` was used to emulate auto wrapping without actually calling `fully_shard`. Since we committed to unifying (see previous PR), we can remove this function and its helpers/tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104409
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:14 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Rohan Varma
0bf39d5663 [FSDP] Option for eval in fp32/bf16 (#104682)
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.

However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:

os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"

Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.

Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
2023-07-07 08:14:23 +00:00
Andrew Gu
d982fdb5d5 [FSDP] Rework meta device init (#104189)
This addresses https://github.com/pytorch/pytorch/issues/104187.

After this PR, the contract with the user is that:
- If passing `param_init_fn=None`, each `nn.Module.reset_parameters()` should only initialize its own parameters/buffers (like `parameters(recurse=False)`/`buffers(recurse=False)`).
- If passing `param_init_fn` not equal to `None`, then similarly, one call to `param_init_fn(module)` should only initialize `module`'s own parameters/buffers.

With this contract and this PR's changes, meta device initialization through either `reset_parameters()` or `param_init_fn` should be correct. Those functions will run on the original parameter/buffer shapes allowing for correct shape-dependent computations like for fan-in/fan-out, and there will not be any re-initialization of any module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104189
Approved by: https://github.com/rohan-varma
2023-07-01 00:25:12 +00:00
Rohan Varma
60e2a4a4a0 [2D parallel] workaround for FSDP init issue (#104398)
Closes https://github.com/pytorch/pytorch/issues/96491 and does so by relaxing FSDP's assumption that the entire input module must be on the same device. Now, FSDP can accept a module partially on CPU and GPU and just emits a warning.

Differential Revision: [D47117256](https://our.internmc.facebook.com/intern/diff/D47117256/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104398
Approved by: https://github.com/fegin
2023-06-29 16:07:07 +00:00
Andrew Gu
6493519fff [Easy][FSDP] Remove misleading asserts (#104274)
Since we do not call `_FSDPState.__init__()` and only use it for typing, it is not possible for these attributes to be `None`. The purpose of these `assert`s is to make sure that these attributes are set by `_init_process_group_state_for_hybrid_shard()`. If we care to make that explicit, I would posit that we should be using `hasattr` checks, not `is not None` checks, because if indeed `_init_process_group_state_for_hybrid_shard()` did not set these attributes, then even checking that it is not `None` would lead to an `AttributeError`. I do not include these `hasattr` checks for now since `_init_process_group_state_for_hybrid_shard()` is short enough that we can quickly tell by inspection that it sets the desired attributes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104274
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:47 +00:00
Andrew Gu
ba9f6e6e92 [FSDP] Validate ignored_modules, ignored_states (#104273)
This checks that `ignored_modules` and `ignored_states` have the expected type and provides a reasonable error message if not. Otherwise, if someone passes a mix of modules and parameters to `ignored_states` for example, then our code may be silently incorrect.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104273
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:47 +00:00
Andrew Gu
ec8aa6e592 [Easy][FSDP] Fix "column" -> "row" in PG example (#103975)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103975
Approved by: https://github.com/fduwjj
2023-06-21 20:41:50 +00:00
Michael Voznesensky
02f28de408 [dynamo x fsdp] Simplify stream logic handling (#103902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103902
Approved by: https://github.com/awgu
2023-06-21 01:34:19 +00:00
Andrew Gu
71b560208c [FSDP] Fix device_id when buffer-only module (#103504)
There was an issue reported internally that with `sync_module_states=True`, if the model had buffers on CPU, even with `device_id` specified, FSDP would try to broadcast CPU buffers, leading to an error like:
```
RuntimeError: No backend type associated with device type cpu
```

After some investigation, I determined that we should _not_ fix this by moving the buffers to GPU just for the broadcast and then back to CPU. Instead, we should fix our `device_id` logic.

The issue is that we always used the _parameters_ as the proxy to tell whether we should move module states to the device specified by `device_id`. However, a module (often the root) may not have any parameters but have some buffers! In that case, the buffers are left on CPU even if `device_id` is specified. This PR fixes this by considering both parameters and buffers for movement to `device_id`.

Note that this PR preserves the logic that `ignored_modules` / `ignored_parameters` are not considered for this movement, meaning that ignored parameters are moved to `device_id`.

Note also that I had to move the unit test back from using MTPG to the normal PG since otherwise, I could not repro the original error. (It seems like MTPG does not complain if we try to use `dist._broadcast_coalesced()` with CPU tensors.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103504
Approved by: https://github.com/rohan-varma
2023-06-13 18:33:26 +00:00
Yanli Zhao
f47ee87765 Fix ignored_states when they are passed as generators (#102575)
This PR fixed the case where ignored_states are passed as generators, not List/Set

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102575
Approved by: https://github.com/awgu
2023-05-31 15:58:55 +00:00