Commit Graph

34 Commits

Author SHA1 Message Date
Andrew Gu
c622559968 [FSDP][3/N] Minor fixes (rename, assert message) (#97663)
This is an easy PR.
- It renames `_shard_indices` to `_shard_param_indices` for consistency.
- It fixes an old mention of `comm_module` in an assert message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97663
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Chien-Chin Huang
f5a0b31a95 [FSDP][optim_state_dict] Make FSDP optim_state_dict aware of DDP prefix (#96415)
Summary: When wrapping FSDP within DDP, optimizer state_dict may be broken due to the prefix of DDP. This PR fixes the issue.

Test Plan: CI

Differential Revision: D43893609

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96415
Approved by: https://github.com/zhaojuanmao
2023-03-13 21:07:34 +00:00
Andrew Gu
6c30dc6cee [FSDP] Save _all_handles; _all_fsdp_states to root (#95465)
- The previous PR addressed one tree traversal in `_root_pre_forward()` but not the main one from `_get_fsdp_handles()` that runs for all settings.
- This PR saves `_all_handles` to cache `_get_fsdp_handles()` and `_all_fsdp_states` to cache `_get_fsdp_states()` (renamed from `_fsdp_states` compared to last PR) on the root state.
- This PR introduces a dummy `_RootFSDPState` class that inherits from `_FSDPState` to be used only for type checking since some attributes are only defined for root states.
    - I found this approach to be better than adding `_p_assert(state.root_only_attr is not None, ...)` upon each usage of `root_only_attr`.
    - This hopefully also helps readers to quickly see which attributes are defined only on root states.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95465
Approved by: https://github.com/fduwjj
2023-02-26 13:59:53 +00:00
Andrew Gu
9c45f47bbe [FSDP] Save _fsdp_states on root (#95343)
This saves an attribute `_fsdp_states: Optional[_FSDPState]`. For root, it is populated with all `_FSDPState`s in the root's tree. For non-root, it is `None`.

This is used to avoid doing the tree traversal during `_root_pre_forward()` when `forward_prefetch=True`.

Differential Revision: [D43536895](https://our.internmc.facebook.com/intern/diff/D43536895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95343
Approved by: https://github.com/fegin
2023-02-23 21:18:05 +00:00
Chien-Chin Huang
eb81e7ec22 [FSDP] Avoid printing incorrect warning for _get_param_to_fqns (#94494)
There exist a hack for `_get_param_to_fqns` and `_apply_to_modules`. The condition for the warning of the hack is incorrect and result in overwhelming message for users. This PR fixes the issue.

The original hack is not removed. It will once the support of DMP + FSDP is deprecated.

Differential Revision: [D43135611](https://our.internmc.facebook.com/intern/diff/D43135611/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94494
Approved by: https://github.com/rohan-varma
2023-02-12 17:09:30 +00:00
Yanli Zhao
e0c24ec2a5 Print fqn in the warning message (#94313)
Print fqn in the warning message, also make "else" match with the "if" in _apply_to_modules()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94313
Approved by: https://github.com/fegin
2023-02-08 06:45:53 +00:00
Chien-Chin Huang
e32d99ae19 [FSDP][optim_state_dict] Make FSDP.optim_state_dict compatbile with DMP (#93285)
`torchrec.DistributedModelParallel` overwrites `named_parameters` and is not compatible with `FullyShardedDataParallel`'s optim_state_dict. This PR adds some workaround in `FullyShardedDataParallel` to make both work together.

Differential Revision: [D42764611](https://our.internmc.facebook.com/intern/diff/D42764611/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93285
Approved by: https://github.com/rohan-varma
2023-02-02 23:42:54 +00:00
Andrew Gu
10990734ce [FSDP][2/N] _summon_full_params -> _unshard_params (#92297)
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
    - `summon_full_params()` core logic to `_unshard_params()`
    - `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
    - Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state

**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
    - We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
    - Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.

**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.

I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297
Approved by: https://github.com/rohan-varma
2023-02-02 15:10:14 +00:00
Chien-Chin Huang
4b0f1cc1ee [FSDP][optim_state_dict][10/N] Make optim_state_dict and optim_state_dict_to_load public (#92118)
Make optim_state_dict and optim_state_dict_to_load public APIs and consolidate them with state_dict by using the same state_dict_type to decide how to perform the optimizer state_dict save and load.

Differential Revision: [D42488022](https://our.internmc.facebook.com/intern/diff/D42488022/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92118
Approved by: https://github.com/rohan-varma
2023-02-02 08:04:20 +00:00
Yanli Zhao
2004df9097 Remove python ddp (#91663)
As it is not used by anyone and also it is not maintained by PyTorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91663
Approved by: https://github.com/rohan-varma
2023-01-04 05:22:30 +00:00
Chien-Chin Huang
0e8565d1d5 [FSDP][optim_state_dict][8/N] Enable fully_shard optim state_dict save and load (#91234)
**What does this PR do?**
This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91234
Approved by: https://github.com/rohan-varma
2022-12-30 06:56:44 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
e81ccfd1ed [FSDP][6/N] Add note explaining idioms for _FSDPState traversal (#90959)
This adds a note to explain how to do traversal in the new code base. These traversal helper methods were introduced in [1/N], [3/N], and [5/N].

I am working on updating the traversal helpers to account for other composable APIs (e.g. `replicate`). The rule is that the traversal should not proceed into an incompatible API's tree. This will be needed for `fully_shard` to be above `replicate`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90959
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
32fde53713 [FSDP][5/N] Add manual "wrapping" support for fully_shard (#90874)
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90874
Approved by: https://github.com/mrshenli
2022-12-20 16:49:15 +00:00
Andrew Gu
da9af9868e [FSDP][4/N] Refactor func to share state/init handle attrs (#90871)
For `limit_all_gathers`, if we do not enforce that they all have the same value, then the entire semantics guaranteed by the `bool` can be violated. It could be as if none of them set that value to be `True`.

For `use_orig_params`, optimizer state dict assumes that the value is the same for all FSDP instances.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90871
Approved by: https://github.com/mrshenli
2022-12-20 16:49:13 +00:00
Andrew Gu
8cd1808dbf [FSDP] Introduce "fully sharded module"; remove comm. module (#90933)
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.

Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:

---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.

For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.

For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---

After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.

---
As an example, consider:
```
mod: Module(
  sub1: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
  sub2: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
2022-12-16 18:45:52 +00:00
Andrew Gu
95ee5fecb1 [FSDP][1/N] Add _get_fsdp_states() (#90860)
- This PR introduces `_get_fsdp_states(module: nn.Module) -> List[_FSDPState]` to prepare for `fully_shard` manual "wrapping".
    - ~~I place it in `_runtime_utils.py`, not `_common_utils.py`, because in a follow-up PR, I will add `_get_root_fsdp_states()`, which requires `_lazy_init()`. I concluded that it would be preferred to have both of these getters be in the same place than to have them split, even if that means that `_get_fsdp_states()` is in `_runtime_utils.py`.~~ Due to circular import issues, I think I should still put it in `_common_utils.py`.
- This PR changes `FullyShardedDataParallel.fsdp_modules()` to be backed by `_get_fsdp_states()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90860
Approved by: https://github.com/rohan-varma
2022-12-16 12:15:42 +00:00
Andrew Gu
1ba4e3c711 [FSDP][BE] Remove _module_to_handles, HandleConfig; use term "fqn"; clarify docs (#90840)
This PR
- Removes `_module_to_handles` since it is no longer used. We instead use `_comm_module_to_handles`.
- Removes `HandleConfig` and stores its fields directly as attributes on `FlatParamHandle`.
- Uses the term `fqn`/`fqns` uniformly in `flat_param.py` instead of `prefixed_param_name` / `prefixed_param_names`.
- Clarifies some documentation.

I am including all of these BE items in the same PR to save CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90840
Approved by: https://github.com/rohan-varma
2022-12-14 21:37:37 +00:00
Chien-Chin Huang
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
Andrew Gu
b3d49c2fb8 [FSDP][1/N] fully_shard state dict (#90767)
Co-authored with @rohan-varma.

**Overview**
This adds preliminary `state_dict()` support for `fully_shard`.
- The only explicit branching between composable and wrapper code paths happens in the state dict hook registration, which is inevitable.
- We introduce a `_comm_module_prefix` to match the FQNs between the two code paths. This is needed since for composable, the FQNs are prefixed from the local FSDP root, whereas for state dict purposes, we want them to be prefixed from the comm. module. Thus, we need this `_comm_module_prefix` to be stripped during state dict.
    - In my understanding, the alternative to not use the `prefix` argument in `state_dict()` does not support the case when `fully_shard` is applied to a submodule (i.e. not the global root module) since we still need _part_ of `prefix` then.

**Follow-Ups**
- We can retire the `functools.partial` usage once @fegin's PR lands.
- We should add more thorough testing (e.g. sharded state dict, save and load together etc.).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90767
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2022-12-13 20:05:40 +00:00
Andrew Gu
45b40be078 [FSDP()] Fix fully_shard fwd hook registration (#90201)
I need to rebase later after Shen's PRs land.

The idea is to only register the pre/post-forward hook on the _root modules_ among the modules that consume a `FlatParameter`. (Yes, the term _root module_ is heavily overloaded. We may want to clarify that at some point. Here, _root_ is being used in the graph sense, meaning parent-less, and the scope is only among the modules consuming a `FlatParameter`.)

This avoids unnecessary pre/post-forward hooks running, which would lead to errors because the unshard is not truly idempotent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90201
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-06 06:09:03 +00:00
Chien-Chin Huang
3c7f96665e [FSDP][state_dict][3/N] Change how state_dict utils access attributes in _FSDPState (#88635)
**What This PR Does**
_state_dict_utils currently accesses the FSDP states through module. To enable composable FSDP state_dict, these accesses need to go through _FSDPState. module is still required for most APIs as state_dict has to access per-module information.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88635
Approved by: https://github.com/awgu
2022-11-11 15:20:36 +00:00
Andrew Gu
fc743ec059 [FSDP()] Have fully_shard() abide by @contract! (#88235)
We are making some progress on composability :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88235
Approved by: https://github.com/mrshenli
2022-11-03 13:41:54 +00:00
Andrew Gu
95a9721a15 [FSDP()][Easy] Rename _State to _FSDPState (#88234)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88234
Approved by: https://github.com/mrshenli
2022-11-03 11:29:01 +00:00
Andrew Gu
73de44fc56 [FSDP] Rename unflat_param_name -> fqn for consistency (#88123)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88123
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
32d22edc67 [FSDP()][27/N] Add forward hook registration (#88040)
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88040
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
2022-11-02 23:25:53 +00:00
Andrew Gu
0a752688bd [FSDP()][17/N] Refactor _fsdp_root_pre_forward() (#87930)
This PR moves `_fsdp_root_pre_forward()` to `_runtime_utils.py`.

Note: This PR includes a (temporary) fix for `NO_SHARD` + `CPUOffload(offload_params=True)`, where we set `non_blocking=False` when copying the gradient from device to host. It is only included in this PR since the test was **flaky** (but not consistently failing) on this PR , so I needed to fix to unblock land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87930
Approved by: https://github.com/mrshenli
2022-11-02 11:32:42 +00:00
Andrew Gu
1f34067e9d [FSDP()][16/N] Refactor post-forward/pre-backward (#87929)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87929
Approved by: https://github.com/mrshenli
2022-11-01 17:26:03 +00:00
Andrew Gu
90c5f856b2 [FSDP()][14/N] Refactor pre-forward/post-backward (#87927)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87927
Approved by: https://github.com/mrshenli
2022-11-01 17:25:59 +00:00
Andrew Gu
78170701a3 [FSDP()][9/N] Refactor ctor (continued) (#87923)
This PR makes a second pass over the constructor. The logic has been grouped into `_init_<...>` functions based on intent (e.g. `_init_prefetching_state()` or `_init_runtime_state()`). This makes the initialization code for composable FSDP much cleaner than having to re-write the same sequences of lower-level helper calls.

This PR also moves `_ExecOrderData` into its own file `_exec_order_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87923
Approved by: https://github.com/mrshenli
2022-11-01 12:39:21 +00:00
Andrew Gu
d89cf2fdc9 [FSDP()][7/N] Refactor most of ctor (#87921)
The goal of this PR is to make one pass over the FSDP constructor and refactor each helper method call to not be `self.<...>`. Subsequent PRs will make further passes over the FSDP constructor.

This PR looks like a lot of lines of code change, but it is only reorganization. Methods are moved to `_init_utils.py` and `_common_utils.py`. This also marks the beginning of moving methods from `_utils.py` to `_common_utils.py` -- they will be coalesced eventually. I am only using `_common_utils.py` as a staging ground to include the methods that have been affected by the refactoring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87921
Approved by: https://github.com/mrshenli
2022-10-31 16:45:24 +00:00
Andrew Gu
9d9267c6f7 [FSDP()][3/N] Refactor public APIs (#87917)
- This PR defines a new `api.py` meant to hold the public API for FSDP (minus `FullyShardedDataParallel` itself). This is needed because several of the `_<...>_utils.py` files rely on the public API, and we cannot import from `torch.distributed.fsdp.fully_sharded_data_parallel` without a circular import. Calling the file `api.py` follows the convention used by `ShardedTensor`.
- This PR cleans up the wording in the `BackwardPrefetch`, `ShardingStrategy`, `MixedPrecision`, and `CPUOffload` docstrings.
- This PR adds the aforementioned classes to `fsdp.rst` to have them rendered in public docs.
- To abide by the public bindings contract (`test_public_bindings.py`), the aforementioned classes are removed from `fully_sharded_data_parallel.py`'s `__all__`. This is technically BC breaking if someone uses `from torch.distributed.fsdp.fully_sharded_data_parallel import *`; however, that does not happen in any of our own external or internal code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87917
Approved by: https://github.com/mrshenli
2022-10-31 16:45:21 +00:00
Andrew Gu
e667c00656 [FSDP()][2/N] Refactor training state (#87916)
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`.
- At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision).
- At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87916
Approved by: https://github.com/mrshenli
2022-10-29 06:50:30 +00:00