Commit Graph

70 Commits

Author SHA1 Message Date
PyTorch MergeBot
ab5b4c4419 Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"
This reverts commit cc220e45a8.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
2023-09-01 01:26:30 +00:00
wz337
cc220e45a8 [HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-01 00:15:00 +00:00
Michael Voznesensky
42660015b4 [Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly (#106886)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106886
Approved by: https://github.com/awgu, https://github.com/wconstab
ghstack dependencies: #106884
2023-08-11 22:35:50 +00:00
weifengpy
4bc846c101 [FSDP] Ignore buffer type casting in ignored modules (#106766)
issue resolved: https://github.com/pytorch/pytorch/issues/97791

before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce

after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
2023-08-09 23:09:43 +00:00
Michael Voznesensky
d1a99a083f Reland Simplify handle indexing (#105006) (#106357)
This reverts commit a9a3c45649.

This PR changes the following:
- `_ExecOrderData.handle_to_handle_index` -> `FlatParamHandle._handle_index`
- `_ExecOrderData.handles_to_pre_forward_order_index` -> `FlatParamHandle._pre_forward_order_index`
- `_ExecOrderData.handles_to_post_forward_order_index` -> `FlatParamHandle._post_forward_index`
- `_FSDPState._needs_pre_forward_unshard` -> `FlatParamHandle._needs_pre_forward_unshard`
- `_FSDPState._needs_pre_backward_unshard` -> `FlatParamHandle._needs_pre_backward_unshard`
- `_FSDPState._handles_prefetched` -> `FlatParamHandle._prefetched`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106357
Approved by: https://github.com/awgu
2023-08-03 19:17:32 +00:00
Andrew Gu
15953fdf35 [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969)
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
2023-08-03 12:42:14 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Andrew Gu
841b4acf1e [FSDP][Easy] Rename to _comm_hook, _comm_hook_state (#106033)
This is just out of preference to make the naming convention consistent with `register_comm_hook()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106033
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Andrew Gu
a9a3c45649 Revert "Simplify handle indexing (#105006)" (#105984)
This reverts commit 429d45f91a.

Unfortunately, https://github.com/pytorch/pytorch/pull/105006 broke backward prefetching (where backward prefetching working correctly was not captured in our unit tests).

I need more time to dig into this (tomorrow), but I think the issue is related to:
429d45f91a (diff-9a6937168d232432c34c2c4605b96f3147afa2786e287f74b6074b20aa5980e6R143-R146)

Follow-ups:
1. Investigate this thoroughly
2. Add unit tests to capture backward prefetch functionality
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105984
Approved by: https://github.com/fegin
2023-07-26 12:12:14 +00:00
Michael Voznesensky
429d45f91a Simplify handle indexing (#105006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105006
Approved by: https://github.com/awgu
2023-07-21 05:53:23 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Andrew Gu
610f74627e [FSDP][4/N] Remove _get_fully_sharded_module_to_states (#104409)
`_get_fully_sharded_module_to_states()` was used to emulate auto wrapping without actually calling `fully_shard`. Since we committed to unifying (see previous PR), we can remove this function and its helpers/tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104409
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:14 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Rohan Varma
0bf39d5663 [FSDP] Option for eval in fp32/bf16 (#104682)
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.

However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:

os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"

Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.

Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
2023-07-07 08:14:23 +00:00
Andrew Gu
d982fdb5d5 [FSDP] Rework meta device init (#104189)
This addresses https://github.com/pytorch/pytorch/issues/104187.

After this PR, the contract with the user is that:
- If passing `param_init_fn=None`, each `nn.Module.reset_parameters()` should only initialize its own parameters/buffers (like `parameters(recurse=False)`/`buffers(recurse=False)`).
- If passing `param_init_fn` not equal to `None`, then similarly, one call to `param_init_fn(module)` should only initialize `module`'s own parameters/buffers.

With this contract and this PR's changes, meta device initialization through either `reset_parameters()` or `param_init_fn` should be correct. Those functions will run on the original parameter/buffer shapes allowing for correct shape-dependent computations like for fan-in/fan-out, and there will not be any re-initialization of any module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104189
Approved by: https://github.com/rohan-varma
2023-07-01 00:25:12 +00:00
Rohan Varma
60e2a4a4a0 [2D parallel] workaround for FSDP init issue (#104398)
Closes https://github.com/pytorch/pytorch/issues/96491 and does so by relaxing FSDP's assumption that the entire input module must be on the same device. Now, FSDP can accept a module partially on CPU and GPU and just emits a warning.

Differential Revision: [D47117256](https://our.internmc.facebook.com/intern/diff/D47117256/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104398
Approved by: https://github.com/fegin
2023-06-29 16:07:07 +00:00
Andrew Gu
6493519fff [Easy][FSDP] Remove misleading asserts (#104274)
Since we do not call `_FSDPState.__init__()` and only use it for typing, it is not possible for these attributes to be `None`. The purpose of these `assert`s is to make sure that these attributes are set by `_init_process_group_state_for_hybrid_shard()`. If we care to make that explicit, I would posit that we should be using `hasattr` checks, not `is not None` checks, because if indeed `_init_process_group_state_for_hybrid_shard()` did not set these attributes, then even checking that it is not `None` would lead to an `AttributeError`. I do not include these `hasattr` checks for now since `_init_process_group_state_for_hybrid_shard()` is short enough that we can quickly tell by inspection that it sets the desired attributes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104274
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:47 +00:00
Andrew Gu
ba9f6e6e92 [FSDP] Validate ignored_modules, ignored_states (#104273)
This checks that `ignored_modules` and `ignored_states` have the expected type and provides a reasonable error message if not. Otherwise, if someone passes a mix of modules and parameters to `ignored_states` for example, then our code may be silently incorrect.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104273
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:47 +00:00
Andrew Gu
ec8aa6e592 [Easy][FSDP] Fix "column" -> "row" in PG example (#103975)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103975
Approved by: https://github.com/fduwjj
2023-06-21 20:41:50 +00:00
Michael Voznesensky
02f28de408 [dynamo x fsdp] Simplify stream logic handling (#103902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103902
Approved by: https://github.com/awgu
2023-06-21 01:34:19 +00:00
Andrew Gu
71b560208c [FSDP] Fix device_id when buffer-only module (#103504)
There was an issue reported internally that with `sync_module_states=True`, if the model had buffers on CPU, even with `device_id` specified, FSDP would try to broadcast CPU buffers, leading to an error like:
```
RuntimeError: No backend type associated with device type cpu
```

After some investigation, I determined that we should _not_ fix this by moving the buffers to GPU just for the broadcast and then back to CPU. Instead, we should fix our `device_id` logic.

The issue is that we always used the _parameters_ as the proxy to tell whether we should move module states to the device specified by `device_id`. However, a module (often the root) may not have any parameters but have some buffers! In that case, the buffers are left on CPU even if `device_id` is specified. This PR fixes this by considering both parameters and buffers for movement to `device_id`.

Note that this PR preserves the logic that `ignored_modules` / `ignored_parameters` are not considered for this movement, meaning that ignored parameters are moved to `device_id`.

Note also that I had to move the unit test back from using MTPG to the normal PG since otherwise, I could not repro the original error. (It seems like MTPG does not complain if we try to use `dist._broadcast_coalesced()` with CPU tensors.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103504
Approved by: https://github.com/rohan-varma
2023-06-13 18:33:26 +00:00
Yanli Zhao
f47ee87765 Fix ignored_states when they are passed as generators (#102575)
This PR fixed the case where ignored_states are passed as generators, not List/Set

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102575
Approved by: https://github.com/awgu
2023-05-31 15:58:55 +00:00
Rohan Varma
3dfa755a1f [MTPG] Enable for some tests in test_fsdp_misc (#102043)
Enables MTPG for some FSDP tests in this file. Tests that need the
backward pass and warning logging are left as follow up work.

Backward pass issue: It seems that there is a hang with all_gather. Will sync with @kumpera on this.

Warning issue: We have a couple tests that regex check on warnings, but in the
multithreaded scenario these warnings are somehow not logged.

Differential Revision: [D43209769](https://our.internmc.facebook.com/intern/diff/D43209769/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102043
Approved by: https://github.com/awgu
2023-05-26 06:21:25 +00:00
Yanli Zhao
956bd03808 add ignored_states to FSDP/fully_shard (#102056)
Add 'ignored_states' that accepts either a list of ignored_parameters or a list of nn modules for FSDP model wrapper and fully_shard composable APIs, it is recommended to use 'ignored_states' over 'ignored_modules' moving forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102056
Approved by: https://github.com/awgu
2023-05-24 18:36:48 +00:00
medivh-xp
e06bd8f3b1 fsdp support create hybrid-sharded process group for custom backend (#100622)
FSDP creates communication groups for intra-node communication through dist.new_subgroups. Previously, dist.new_subgroups only supported creation based on the number of CUDA devices. However, issue #99706 removed the avaliable-check for CUDA devices, allowing for custom backend create group based on num of custom devices per node.

This PR allows FSDP to explicitly pass device num within the node when creating communication groups for intra-node communication, instead of defaulting to the number of CUDA devices.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100622
Approved by: https://github.com/awgu
2023-05-19 06:08:55 +00:00
medivh-xp
859e82a7a9 Making fsdp device-agnostic for custom-backend which implement cuda-semantics (#99024)
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.

This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.

The main issues addressed are:

#### 1. Device decision for FSDP wrapping of Modules without Parameters

Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.

#### 2. Abstraction of a cuda-like device

Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
2023-04-27 04:13:28 +00:00
Chien-Chin Huang
3de7fd461a [FSDP][Reland] Include duplicate parameters and modules when calling named_parameters and named_modules (#99448)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case.

Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448
Approved by: https://github.com/zhaojuanmao
2023-04-25 00:27:07 +00:00
Nikita Shulga
ccc5d1daec Revert D44897935: Multisect successfully blamed D44897935 for test or build failures (#99353)
Summary:
This diff is reverting D44897935
D44897935: [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912) by fegin has been identified to be causing the following test or build failures:

Tests affected:
- [caffe2/torch/fb/module_factory/sync_sgd/tests:test_pyper_data_parallel_wrapper - caffe2.torch.fb.module_factory.sync_sgd.tests.test_pyper_data_parallel_wrapper.PyPerDataParallelWrapperTest: test_fsdp_submodules_pyper](https://www.internalfb.com/intern/test/562950025957458/)

Here's the Multisect link:
https://www.internalfb.com/multisect/1893714
Here are the tasks that are relevant to this breakage:

We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

If you believe this diff has been generated in error you may Commandeer and Abandon it.

Test Plan: NA

Reviewed By: fegin

Differential Revision: D45027286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99353
Approved by: https://github.com/izaitsevfb, https://github.com/fegin
2023-04-17 20:53:10 +00:00
Chien-Chin Huang
8e328762ff [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

Differential Revision: [D44897935](https://our.internmc.facebook.com/intern/diff/D44897935/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98912
Approved by: https://github.com/awgu
2023-04-13 20:37:11 +00:00
medivh-xp
0962114802 Fix 'fully_shard' may determine compute device incorrectly (#98831)
Fixes #98829

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98831
Approved by: https://github.com/awgu
2023-04-11 22:42:48 +00:00
Kazuaki Ishizaki
6514d71add Fix typos under torch/distributed directory (#98225)
This PR fixes typos in comments and messages of `.py` files under `torch/distributed` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98225
Approved by: https://github.com/soulitzer, https://github.com/kit1980
2023-04-05 00:21:33 +00:00
Andrew Gu
66d07e3b19 [FSDP] Only move current FSDP's states to GPU during init (#98319)
Fixes https://github.com/pytorch/pytorch/issues/95813
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98319
Approved by: https://github.com/rohan-varma
2023-04-04 21:03:47 +00:00
Andrew Gu
10271a60a8 [FSDP] Skip _use_sharded_views() for SHARD_GRAD_OP (#98250)
This PR has `SHARD_GRAD_OP` (and `_HYBRID_SHARD_ZERO2`) skip `_use_sharded_views()` in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (`_use_sharded_views()`) and the pre-backward unshard (`_use_unsharded_views()`).

<details>
<summary>(Before) Pre-backward hook: 4.356 ms</summary>

<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">

</details>

<details>
<summary>(After) Pre-backward hook: 1.044 ms</summary>

![Screenshot 2023-04-04 at 9 05 53 AM](https://user-images.githubusercontent.com/31054793/229800917-9580ce6b-3721-469a-9212-f0cbfd8cbb52.png)

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98250
Approved by: https://github.com/rohan-varma
2023-04-04 17:07:28 +00:00
James Reed
3b1b585a59 [FSDP] Fix bug in determining whether parameters need to be materialized (#97488)
Previously, `_need_to_materialize_module` would return false because:

* `managed_params =_get_orig_params(module, ignored_params)` returns a generator
* `is_meta_module = any(param.is_meta for param in managed_params)` exhausts the generator in its check
* `any(fake.is_fake(param) for param in managed_params)` would try to iterate over the empty generator and get an empty sequence, thus returning `False`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97488
Approved by: https://github.com/ngimel, https://github.com/awgu
2023-03-25 08:24:57 +00:00
Rohan Varma
605a77fd59 Log FSDP mixed precision (#97367)
Log to clarify the mp config in jobs

Differential Revision: [D44307044](https://our.internmc.facebook.com/intern/diff/D44307044/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97367
Approved by: https://github.com/awgu
2023-03-24 16:01:59 +00:00
Andrew Gu
5ee230face [FSDP][1/N] Refactor module materialization (#94196)
**Overview**
This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables.

This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.

**Discussion**
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.

**For Reviewers**
- `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`).
- `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94196
Approved by: https://github.com/rohan-varma
2023-02-13 21:43:00 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Chien-Chin Huang
4b0f1cc1ee [FSDP][optim_state_dict][10/N] Make optim_state_dict and optim_state_dict_to_load public (#92118)
Make optim_state_dict and optim_state_dict_to_load public APIs and consolidate them with state_dict by using the same state_dict_type to decide how to perform the optimizer state_dict save and load.

Differential Revision: [D42488022](https://our.internmc.facebook.com/intern/diff/D42488022/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92118
Approved by: https://github.com/rohan-varma
2023-02-02 08:04:20 +00:00
Andrew Gu
63d6ee7d02 [FSDP][Easy] Remove outdated comment (#92739)
We pass `fully_sharded_module`, not `root_module`, after recent refactoring to unify composable and wrapper FSDP for now. This PR removes the comment explaining why before we passed in `root_module`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92739
Approved by: https://github.com/mrshenli
2023-01-23 15:52:49 +00:00
Andrew Gu
0d4bbd1996 [Lint] Add FSDP/composable API files to ufmt include (#90873)
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.

There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.

The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.

---

I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
2023-01-18 05:33:34 +00:00
Yanli Zhao
2004df9097 Remove python ddp (#91663)
As it is not used by anyone and also it is not maintained by PyTorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91663
Approved by: https://github.com/rohan-varma
2023-01-04 05:22:30 +00:00
Yanli Zhao
f613633124 Remove _ignored_param_names (#91530)
'_ignored_param_names' is only used in 'param_hook' during state_dict() post hook processing to check a parameter key needs to be cloned or not. But it is not needed, as state_dict() post hook only passes fsdp managed parameter keys to 'param_hook', see https://github.com/pytorch/pytorch/blob/master/torch/distributed/fsdp/_state_dict_utils.py#L203. That means the passed parameter keys are always not part of '_ignored_param_names'.

so we should be able to safely remove '_ignored_param_names' and related codes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91530
Approved by: https://github.com/rohan-varma
2022-12-31 03:28:22 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Shen Li
e5a48da664 Allow FSDP to have ignored modules out of wrapped root (#91079)
Motivations for this change:

1. TorchRec returns inconsistent results on `m.named_parameters()`
   and `m.m1.named_parameters()` if m1 is a `ShardedModule`. Basically,
   `ShardedModule` appears in `m.named_modules()`, but its parameters
   are not in `m.named_parameters()`. As a result, when we identify
   `ShardedModule` and pass them as `ignored_modules` to FSDP, FSDP
   complains about key error in `_get_ignored_params`.
2. If users are manually wrapping submodules with FSDP, it could be
   easier for them to keep a global set of ignored parameters, instead
   of create a new collection for every FSDP invocation.

Given the above two reasons, we allow FSDP to have ignored modules
out of the wrapped root module.

Differential Revision: [D42132394](https://our.internmc.facebook.com/intern/diff/D42132394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91079
Approved by: https://github.com/awgu
2022-12-19 14:28:25 +00:00
Andrew Gu
8cd1808dbf [FSDP] Introduce "fully sharded module"; remove comm. module (#90933)
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.

Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:

---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.

For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.

For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---

After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.

---
As an example, consider:
```
mod: Module(
  sub1: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
  sub2: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
2022-12-16 18:45:52 +00:00
Andrew Gu
1ba4e3c711 [FSDP][BE] Remove _module_to_handles, HandleConfig; use term "fqn"; clarify docs (#90840)
This PR
- Removes `_module_to_handles` since it is no longer used. We instead use `_comm_module_to_handles`.
- Removes `HandleConfig` and stores its fields directly as attributes on `FlatParamHandle`.
- Uses the term `fqn`/`fqns` uniformly in `flat_param.py` instead of `prefixed_param_name` / `prefixed_param_names`.
- Clarifies some documentation.

I am including all of these BE items in the same PR to save CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90840
Approved by: https://github.com/rohan-varma
2022-12-14 21:37:37 +00:00
Andrew Gu
93aee0cdc9 [FSDP][Easy] ufmt files (#90548)
```
ufmt format torch/distributed/fsdp
ufmt format test/distributed/fsdp
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90548
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-14 02:02:53 +00:00
Andrew Gu
fc429512d5 [FSDP] Clean up FlatParamHandle dtypes, post-backward hook (#90660)
This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.

**Overview**
This PR addresses everything in https://github.com/pytorch/pytorch/issues/90657 except renaming `keep_low_precision_grads` to `keep_grads_in_reduce_dtype` since that is BC breaking. I recommend reading the issue before preceding.

For `MixedPrecision(param_dtype, reduce_dtype, ...)`, the exact rule for parameter and gradient reduction mixed precision that we are following is:
> If `param_dtype is not None` and `reduce_dtype is None`, then we infer `reduce_dtype = param_dtype`. Otherwise, we take `param_dtype` and `reduce_dtype` as is.

This PR enforces that, at the `FlatParamHandle` level, `handle._config.fwd_bwd_param_dtype` and `handle._config.reduce_dtype` are never `None`. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in `handle._orig_param_dtype`. It is no longer to check against `None`.

This avoids ambiguous cases such as when the user passes `MixedPrecision(param_dtype=torch.float32)`. In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).

**Additional Details**
- We remove `FullyShardedDataParallel._mixed_precision_enabled_for_params`, `FullyShardedDataParallel._mixed_precision_enabled_for_reduce`, and `FullyShardedDataParallel._mixed_precision_keep_low_precision_grads` since they are not used.
- The unit test `test_meta_device_with_mixed_precision()` exercises a tricky edge case with meta device initialization, `apply()` (calling into `summon_full_params()`), and `param_dtype=torch.float32` for a nested wrapping case, where each nested instance has parameters.
- We include some minor fixes/improvements to the communication hook implementation.

**Follow-Ups**
- We should get rid of `HandleConfig` and store its fields as attributes on `FlatParamHandle` directly.
- Rename `keep_low_precision_grads` to `keep_grads_in_reduce_dtype`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90660
Approved by: https://github.com/zhaojuanmao
2022-12-13 07:34:59 +00:00
Andrew Gu
e7efeb5282 [FSDP] Save _stream_to_name for debugging (#90611)
This saves a data structure `_stream_to_name: Dict[torch.cuda.Stream, str]` that maps each FSDP stream to its name. This can help in debugging by checking `_stream_to_name[torch.cuda.current_stream()]` to see if it is `"default"` or `"unshard"` in the post-backward hook for example.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90611
Approved by: https://github.com/rohan-varma
2022-12-11 03:46:18 +00:00
Rohan Varma
793a999ce0 Hybrid Sharded Data Parallel (#89915)
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- @awgu 's excellent prototype: 5ad3a16d48
- @liangluofb For ideation, feedback, and initial implementation and experimentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89915
Approved by: https://github.com/awgu
2022-12-08 16:18:03 +00:00