Commit Graph

48 Commits

Author SHA1 Message Date
Chien-Chin Huang
ab4fe01e72 [FSDP][optim_state_dict] Returns the initial states of the empty parameters for KeyedOptimizer/NamedOptimizer (#94130)
KeyedOptimizer and NamedOptimizer expect the states exist in the state_dict when `load_state_dict` is called even if the corresponding parameters are empty (size == 0). This PR adds the support to make KeyedOptimizer work with `use_orig_params=True`.

Differential Revision: [D43019458](https://our.internmc.facebook.com/intern/diff/D43019458/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94130
Approved by: https://github.com/rohan-varma
2023-02-07 23:36:56 +00:00
Chien-Chin Huang
bc6d54f6d8 [FSDP][optim_state_dict] Let optim_state_dict ignore the non-FSDP managed parameters that do not reside on the rank (#94129)
When FSDP is used with other parallelism (e.g., TorchRec), some parameters that are not managed by FSDP may not reside on all the ranks (TorchRec is model parallelism). When `use_orig_params=True` , FSDP will synchronize the FQNs among ranks. As a result, a rank may get the FQNs that the rank does not actually own. If the FQN belongs to a TorchRec managed parameter, FSDP has to ignore the parameter state. Otherwise FSDP does not know how to store the state.

This PR add the logic to ignore the parameters that are not managed by FSDP and are not on the rank.

Differential Revision: [D42982778](https://our.internmc.facebook.com/intern/diff/D42982778/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94129
Approved by: https://github.com/rohan-varma
2023-02-07 06:29:28 +00:00
Chien-Chin Huang
0f5b6caa16 [FSDP][optim_state_dict] Ignore the state check on rank that does not own the corresponding parameter (#93318)
When a rank does not own a parameter (parameter.numel() == 0), its optim state is not valid and should not be checked against the current saved one.

Differential Revision: [D42865237](https://our.internmc.facebook.com/intern/diff/D42865237/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93318
Approved by: https://github.com/rohan-varma
2023-02-03 00:50:04 +00:00
Chien-Chin Huang
e32d99ae19 [FSDP][optim_state_dict] Make FSDP.optim_state_dict compatbile with DMP (#93285)
`torchrec.DistributedModelParallel` overwrites `named_parameters` and is not compatible with `FullyShardedDataParallel`'s optim_state_dict. This PR adds some workaround in `FullyShardedDataParallel` to make both work together.

Differential Revision: [D42764611](https://our.internmc.facebook.com/intern/diff/D42764611/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93285
Approved by: https://github.com/rohan-varma
2023-02-02 23:42:54 +00:00
Andrew Gu
10990734ce [FSDP][2/N] _summon_full_params -> _unshard_params (#92297)
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
    - `summon_full_params()` core logic to `_unshard_params()`
    - `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
    - Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state

**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
    - We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
    - Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.

**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.

I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297
Approved by: https://github.com/rohan-varma
2023-02-02 15:10:14 +00:00
Chien-Chin Huang
888771dc5d [FSDP][optim_state_dict] Fix _is_named_optimizer when the state is empty (#93303)
Optimizer state is not eager initializaion -- only NamedOptimizer and KeyedOptimizer are. This PR makes it `_is_named_optimizer` work with regular optimizers.

Differential Revision: [D42858589](https://our.internmc.facebook.com/intern/diff/D42858589/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93303
Approved by: https://github.com/fduwjj
2023-01-31 03:49:26 +00:00
Chien-Chin Huang
a4238976a8 [FSDP][optim_state_dict] Ensure correct devices for tensors when doing all_gather (#92992)
When doing `_all_gather_optim_state`, we need to ensure that `step` tensors are  on CPU and other tensors are on GPUs. This PR add the logic to ensure the locality.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92992
Approved by: https://github.com/fduwjj
2023-01-27 06:50:36 +00:00
Chien-Chin Huang
8b1b47c36a [FSDP][optim_state_dict] Use all_gather to deal with uneven size tensors (#92991)
The current `_all_gather_optim_state` pads the uneven tensors which is not necessary as `all_gather` support the uneven tensors. This PR removes the padding logic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92991
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2023-01-27 06:46:44 +00:00
Chien-Chin Huang
8f294f785f [FSDP][optim_state_dict] Fix the conditions to check non-parameter associated states (#92744)
If a state is not associated with any parameter, `FSDP.optim_state_dict` should still save it. The current implementation to determine whether a state is associated with a parameter is not completely correct and can cause `use_orig_params=True` have extra states.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92744
Approved by: https://github.com/awgu
2023-01-23 17:40:50 +00:00
Chien-Chin Huang
92d412d684 [FSDP][optim_state_dict][11/N] Let FSDP support NamedOptimizer/KeyedOptimizer when use_orig_params is False (#92184)
Current design of FSDP only support NamedOptimizer/KeyedOptimizer when use_orig_params is True this PR adds the support even if use_orig_params if False. This PR also adds the support for user-defined optimizer states -- states that are not associated with any particular parameters.

Differential Revision: [D42497416](https://our.internmc.facebook.com/intern/diff/D42497416/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92184
Approved by: https://github.com/colin2328, https://github.com/rohan-varma
2023-01-18 21:24:30 +00:00
Chien-Chin Huang
1439cb0314 [FSDP][optim_state_dict][9/N] Rewrite the all-gather flow of optimizer state to support older GPUs (#91343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91343
Approved by: https://github.com/rohan-varma
2023-01-17 17:21:19 +00:00
Rohan Varma
a155f64957 Update _optim_utils.py (#91935)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91935
Approved by: https://github.com/awgu, https://github.com/fegin
2023-01-11 22:06:26 +00:00
Chien-Chin Huang
0e8565d1d5 [FSDP][optim_state_dict][8/N] Enable fully_shard optim state_dict save and load (#91234)
**What does this PR do?**
This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91234
Approved by: https://github.com/rohan-varma
2022-12-30 06:56:44 +00:00
Chien-Chin Huang
6cea4f3d57 [FSDP][optim_state_dict][7/N] Make FSDP support NamedOptimizer (#91160)
**What does this PR do?**
This PR refactors FSDP optimizer state_dict APIs to accept `NamedOptimizer` as the input optimizer. The key difference is that the state_dict returned by `NamedOptimizer` is already keyed as FQN. This PR majorly changes the internal mapping to allows the optimizer state_dict to be keyed as FQN.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91160
Approved by: https://github.com/fduwjj, https://github.com/rohan-varma
2022-12-22 04:35:26 +00:00
Chien-Chin Huang
1ab6ac4682 [FSDP][optim_state_dict][6/N] Refactor the optim_state_dict APIs to support hooks (#90798)
**What does this PR do?**

This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add `_optim_state_dict_post_hook` and `_load_optim_state_dict_pre_hook` for the integration with `NamedOptimzer`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90798
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-12-21 21:38:14 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
39d9dd135a [FSDP][Easy] ufmt files (#90858)
```
ufmt format torch/distributed/fsdp
ufmt format test/distributed/fsdp
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90858
Approved by: https://github.com/rohan-varma
2022-12-15 04:15:26 +00:00
Chien-Chin Huang
4a2d64994c [FSDP][optim_state_dict][4/N] Remove the unused _get_flat_param_to_fsdp_module API (#89980)
This is an easy PR, just remove an unused internal API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89980
Approved by: https://github.com/awgu
2022-12-13 21:01:46 +00:00
Chien-Chin Huang
043de8d1b1 [FSDP][optim_state_dict][3/N] Support use_orig_param optim_state_dict (non-broadcast version) (#89900)
**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89900
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2022-12-13 20:45:21 +00:00
Chien-Chin Huang
44779d9bc6 [FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_info to map from original FQN to flat_param (#89899)
**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89899
Approved by: https://github.com/awgu
2022-12-07 19:40:47 +00:00
Ram Rachum
351d73b97f Fix exception causes all over the codebase (#90271)
This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
2022-12-07 04:29:00 +00:00
Chien-Chin Huang
72fdfad4ad [FSDP][optim_state_dict][1/N] Restructure _optim_state_dict to prepare the support of use_orig_param (#89898)
**Motivation:**
Restructure some APIs in _optim_state_dict.py to allow better future extension, mostly for supporting use_orig_params. NO logic change in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89898
Approved by: https://github.com/awgu
2022-12-05 21:01:48 +00:00
Chien-Chin Huang
ae4074669e [FSDP][state_dict][6/N] Remove most FSDP module dependency from _optim_utils (#88638)
**What**
This PR removes most `FullyShardedDataParallel` dependencies from `optim_utils`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88638
Approved by: https://github.com/awgu
2022-11-12 03:16:37 +00:00
Andrew Gu
a689502275 [FSDP] Do not include empty state in _flatten_optim_state_dict() (#88353)
983c0e7f31/torch/optim/adam.py (L163)
The above line requires that a candidate optimizer state dict being loaded via `load_state_dict()` has non-empty state for its 0th parameter (via `state_values[0]`). This PR changes FSDP to only include non-empty mappings in the state returned by `_flatten_optim_state_dict()`, which is the subroutine for both `shard_full_optim_state_dict()` and `flatten_sharded_optim_state_dict()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88353
Approved by: https://github.com/fegin
2022-11-03 11:33:10 +00:00
Andrew Gu
73de44fc56 [FSDP] Rename unflat_param_name -> fqn for consistency (#88123)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88123
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
bf2819a836 [FSDP()][24/N] Refactor _lazy_init() (#87939)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87939
Approved by: https://github.com/zhaojuanmao
2022-11-02 16:35:47 +00:00
Andrew Gu
cbc9faebfe [FSDP()][1/N] Start refactoring FSDP root pre-forward (#87915)
Welcome! This PR starts the refactoring journey.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87915
Approved by: https://github.com/mrshenli
2022-10-29 06:50:30 +00:00
Andrew Gu
e3cf81e0a7 [FSDP] ufmt /fsdp (#87811)
This applies `ufmt` to all of the FSDP files in the `torch/distributed/fsdp/` directory.

**Test Plan**
CI

**Notes**
For VSCode users,
- Install `ufmt`: https://pypi.org/project/ufmt/
- Install VSCode `ufmt` extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
- Include in `settings.json`:
```
{
    "[python]": {
        "editor.defaultFormatter": "omnilib.ufmt",
        "editor.formatOnSave": true,
    },
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87811
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2022-10-27 04:25:55 +00:00
Rohan Varma
701b3dd773 optim utils all_gather_into_tensor (#87769)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87769
Approved by: https://github.com/awgu
2022-10-26 16:20:46 +00:00
Andrew Gu
be682befbc [FSDP] Add use_orig_params (#84911)
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
See 85831 (removing link to avoid spamming the issue whenever I update this PR).

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84911
Approved by: https://github.com/rohan-varma
2022-10-07 18:07:17 +00:00
Chien-Chin Huang
2067b768fc [FSDP] Delay moving tensor to CPU until necessary for optim_state_dict() (#85761)
Optimizer state_dict currently move tensors to CPU() immediately after allgather(). However, for sharded optimizer state_dict, this moving is duplicated. We should wait until all the sharding are done. This PR may slightly reduce the performance of full optimizer state_dict as it has to allocate more memory than w/o this PR. But the benchmark shows the memory allocation is pretty light.

Differential Revision: [D39855912](https://our.internmc.facebook.com/intern/diff/D39855912/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39855912/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85761
Approved by: https://github.com/rohan-varma
2022-10-03 17:23:23 +00:00
Andrew Gu
ff71f45788 [FSDP] Add FSDPExtensions for TP support (#85039)
This adds `FSDPExtensions` to enable TP + FSDP composability. To be agnostic to both `ShardedTensor` and `DistributedTensor`, the design relies on customizable hooks.

Some notes:
- I preferred the `_ext` prefix (short for "extension") over `_param_extension` simply because it is shorter. It should not matter much because it is purely internal facing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85039
Approved by: https://github.com/kumpera, https://github.com/fegin
2022-09-28 18:34:17 +00:00
Andrew Gu
c6c3346d5a [FSDP] Short-term fix to remove optim_input (#84201)
This is a short-term quick fix to accommodate using the existing optimizer state APIs without passing `optim_input`. It preserves the existing `optim_input` code path but if `optim_input` is `None` while `optim` is not, then the APIs will use the new code path that relies on `self.param_groups` to get the information previously provided by `optim_input`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84201
Approved by: https://github.com/rohan-varma
2022-09-16 21:24:15 +00:00
Andrew Gu
39676a977f [FSDP][Easy] Save unpadded/padded unsharded sizes as attributes (#84366)
Differential Revision: [D39331199](https://our.internmc.facebook.com/intern/diff/D39331199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84366
Approved by: https://github.com/rohan-varma
2022-09-13 17:09:20 +00:00
Andrew Gu
afcc7c7f5c [FSDP] Generalize prefetching; lower unshard/reshard to handle (#83665)
### Additional Constructor Changes
- `self.sharding_strategy`
    - If the world size is 1, I clamp the sharding strategy to `NO_SHARD`, regardless of the passed-in sharding strategy, since the behavior is fully equivalent. This absolves the need for `p._is_sharded or self.world_size == 1` checks in the core code. Once we fully shift the paradigm to using handles, this should result in a clear net positive. However, for now, we still have some places where we interface directly with the `FlatParameter`, in which case we have some temporary hacky code.
- `HandleConfig`
    - As a part of the new design abstraction, much logic is lowered to the `FlatParamHandle`. This requires the handle be aware of mixed precision, CPU offloading, sharding strategy, and the process group (for world size > 1). To be less error-prone, I re-defined the `dataclass`s and `enum`s for the handle. These can be removed and coalesced with the existing ones.
    - The drawback is that the `FlattenParamsWrapper` constructor now takes in the `HandleConfig` to forward it to the `FlatParamHandle` constructor. I tolerate this since we plan to retire the FPW. For now, the handle's process group attributes are set later when we call `handle.shard()`.
    - We will dive into this logic lowering later. For now, the idea is we need to pass some extra info to the handle, which must go through the FPW.
- `FullyShardedDataParallel._shard_parameters()` -> `FlatParamHandle.shard()`
- [Important] Generalizing attributes to remove the 1 `FullyShardedDataParallel` : 1 `FlatParameter` assumption
    - **Before:** `_fsdp_graph_order`, `_pre_backward_hook_full_params_prefetched`, `_forward_full_params_prefetched`, `reshard_after_forward` are with respect to 1 `FullyShardedDataParallel`
    - **After:** (1) We use `FlatParamHandle` in place of `FullyShardedDataParallel`. (2) The atomic unit for forward and pre-backward is a _group_ of handles involved in the same module's forward/pre-backward. This is represented as `Tuple[FlatParamHandle, ...]`. For now, this is **always a singleton tuple**, but this shift enables a module having multiple FSDP parameters (which we have use cases for).
- `_reset_lazy_init()` attributes
    - The prefetched flags are merged into `self._handles_prefetched`, which is directly defined in the constructor. `reshard_after_forward` is retired since it can be fully determined by other attributes (`_is_root` and `sharding_strategy`).

## FSDP Runtime: Unshard

The first step is to read the existing `_rebuild_full_params()`. A few notable observations:
- It returns `Tuple[Tensor, bool]`. The first element is the _padded unsharded flattened parameter_, and the second element is whether we can free it upon exiting `summon_full_params()`. This return value is **only used in `summon_full_params()`**.
- If parameter mixed precision is enabled and the `FlatParameter` is already unsharded, then the low precision shard (`_mp_shard`) is still re-allocated on GPU. (It is freed at the end of the method.)
- If CPU offloading is enabled and the `FlatParameter` is already unsharded, then there is a no-op `p.data = p.data.to(self.compute_device, non_blocking=True)`.
- Inside `summon_full_params()`, `mixed_precision_cast_ran` is always `False`. Therefore, the return value for the `not p._is_sharded and mixed_precision_cast_ran` branch is unused.
-`summon_full_params()` can only be called (before forward or after backward) or (between forward and backward). Given this, I cannot think of a case where we call `summon_full_params()`, the `FlatParameter` is already unsharded, but `reshard_after_forward` is `True`. The `FlatParameter` should be sharded (before forward or after backward), and the `FlatParameter` may only be unsharded (between forward and backward) if `reshard_after_forward` is `False`.
- If parameter mixed precision is enabled and the sharding strategy is a sharded one, then inside `summon_full_params()`, the `FlatParameter` is unsharded in full precision. This involves allocating a new padded unsharded flattened parameter on GPU in full precision since `_full_param_padded` is in the low precision.

Some comments:
- Ideally, we reduce the complexity of the core code path: i.e. unshard for forward and pre-backward. If the return value is only used for `summon_full_params()`, we should consider if we can compartmentalize that logic.
- The branching is complex, and some return values are never used, where this fact is not immediately obvious. We should see if we can reduce the branch complexity.

Disclaimer: The difference in attribute semantics between `NO_SHARD` and the sharded strategies makes it challenging to unify the cases. This PR does not attempt to address that since it requires more design thought. However, it does attempt to reduce the complexity for the sharded strategies.

### Unshard: Core Code Path
Let us trace through the new logical unshard.
1. `FullyShardedDataParallel._unshard(self, handles: List[FlatParamHandle], prepare_gradient: bool)`
    - This iterates over the handles and calls `handle.pre_unshard()`, `handle.unshard()`, and `handle.post_unshard(prepare_gradient)` in the all-gather stream.
2. `FlatParamHandle.needs_unshard(self)`
    - We take an aside to look at this key subroutine.
    - For `NO_SHARD`, this returns `False`.
    - For sharded strategies, this checks if the padded unsharded flattened parameter is allocated. The padded unsharded flattened parameter is the base tensor for the unpadded unsharded flattened parameter, which is a view into the padded one. Thus, the padded one's allocation fully determines if the `FlatParameter` is unsharded.
    - For sharded strategies, to accommodate the parameter mixed precision + `summon_full_params()` case, we introduce `_full_prec_full_param_padded`, which is the padded unsharded flattened parameter in full precision. The helper `_get_padded_unsharded_flat_param()` takes care of this casing and returns the padded unsharded flattened parameter. Instead of allocating a new tensor each time, we manually manage `_full_prec_full_param_padded`'s storage just like for `_full_param_padded`.
3. `FlatParamHandle.pre_unshard(self)`
    - For sharded strategies, the postcondition is that the handle's `FlatParameter` points to the tensor to all-gather. This should be on the communication device and in the desired precision. The allocation and usage of the low precision shard for parameter mixed precision and the CPU -> GPU copy for CPU offloading both classify naturally in the pre-unshard.
    - For sharded strategies, if the `FlatParameter` does not need to be unsharded, `pre_unshard()` is a no-op. This avoids unnecessarily allocating and freeing the low precision shard.
    - For `NO_SHARD`, we simply preserve the existing semantics.
4. `FlatParamHandle.unshard(self)`
    - If the handle was resharded without freeing the padded unsharded flattened parameter (e.g. `summon_full_params()` between forward and backward when `reshard_after_forward=False`), then the `FlatParameter` points to the sharded flattened parameter. We need to switch to using the unsharded parameter. This is a design choice. Alternatively, we may not switch to using the sharded flattened parameter in `reshard()` if we do not free the padded unsharded flattened parameter. However, the postcondition that the `FlatParameter` points to the sharded flattened parameter after `reshard()` is helpful logically, so I prefer this approach.
    - Otherwise, this allocates the padded unsharded flattened parameter, all-gathers, and switches to using the unpadded unsharded flattened parameter.
    - In the future, we may add an option to `unshard()` that additionally all-gathers the gradient.
5. `FlatParamHandle.post_unshard(self, prepare_gradient: bool)`
    - For sharded strategies, if using parameter mixed precision, this frees the low precision shard. More generally, this should free any sharded allocations made in `pre_unshard()` since the all-gather has been launched. If using CPU offloading, the GPU copy of the local shard goes out of scope after `unshard()` and is able to be garbage collected. **We should understand if there is any performance difference between manually freeing versus deferring to garbage collection since our usage is inconsistent.** For now, I preserve the existing semantics here.
    - `prepare_gradient` is meant to be set to `True` for the pre-backward unshard and `False` for the forward unshard. This runs the equivalent logic of `_prep_grads_for_backward()`.
    - This post-unshard logic (notably the gradient preparation) now runs in the all-gather stream, which is fine because we always have the current stream wait for the all-gather stream immediately after `FullyShardedDataParallel._unshard()`. IIUC, we do not need to call `_mp_shard.record_stream(current_stream)` (where `current_stream` is the default stream) because `_mp_shard` is allocated and freed in the same (all-gather) stream.
    - A postcondition is that the `FlatParameter` is on the compute device. It should also have the unpadded unsharded size (though I do not have a check for this at the moment).

### Unshard: `summon_full_params()`
Now that we see how the logical unshard has been reorganized for the core code path, let us dive into `summon_full_params()`.

The two constraints are:
1. If using parameter mixed precision, we should unshard in full precision.
2. We must determine if we should free the padded unsharded flattened parameter upon exiting.

The first constraint is addressed as described before in the core unshard code path, so it remains to explore the second constraint.

I propose a simple rule: **We free iff we actually unshard the `FlatParameter` in `summon_full_params()`** (i.e. it was not already unsharded). We perform a case analysis:

**Parameter mixed precision enabled:**
* `NO_SHARD`: `flat_param.data` points to `flat_param._local_shard`, which is the full precision unsharded flattened parameter. This is **not safe to free**.
* `FULL_SHARD` / `SHARD_GRAD_OP`: We force full precision and all-gather to `_full_prec_full_param_padded`. We do not support `nested summon_full_params()`, so `_full_prec_full_param_padded` must be unallocated. We unshard, and it is **safe to free**.

**Parameter mixed precision disabled:**
* `NO_SHARD`: This is the same as with mixed precision enabled. This is **not safe to free**.
* `FULL_SHARD` / `SHARD_GRAD_OP`: We all-gather to `_full_param_padded`. It may already be unsharded.
    * Already unsharded: The unshard is a no-op. This is **not safe to free**.
        * For `FULL_SHARD`, this can happen for the root FSDP instance after `forward()` but before backward.
        * For `SHARD_GRAD_OP`, this can happen for all FSDP instances after `forward()` but before backward.
    * Needs unshard: We unshard. This is **safe to free**.

Therefore, we see that it is not safe to free when using `NO_SHARD` and when using a sharded strategy but the `FlatParameter` is already unsharded. This is precisely the proposed rule.

There were two notable edge cases that the existing code did not address.
1. The existing code tests if the `FlatParameter` is already unsharded by checking the allocation status of `_full_param_padded`. When using parameter mixed precision, this is the incorrect tensor to check. If `_full_param_padded` is allocated (e.g. when `reshard_after_forward=False` and calling `summon_full_params()` between forward and backward), the already-unsharded check is a false positive, and `summon_full_params()` does not correctly force full precision. https://github.com/pytorch/pytorch/issues/83068
    - This PR's `needs_unshard()` check correctly routes to the appropriate padded unsharded flattened parameter depending on the calling context (i.e. if it needs to force full precision or not).
2. The existing code does not free the GPU copy of the padded unsharded flattened parameter when calling `summon_full_params(offload_to_cpu=True)`. It unshards the `FlatParameter`, moves the padded unsharded flattened parameter to CPU, and sets the `FlatParameter` data to be the appropriate unpadded view into the padded unsharded flattened parameter on CPU. However, `_full_param_padded` still points to the all-gathered padded unsharded flattened parameter on GPU, which is kept in memory. https://github.com/pytorch/pytorch/issues/83076
    - This PR frees the GPU copy and reallocates it upon exiting `summon_full_params()`. This is essential for avoiding peak GPU memory usage from increasing as we recurse through the module tree. There may be some cases where we can avoid reallocation altogether, but that can be addressed in a follow-up PR.
    - This PR offloads the *unpadded* unsharded flattened parameter to CPU directly instead of the *padded* one. As far as I can tell, there is no need to include the padding since unflattening the original parameters does not require the padding.
    - The relevant code is in the context manager `FlatParamHandle.to_cpu()`.

### Unshard: Mixed-Precision Stream

This PR removes the mixed precision stream usage. As is, I do not think there is any extra overlap being achieved by the stream usage.

The low precision shard is allocated and copied to in the mixed precision stream ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1401-L1412))), and the current stream (in this case the all-gather stream) waits for the mixed precision stream ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1414))). However, we immediately schedule an all-gather that communicates that exact low precision shard ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L3338))) with no other meaningful computation between. If we remove the mixed precision stream, the low precision shard is allocated and copied to in the all-gather stream (including the non-blocking CPU -> GPU copy if using CPU offloading).

Under this PR's design, we may consider a "pre-unshard" stream for all logical pre-unshard data transfers if we want to overlap in the future. IIUC, the overlap opportunity exists if there are multiple `FlatParameter`s per module, and we only have the all-gather stream wait for the data transfer corresponding to the local shard it communicates, not the others.

If we agree on removing the mixed-precision stream for now, I will remember to delete it from `_init_streams()`.

## FSDP Runtime: Reshard

Like with unshard, the first step is the look at the existing `_free_full_params()` and `_use_param_local_shard()`. A few notable observations:
- For only `NO_SHARD`, `_free_full_params()` includes a call to `_free_mp_shard()`.
- For `summon_full_params()`, there is a separate `_free_full_params_and_use_local_shard()` that duplicates the main logic of `_free_full_params()` and calls `_use_param_local_shard()`.
- In `forward()`, if `reshard_after_forward=True`, we call `_free_full_params()` and then `_free_mp_shard()`. Hence, for `NO_SHARD`, the `_free_mp_shard()` is a no-op.
- In the post-backward hook, we typically call `_free_full_params()` and `_free_mp_shard()`. The `_free_mp_shard()` is a no-op for `NO_SHARD` and if `reshard_after_forward=True`.

Some comments:
- The code certainly works, but some of the no-ops are subtle. When possible, we should make it clear when calls are no-ops or not. It is good that the existing code documents that `_free_mp_shard()` is a no-op in the post-backward hook when `reshard_after_forward=True`. However, there are still some non-obvious no-ops (around `NO_SHARD`).
- We should see if we can avoid the duplicate `_free_full_params_and_use_local_shard()`.

Let us trace through the logical reshard:
1. `FullyShardedDataParallel._reshard(self, handles: List[FlatParamHandle], free_unsharded_flat_params: List[bool])`
    - The two args should have the same length since they are to be zipped.
    - The goal of having `free_unsharded_flat_params` is that the caller should be explicit about whether the (padded) unsharded flattened parameter should be freed. The low precision shard is always meant to be freed (as early as possible), so there is no corresponding `List[bool]`.
2. `FlatParamHandle.reshard(self, free_unsharded_flat_param: bool)`
    - This frees the (padded) unsharded flattened parameter if `free_unsharded_flat_param` and switches to using the sharded flattened parameter.
    - Echoing back to forcing full precision in `summon_full_params()`, `_free_unsharded_flat_param()` frees the correct tensor by using `_get_padded_unsharded_flat_parameter()`.
3. `FlatParamHandle.post_reshard(self)`
    - I am not fully content with the existence of this method, but this seems to be an unavoidable consequence of `NO_SHARD`. Perhaps, this may be useful in the future for other reasons though.
    - Right now, this method is only meaningful for `NO_SHARD` + parameter mixed precision + outside `summon_full_params()`. `_mp_shard` is not freed in the post-unshard since it is also the low precision _unsharded_ flattened parameter, so we must delay the free until the the post-reshard.

Below the `FlatParamHandle.reshard()` and `post_reshard()` layer, there should not be any no-ops.

One final comment I will mention is that I like the `pre_unshard()`, `unshard()`, `post_unshard()`, and `reshard()`, `post_reshard()` organization because it makes it clear what the boundaries are and their temporal relationship. Through that, we can set pre- and post-conditions. Furthermore, we can eventually convert logic to hooks that may be registered on the `FlatParamHandle` (for `pre_unshard()`, `post_unshard()`, and `post_reshard()`). This may improve the customizability of FSDP.

 ## FSDP Runtime: `forward()`

- This PR reorganizes `forward()` in preparation for non-recursive wrapping, which uses pre-forward and post-forward hooks that expect the signature `hook(module, input)`. For FSDP, the `module` and `input` arguments are not used.
- This PR creates a new method `_fsdp_root_pre_forward()` to handle the logic only the root FSDP should run.

## FSDP Prefetching

Finally, we dive into the prefetching changes. Some highlights:
1. This PR unifies the execution order validation and prefetching implementations.
    - Both involve the execution order and can be unified to share some boilerplate.
2. Execution order validation only runs when the distributed debug level is `INFO`.
    - We have yet to have one success case where we actually catch an unintended source of dynamism. The warning is also too verbose. Hence, we are gating it by the `INFO` level.
3. This PR moves prefetching to be with respect to groups of handles (as mentioned in the constructor comment).
    - This is essential for supporting prefetching with non-recursive wrapping.
4. This PR does not include "bubbles", i.e. modules with no handles, in the recorded execution order(s). This deviates from the existing implementation.
    - This makes prefetching possibly more aggressive (when there are such bubbles), but it should not have significant performance implications either way.
5. This PR changes backward prefetching to reset the post-forward order each iteration (as intended).
6. This PR changes forward prefetching to use the first iteration's pre-forward order instead of the first iteration's post-forward order. (We can discuss whether we want this in this PR or not. Otherwise, I can keep it as using the post-forward order to preserve the existing semantics.) This PR also removes the `all_gather_stream.wait_stream(current_stream)` before forward prefetching because it does not help with high GPU reserved memory. We can add that back if desired.

### Appendix
#### Reverse Post-Forward Order Is Not Always the Pre-Backward Order
The existing PT-D FSDP pre-backward prefetching uses the reverse post-forward order.
<details>
  <summary>Model Code</summary>

  ```
  class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 4, kernel_size=3),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(4, 4, kernel_size=3),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=False),
        )
        self.block3 = nn.Linear(12, 8)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
            nn.Linear(4, 10),
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return self.head(x)

  model = Model().cuda()
  fsdp_kwargs = {}
  model.block1[1] = FSDP(model.block1[1], **fsdp_kwargs)  # BN2d
  model.block2[1] = FSDP(model.block2[1], **fsdp_kwargs)  # BN2d
  model.block1 = FSDP(model.block1, **fsdp_kwargs)
  model.block2 = FSDP(model.block2, **fsdp_kwargs)
  model.block3 = FSDP(model.block3, **fsdp_kwargs)
  model = FSDP(model, **fsdp_kwargs)
  ```
</details>

<details>
  <summary>Execution Orders </summary>

  ```
  Pre-backward hook for ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  Pre-backward hook for ('weight', 'bias') 140339461194656 (block3)
  Pre-backward hook for ('0.weight', '0.bias') 140339520589776 (block2)
  Pre-backward hook for ('weight', 'bias') 140339520587664 (block2 BN)
  Pre-backward hook for ('weight', 'bias') 140339520586656 (block1 BN)
  Pre-backward hook for ('0.weight', '0.bias') 140339520588768 (block1)

  Pre-forward order:
  ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  ('0.weight', '0.bias') 140339520588768 (block1)
  ('weight', 'bias') 140339520586656 (block1 BN)
  ('0.weight', '0.bias') 140339520589776 (block2)
  ('weight', 'bias') 140339520587664 (block2 BN)
  ('weight', 'bias') 140339461194656 (block3)

  Reverse post-forward order:
  ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  ('weight', 'bias') 140339461194656 (block3)
  ('0.weight', '0.bias') 140339520589776 (block2)
  ('weight', 'bias') 140339520587664 (block2 BN)
  ('0.weight', '0.bias') 140339520588768 (block1)
  ('weight', 'bias') 140339520586656 (block1 BN)
  ```
</details>

Differential Revision: [D39293429](https://our.internmc.facebook.com/intern/diff/D39293429)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83665
Approved by: https://github.com/zhaojuanmao
2022-09-13 17:05:10 +00:00
Chien-Chin Huang
28c830ac07 [FSDP] Optimizer states may be on CPU, copy them to GPU before gathering (#84708)
**Background**:
Optimizer states may not always on GPUs. Some examples include, 1.) CPU offload is enable, 2.) after lightning trainer fit() is called.

**What Does This PR Do?**
If states are not on GPUs, move them to GPUs before gathering the global states.

Differential Revision: [D39332300](https://our.internmc.facebook.com/intern/diff/D39332300/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84708
Approved by: https://github.com/awgu
2022-09-09 17:06:10 +00:00
Chien-Chin Huang
1840f24df7 [FSDP] Ensure that all ranks use the same order to iterate through optimizer states (#84654)
**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84654
Approved by: https://github.com/awgu
2022-09-09 07:19:01 +00:00
Andrew Gu
7f58db7424 [Easy][FSDP] ufmt _optim_utils.py (#84199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84199
Approved by: https://github.com/rohan-varma
2022-08-30 18:31:33 +00:00
Chien-Chin Huang
3e1fc85b23 [FSDP] Implement sharded_optim_state_dict and flatten_sharded_optim_state_dict. (#77628)
As title

Differential Revision: [D36436496](https://our.internmc.facebook.com/intern/diff/D36436496/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77628
Approved by: https://github.com/awgu
2022-08-18 16:38:58 +00:00
Andrew Gu
790b122901 [FSDP] Move tensor sharding logic to FlatParamHandle (#80000)
This moves the tensor sharding logic from `FullyShardedDataParallel` to `FlatParamHandle`. In particular, `_get_shard()` and its related subroutines are moved to `FlatParamHandle` as static methods.

The motivation is to start refactoring to move the broader FSDP sharding logic in `_shard_parameters()` to `FlatParamHandle` (as a part of the multiple parameter group and possibly future pluggable sharding efforts). In other words, in follow-ups, I hope to move
cd08954463/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1444-L1447)
to be part of `FlatParamHandle`.

Differential Revision: [D37726060](https://our.internmc.facebook.com/intern/diff/D37726060)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80000
Approved by: https://github.com/fegin
2022-07-22 19:21:51 +00:00
Andrew Gu
b069120d9c [FSDP] Deduplicate _orig_size and _unsharded_size (#79984)
This removes the `_orig_size` attribute that is initialized in `fully_sharded_data_parallel.py` since it represents the same quantity as `_unsharded_size` in `flat_param.py`. Since the quantity is not sharding dependent, we keep its initialization in `FlatParameter.init_metadata()` instead of in `FullyShardedDataParallel._shard_parameters()`.

Differential Revision: [D37726062](https://our.internmc.facebook.com/intern/diff/D37726062)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79984
Approved by: https://github.com/rohan-varma
2022-07-22 19:20:47 +00:00
Andrew Gu
be656f55b1 [FSDP] Introduce FlatParamHandle (#79652)
**Overview**
This PR introduces `FlatParamHandle` to enable non-recursive FSDP wrapping. The class absorbs the unflattening/flattening logic from `FlattenParamsWrapper` but does not require wrapping a particular `nn.Module`.

## Discussion
### Introducing `FlatParamHandle`
There is flexibility in the design space for how to allocate attributes and methods to `FlatParameter` versus a wrapping class like `FlatParamHandle` or `FlattenParamsWrapper`. Several points in the design space provide the same functionality, so deciding on an allocation is arguably stylistic, though then preference should be given to cleaner designs.

The forefront consideration is that a `FlatParameter`'s metadata should be initialized once, while its data may be reloaded via checkpointing. This motivates decoupling the metadata initialization from the `FlatParameter` constructor, which should instead only handle the parameter data. Thus, we have both a `FlatParamHandle` managing a `FlatParameter` and the `FlatParameter` itself.
```
class FlatParamHandle:
    def __init__(self, module: nn.Module, params: Sequence[nn.Parameter]):
        # Calls `_init_flat_param()`
    def _init_flat_param(self, module: nn.Module, params: Sequence[nn.Parameter]):
        # Calls `flatten_params()` and initializes metadata
    @staticmethod
    def flatten_params(params: Sequence[torch.Tensor], requires_grad: bool) -> FlatParameter:
        # Also may be used for checkpoint reloading
class FlatParameter(nn.Parameter):
    # Constructor is not overridden
```
Under this separation with `FlatParameter` as solely as a data container, we keep methods manipulating `FlatParameter` on the `FlatParamHandle`. Because `FlatParameter`'s constructor is not overridden, we should be able to replace it with another tensor type e.g. `ShardedTensor` with minimal changes.

### Compatibility with `FlattenParamsWrapper`
To ensure backward compatibility, `FlattenParamsWrapper` now holds a `FlatParamHandle`. Existing logic from `FlattenParamsWrapper` simply routes to the handle now.

A `FullyShardedDataParallel` instance holds references to all of its handles.
- For the recursive-wrapping paradigm, there is at most one handle, which is from its `FlattenParamsWrapper` if it manages parameters.
- For the non-recursive wrapping paradigm, there may be multiple handles, all owned by the single (root) `FullyShardedDataParallel` instance.

## For Reviewers
### `FlatParameter` Construction
In the existing implementation, a `FlatParameter`'s metadata was partially initialized in its constructor (e.g. `_param_numels`, `_param_shapes`) and partially initialized by the owning `FlattenParamsWrapper` (e.g. `_param_infos`, `_shared_param_infos`). The latter part was needed due to requiring module information. With this PR, the metadata initialization is consolidated in `FlatParamHandle`.
- During model construction, a `FlatParameter` should be initialized via the handle constructor`FlatParamHandle(params, module)`.
- During sharded checkpoint loading, a `FlatParameter` should be initialized via the static method `FlatParamHandle.flatten_params(new_params)`.
    - The checkpointing implementation is responsible for checking that `new_params` used to construct the `FlatParameter` data to load is consistent with the existing `FlatParameter`'s metadata.

These are the only two cases for `FlatParameter` construction right now, so there is no real functionality regression by not recomputing some of the metadata in the `FlatParameter` constructor. The `nn.Module.state_dict()` is implemented using in-place `copy_()`, so the new loaded `FlatParameter`'s metadata *should* match the existing `FlatParameter`'s metadata for correctness anyway. (I.e. we do not support a usage where we reload a `FlatParameter` with differing metadata into an existing `FlatParameter`.)

### BC Breaking
- `ShardMetadata` -> `FlatParamShardMetadata` to avoid name conflict with `ShardedTensor`
    - `metadata()` -> removed (unused)
- `FlatParameter` attributes
    - `_param_numels` -> `_numels`
    - `_param_shapes` -> `_shapes`
    - `_param_names` -> `_prefixed_param_names`
    - `full_numel` -> `_unsharded_size.numel()`
    - `_param_indice_in_shard` -> `_shard_indices`
    - `_sharded_param_offsets` -> `_shard_param_offsets`
    - `num_padded` -> `_shard_numel_padded`
    - `param_offsets` -> not saved; directly constructed in `_get_flat_param_offsets()` and used once
- `FlattenParamsWrapper` `param_list` argument -> `params` for consistency with `FlatParameter`

## Follow-Ups

- The current `FlatParameter`'s `data` represents either the sharded unflattened parameter, unsharded unflattened parameter, or reduced-precision sharded unflattened parameter, depending dynamically on the runtime context. When its `data` represents one quantity, the other quantities are still saved as attributes on the `FlatParameter` (e.g. `_local_shard`, `_full_param_padded`, `_mp_shard`). `FullyShardedDataParallel` directly manipulates the `data`.
We should investigate the tradeoffs of having those attributes on the `FlatParameter` versus moving them to the `FlatParamHandle`. The motivation for the latter is to define a clean interface for `FullyShardedDataParallel` to manage parameter data in preparation for generalizing to multiple parameter groups, to managing non-`FlatParameter`s, and to supporting non-CUDA devices. (More explicitly, `FullyShardedDataParallel`'s parameter *variables* would be set to different `Tensor` variables, none of which own another, instead of `FullyShardedDataParallel`'s parameter variables' *data* being set to different `Tensor` variables, all owned by the `FlatParameter`, and the data management would be folded into handle, hidden from `FullyShardedDataParallel`.)
- We should investigate if we can coalesce the remaining logic in `FlattenParamsWrapper` into `FullyShardedDataParallel` and remove `FlattenParamsWrapper`.
- We may want to move the mixed precision logic to the handle instead of the `FullyShardedDataParallel` instance to enable per-`FlatParameter` mixed precision instead of per-`FullyShardedDataParallel`. Otherwise, the non-recursive wrapping path is bound to all-or-nothing mixed precision.

Differential Revision: [D37250558](https://our.internmc.facebook.com/intern/diff/D37250558)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79652
Approved by: https://github.com/zhaojuanmao, https://github.com/fegin, https://github.com/rohan-varma
2022-07-22 19:16:50 +00:00
Andrew Gu
ab09f34622 [FSDP] Fix full_optim_state_dict() hang (#80712)
Fixes https://github.com/pytorch/pytorch/issues/80581.

Context:
1f08c1d3d6/torch/distributed/fsdp/_optim_utils.py (L152-L163)

To-Do:
I do not understand why inserting this `torch.cuda.synchronize()` prevents the `.cpu()` call from hanging and why in particular, this `torch.cuda.synchronize()` must be called on **all ranks**. If it is only called on the saving ranks (i.e. rank 0), then the hang persists.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80712
Approved by: https://github.com/rohan-varma
2022-07-07 15:23:06 +00:00
Michael Carilli
ba27ee9e8f [CUDA graphs] Allows Adam and AdamW to be capture-safe (#77862)
Near term fix for https://github.com/pytorch/pytorch/issues/76368.

Q. Why does the user need to request `capturable=True` in the optimizer constructor? Why can't capture safety be completely automatic?
A. We need to set up capture-safe (device-side) state variables before capture. If we don't, and step() internally detects capture is underway, it's too late: the best we could do is create a device state variable and copy the current CPU value into it, which is not something we want baked into the graph.

Q. Ok, why not just do the capture-safe approach with device-side state variables all the time?
A. It incurs several more kernel launches per parameter, which could really add up and regress cpu overhead for ungraphed step()s. If the optimizer won't be captured, we should allow step() to stick with its current cpu-side state handling.

Q. But cuda RNG is a stateful thing that maintains its state on the cpu outside of capture and replay, and we capture it automatically. Why can't we do the same thing here?
A. The graph object can handle RNG generator increments because its capture_begin, capture_end, and replay() methods can see and access generator object. But the graph object has no explicit knowledge of or access to optimizer steps in its capture scope. We could let the user tell the graph object what optimizers will be stepped in its scope, ie something like
```python
graph.will_use_optimizer(opt)
graph.capture_begin()
...
```
but that seems clunkier than an optimizer constructor arg.

I'm open to other ideas, but right now I think constructor arg is necessary and the least bad approach.

Long term, https://github.com/pytorch/pytorch/issues/71274 is a better fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77862
Approved by: https://github.com/ezyang
2022-06-13 01:56:47 +00:00
Andrew Gu
4615738a3d [FSDP] Allow different optim_input orders across ranks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78599

Approved by: https://github.com/rohan-varma
2022-06-03 11:47:24 +00:00
Andrew Gu
8412f209f0 [FSDP] Remove unneeded padding logic for optim state dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78208

Approved by: https://github.com/rohan-varma
2022-05-25 17:22:03 +00:00
Andrew Gu
94d65b05e9 [FSDP] Optim state: ignore params if not in dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76671

Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
2022-05-04 18:39:21 +00:00
Andrew Gu
019259a66f [FSDP] Add scatter_full_optim_state_dict()
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75517

Approved by: https://github.com/zhaojuanmao
2022-04-18 13:19:23 +00:00