Commit Graph

52 Commits

Author SHA1 Message Date
weifengpy
4bc846c101 [FSDP] Ignore buffer type casting in ignored modules (#106766)
issue resolved: https://github.com/pytorch/pytorch/issues/97791

before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce

after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
2023-08-09 23:09:43 +00:00
Jane Xu
7e47343d64 [BE] document more of FSDP checkpointing logic with a sprinkle of cleaning (#106069)
This PR should not make any functional difference. It:
- adds clearer documentation
- clarifies a type
- revises minor typos
- swaps a .keys for a .items call on a dictionary

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106069
Approved by: https://github.com/awgu
2023-08-02 17:19:04 +00:00
Andrew Gu
506b55fc29 [FSDP][Easy] Move _FSDPState attrs to avoid comment confusion (#106392)
Resubmit of https://github.com/pytorch/pytorch/pull/106333 after rebasing (I lost the original branch locally)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106392
Approved by: https://github.com/kwen2501
2023-08-01 20:39:22 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Albert Chen
7c8efc9049 [PT][FSDP] Combine _utils.py into _common_utils.py [2/2] (#106181)
Summary:
https://github.com/pytorch/pytorch/issues/97813
This diffs moves `_no_dispatch_record_stream` and `_same_storage_as_data_ptr`

Test Plan: CI

Differential Revision: D47706114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106181
Approved by: https://github.com/awgu
2023-07-28 17:15:25 +00:00
Albert Chen
b65b9e6ff4 [PT][FSDP] Combine _utils.py into _common_utils.py [1/3] (#105857)
Summary:
https://github.com/pytorch/pytorch/issues/97813

This diffs moves `_override_module_mixed_precision`

Test Plan: CI

Differential Revision: D47706059

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105857
Approved by: https://github.com/awgu
2023-07-25 17:37:08 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Iris
a02a58d862 [FSDP][1/N]Add device_mesh to FSDPstate (#102317) (#102551)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).
Approved by: https://github.com/awgu

Add device mesh to fsdp state
skip dist.get_world_size(pg) != dist.get_world_size()
address test_fake_pg.py test failure
fix test_fake_py.py failure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102551
Approved by: https://github.com/fegin
2023-06-07 04:14:00 +00:00
PyTorch MergeBot
81ac076bce Revert "[FSDP]Add device_mesh to FSDPstate (#102317)"
This reverts commit 4c584acc5d.

Reverted https://github.com/pytorch/pytorch/pull/102317 on behalf of https://github.com/malfet due to Broke test_fake_pg, see https://github.com/pytorch/pytorch/actions/runs/5100633726/jobs/9173277369  ([comment](https://github.com/pytorch/pytorch/pull/102317#issuecomment-1566129496))
2023-05-28 12:53:28 +00:00
Iris
4c584acc5d [FSDP]Add device_mesh to FSDPstate (#102317)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102317
Approved by: https://github.com/awgu
2023-05-27 20:25:30 +00:00
Edward Z. Yang
f65732552e Support FakeTensor with FlatParameter (#101987)
In this PR we turn FlatParameter into a virtual tensor subclass
which doesn't actually ever get instantiated: __new__ will create
a Parameter instead (or a FakeTensor, if necessary).

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101987
Approved by: https://github.com/awgu, https://github.com/eellison
2023-05-23 23:12:08 +00:00
medivh-xp
859e82a7a9 Making fsdp device-agnostic for custom-backend which implement cuda-semantics (#99024)
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.

This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.

The main issues addressed are:

#### 1. Device decision for FSDP wrapping of Modules without Parameters

Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.

#### 2. Abstraction of a cuda-like device

Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
2023-04-27 04:13:28 +00:00
Chien-Chin Huang
3de7fd461a [FSDP][Reland] Include duplicate parameters and modules when calling named_parameters and named_modules (#99448)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case.

Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448
Approved by: https://github.com/zhaojuanmao
2023-04-25 00:27:07 +00:00
Yanli Zhao
6ca991cacf [Composable API] Add fully_shard debug function to print sharded tree structure, module names and managed param fqns (#99133)
Adding a fully_shard debug function to print sharded tree structure like following format, return module names and their managed parameter fqns as well.

![Screenshot 2023-04-18 at 5 14 54 PM](https://user-images.githubusercontent.com/48731194/232931628-169a63a9-b4d5-4902-9cfd-f40113f3ec98.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99133
Approved by: https://github.com/rohan-varma
2023-04-19 19:27:43 +00:00
Nikita Shulga
ccc5d1daec Revert D44897935: Multisect successfully blamed D44897935 for test or build failures (#99353)
Summary:
This diff is reverting D44897935
D44897935: [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912) by fegin has been identified to be causing the following test or build failures:

Tests affected:
- [caffe2/torch/fb/module_factory/sync_sgd/tests:test_pyper_data_parallel_wrapper - caffe2.torch.fb.module_factory.sync_sgd.tests.test_pyper_data_parallel_wrapper.PyPerDataParallelWrapperTest: test_fsdp_submodules_pyper](https://www.internalfb.com/intern/test/562950025957458/)

Here's the Multisect link:
https://www.internalfb.com/multisect/1893714
Here are the tasks that are relevant to this breakage:

We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

If you believe this diff has been generated in error you may Commandeer and Abandon it.

Test Plan: NA

Reviewed By: fegin

Differential Revision: D45027286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99353
Approved by: https://github.com/izaitsevfb, https://github.com/fegin
2023-04-17 20:53:10 +00:00
Chien-Chin Huang
8e328762ff [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

Differential Revision: [D44897935](https://our.internmc.facebook.com/intern/diff/D44897935/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98912
Approved by: https://github.com/awgu
2023-04-13 20:37:11 +00:00
medivh-xp
0962114802 Fix 'fully_shard' may determine compute device incorrectly (#98831)
Fixes #98829

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98831
Approved by: https://github.com/awgu
2023-04-11 22:42:48 +00:00
Andrew Gu
c622559968 [FSDP][3/N] Minor fixes (rename, assert message) (#97663)
This is an easy PR.
- It renames `_shard_indices` to `_shard_param_indices` for consistency.
- It fixes an old mention of `comm_module` in an assert message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97663
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Chien-Chin Huang
f5a0b31a95 [FSDP][optim_state_dict] Make FSDP optim_state_dict aware of DDP prefix (#96415)
Summary: When wrapping FSDP within DDP, optimizer state_dict may be broken due to the prefix of DDP. This PR fixes the issue.

Test Plan: CI

Differential Revision: D43893609

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96415
Approved by: https://github.com/zhaojuanmao
2023-03-13 21:07:34 +00:00
Andrew Gu
6c30dc6cee [FSDP] Save _all_handles; _all_fsdp_states to root (#95465)
- The previous PR addressed one tree traversal in `_root_pre_forward()` but not the main one from `_get_fsdp_handles()` that runs for all settings.
- This PR saves `_all_handles` to cache `_get_fsdp_handles()` and `_all_fsdp_states` to cache `_get_fsdp_states()` (renamed from `_fsdp_states` compared to last PR) on the root state.
- This PR introduces a dummy `_RootFSDPState` class that inherits from `_FSDPState` to be used only for type checking since some attributes are only defined for root states.
    - I found this approach to be better than adding `_p_assert(state.root_only_attr is not None, ...)` upon each usage of `root_only_attr`.
    - This hopefully also helps readers to quickly see which attributes are defined only on root states.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95465
Approved by: https://github.com/fduwjj
2023-02-26 13:59:53 +00:00
Andrew Gu
9c45f47bbe [FSDP] Save _fsdp_states on root (#95343)
This saves an attribute `_fsdp_states: Optional[_FSDPState]`. For root, it is populated with all `_FSDPState`s in the root's tree. For non-root, it is `None`.

This is used to avoid doing the tree traversal during `_root_pre_forward()` when `forward_prefetch=True`.

Differential Revision: [D43536895](https://our.internmc.facebook.com/intern/diff/D43536895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95343
Approved by: https://github.com/fegin
2023-02-23 21:18:05 +00:00
Chien-Chin Huang
eb81e7ec22 [FSDP] Avoid printing incorrect warning for _get_param_to_fqns (#94494)
There exist a hack for `_get_param_to_fqns` and `_apply_to_modules`. The condition for the warning of the hack is incorrect and result in overwhelming message for users. This PR fixes the issue.

The original hack is not removed. It will once the support of DMP + FSDP is deprecated.

Differential Revision: [D43135611](https://our.internmc.facebook.com/intern/diff/D43135611/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94494
Approved by: https://github.com/rohan-varma
2023-02-12 17:09:30 +00:00
Yanli Zhao
e0c24ec2a5 Print fqn in the warning message (#94313)
Print fqn in the warning message, also make "else" match with the "if" in _apply_to_modules()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94313
Approved by: https://github.com/fegin
2023-02-08 06:45:53 +00:00
Chien-Chin Huang
e32d99ae19 [FSDP][optim_state_dict] Make FSDP.optim_state_dict compatbile with DMP (#93285)
`torchrec.DistributedModelParallel` overwrites `named_parameters` and is not compatible with `FullyShardedDataParallel`'s optim_state_dict. This PR adds some workaround in `FullyShardedDataParallel` to make both work together.

Differential Revision: [D42764611](https://our.internmc.facebook.com/intern/diff/D42764611/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93285
Approved by: https://github.com/rohan-varma
2023-02-02 23:42:54 +00:00
Andrew Gu
10990734ce [FSDP][2/N] _summon_full_params -> _unshard_params (#92297)
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
    - `summon_full_params()` core logic to `_unshard_params()`
    - `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
    - Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state

**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
    - We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
    - Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.

**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.

I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297
Approved by: https://github.com/rohan-varma
2023-02-02 15:10:14 +00:00
Chien-Chin Huang
4b0f1cc1ee [FSDP][optim_state_dict][10/N] Make optim_state_dict and optim_state_dict_to_load public (#92118)
Make optim_state_dict and optim_state_dict_to_load public APIs and consolidate them with state_dict by using the same state_dict_type to decide how to perform the optimizer state_dict save and load.

Differential Revision: [D42488022](https://our.internmc.facebook.com/intern/diff/D42488022/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92118
Approved by: https://github.com/rohan-varma
2023-02-02 08:04:20 +00:00
Yanli Zhao
2004df9097 Remove python ddp (#91663)
As it is not used by anyone and also it is not maintained by PyTorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91663
Approved by: https://github.com/rohan-varma
2023-01-04 05:22:30 +00:00
Chien-Chin Huang
0e8565d1d5 [FSDP][optim_state_dict][8/N] Enable fully_shard optim state_dict save and load (#91234)
**What does this PR do?**
This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91234
Approved by: https://github.com/rohan-varma
2022-12-30 06:56:44 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
e81ccfd1ed [FSDP][6/N] Add note explaining idioms for _FSDPState traversal (#90959)
This adds a note to explain how to do traversal in the new code base. These traversal helper methods were introduced in [1/N], [3/N], and [5/N].

I am working on updating the traversal helpers to account for other composable APIs (e.g. `replicate`). The rule is that the traversal should not proceed into an incompatible API's tree. This will be needed for `fully_shard` to be above `replicate`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90959
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
32fde53713 [FSDP][5/N] Add manual "wrapping" support for fully_shard (#90874)
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90874
Approved by: https://github.com/mrshenli
2022-12-20 16:49:15 +00:00
Andrew Gu
da9af9868e [FSDP][4/N] Refactor func to share state/init handle attrs (#90871)
For `limit_all_gathers`, if we do not enforce that they all have the same value, then the entire semantics guaranteed by the `bool` can be violated. It could be as if none of them set that value to be `True`.

For `use_orig_params`, optimizer state dict assumes that the value is the same for all FSDP instances.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90871
Approved by: https://github.com/mrshenli
2022-12-20 16:49:13 +00:00
Andrew Gu
8cd1808dbf [FSDP] Introduce "fully sharded module"; remove comm. module (#90933)
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.

Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:

---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.

For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.

For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---

After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.

---
As an example, consider:
```
mod: Module(
  sub1: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
  sub2: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
2022-12-16 18:45:52 +00:00
Andrew Gu
95ee5fecb1 [FSDP][1/N] Add _get_fsdp_states() (#90860)
- This PR introduces `_get_fsdp_states(module: nn.Module) -> List[_FSDPState]` to prepare for `fully_shard` manual "wrapping".
    - ~~I place it in `_runtime_utils.py`, not `_common_utils.py`, because in a follow-up PR, I will add `_get_root_fsdp_states()`, which requires `_lazy_init()`. I concluded that it would be preferred to have both of these getters be in the same place than to have them split, even if that means that `_get_fsdp_states()` is in `_runtime_utils.py`.~~ Due to circular import issues, I think I should still put it in `_common_utils.py`.
- This PR changes `FullyShardedDataParallel.fsdp_modules()` to be backed by `_get_fsdp_states()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90860
Approved by: https://github.com/rohan-varma
2022-12-16 12:15:42 +00:00
Andrew Gu
1ba4e3c711 [FSDP][BE] Remove _module_to_handles, HandleConfig; use term "fqn"; clarify docs (#90840)
This PR
- Removes `_module_to_handles` since it is no longer used. We instead use `_comm_module_to_handles`.
- Removes `HandleConfig` and stores its fields directly as attributes on `FlatParamHandle`.
- Uses the term `fqn`/`fqns` uniformly in `flat_param.py` instead of `prefixed_param_name` / `prefixed_param_names`.
- Clarifies some documentation.

I am including all of these BE items in the same PR to save CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90840
Approved by: https://github.com/rohan-varma
2022-12-14 21:37:37 +00:00
Chien-Chin Huang
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
Andrew Gu
b3d49c2fb8 [FSDP][1/N] fully_shard state dict (#90767)
Co-authored with @rohan-varma.

**Overview**
This adds preliminary `state_dict()` support for `fully_shard`.
- The only explicit branching between composable and wrapper code paths happens in the state dict hook registration, which is inevitable.
- We introduce a `_comm_module_prefix` to match the FQNs between the two code paths. This is needed since for composable, the FQNs are prefixed from the local FSDP root, whereas for state dict purposes, we want them to be prefixed from the comm. module. Thus, we need this `_comm_module_prefix` to be stripped during state dict.
    - In my understanding, the alternative to not use the `prefix` argument in `state_dict()` does not support the case when `fully_shard` is applied to a submodule (i.e. not the global root module) since we still need _part_ of `prefix` then.

**Follow-Ups**
- We can retire the `functools.partial` usage once @fegin's PR lands.
- We should add more thorough testing (e.g. sharded state dict, save and load together etc.).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90767
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2022-12-13 20:05:40 +00:00
Andrew Gu
45b40be078 [FSDP()] Fix fully_shard fwd hook registration (#90201)
I need to rebase later after Shen's PRs land.

The idea is to only register the pre/post-forward hook on the _root modules_ among the modules that consume a `FlatParameter`. (Yes, the term _root module_ is heavily overloaded. We may want to clarify that at some point. Here, _root_ is being used in the graph sense, meaning parent-less, and the scope is only among the modules consuming a `FlatParameter`.)

This avoids unnecessary pre/post-forward hooks running, which would lead to errors because the unshard is not truly idempotent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90201
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-06 06:09:03 +00:00
Chien-Chin Huang
3c7f96665e [FSDP][state_dict][3/N] Change how state_dict utils access attributes in _FSDPState (#88635)
**What This PR Does**
_state_dict_utils currently accesses the FSDP states through module. To enable composable FSDP state_dict, these accesses need to go through _FSDPState. module is still required for most APIs as state_dict has to access per-module information.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88635
Approved by: https://github.com/awgu
2022-11-11 15:20:36 +00:00
Andrew Gu
fc743ec059 [FSDP()] Have fully_shard() abide by @contract! (#88235)
We are making some progress on composability :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88235
Approved by: https://github.com/mrshenli
2022-11-03 13:41:54 +00:00
Andrew Gu
95a9721a15 [FSDP()][Easy] Rename _State to _FSDPState (#88234)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88234
Approved by: https://github.com/mrshenli
2022-11-03 11:29:01 +00:00
Andrew Gu
73de44fc56 [FSDP] Rename unflat_param_name -> fqn for consistency (#88123)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88123
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
32d22edc67 [FSDP()][27/N] Add forward hook registration (#88040)
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88040
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
2022-11-02 23:25:53 +00:00
Andrew Gu
0a752688bd [FSDP()][17/N] Refactor _fsdp_root_pre_forward() (#87930)
This PR moves `_fsdp_root_pre_forward()` to `_runtime_utils.py`.

Note: This PR includes a (temporary) fix for `NO_SHARD` + `CPUOffload(offload_params=True)`, where we set `non_blocking=False` when copying the gradient from device to host. It is only included in this PR since the test was **flaky** (but not consistently failing) on this PR , so I needed to fix to unblock land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87930
Approved by: https://github.com/mrshenli
2022-11-02 11:32:42 +00:00
Andrew Gu
1f34067e9d [FSDP()][16/N] Refactor post-forward/pre-backward (#87929)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87929
Approved by: https://github.com/mrshenli
2022-11-01 17:26:03 +00:00
Andrew Gu
90c5f856b2 [FSDP()][14/N] Refactor pre-forward/post-backward (#87927)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87927
Approved by: https://github.com/mrshenli
2022-11-01 17:25:59 +00:00
Andrew Gu
78170701a3 [FSDP()][9/N] Refactor ctor (continued) (#87923)
This PR makes a second pass over the constructor. The logic has been grouped into `_init_<...>` functions based on intent (e.g. `_init_prefetching_state()` or `_init_runtime_state()`). This makes the initialization code for composable FSDP much cleaner than having to re-write the same sequences of lower-level helper calls.

This PR also moves `_ExecOrderData` into its own file `_exec_order_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87923
Approved by: https://github.com/mrshenli
2022-11-01 12:39:21 +00:00
Andrew Gu
d89cf2fdc9 [FSDP()][7/N] Refactor most of ctor (#87921)
The goal of this PR is to make one pass over the FSDP constructor and refactor each helper method call to not be `self.<...>`. Subsequent PRs will make further passes over the FSDP constructor.

This PR looks like a lot of lines of code change, but it is only reorganization. Methods are moved to `_init_utils.py` and `_common_utils.py`. This also marks the beginning of moving methods from `_utils.py` to `_common_utils.py` -- they will be coalesced eventually. I am only using `_common_utils.py` as a staging ground to include the methods that have been affected by the refactoring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87921
Approved by: https://github.com/mrshenli
2022-10-31 16:45:24 +00:00