Commit Graph

68 Commits

Author SHA1 Message Date
Yanli Zhao
6ca991cacf [Composable API] Add fully_shard debug function to print sharded tree structure, module names and managed param fqns (#99133)
Adding a fully_shard debug function to print sharded tree structure like following format, return module names and their managed parameter fqns as well.

![Screenshot 2023-04-18 at 5 14 54 PM](https://user-images.githubusercontent.com/48731194/232931628-169a63a9-b4d5-4902-9cfd-f40113f3ec98.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99133
Approved by: https://github.com/rohan-varma
2023-04-19 19:27:43 +00:00
Chien-Chin Huang
bdaf32261f [FSDP] Ensure that customized non tensor optimizer state can be saved (#99214)
The current logic does not actually handle all different non-tensor optimizer states correctly. This PR fixes the issue and adds a test.

This PR will solve https://github.com/pytorch/pytorch/issues/99079

Differential Revision: [D45021331](https://our.internmc.facebook.com/intern/diff/D45021331/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99214
Approved by: https://github.com/awgu, https://github.com/awaelchli
2023-04-17 21:54:16 +00:00
Nikita Shulga
ccc5d1daec Revert D44897935: Multisect successfully blamed D44897935 for test or build failures (#99353)
Summary:
This diff is reverting D44897935
D44897935: [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912) by fegin has been identified to be causing the following test or build failures:

Tests affected:
- [caffe2/torch/fb/module_factory/sync_sgd/tests:test_pyper_data_parallel_wrapper - caffe2.torch.fb.module_factory.sync_sgd.tests.test_pyper_data_parallel_wrapper.PyPerDataParallelWrapperTest: test_fsdp_submodules_pyper](https://www.internalfb.com/intern/test/562950025957458/)

Here's the Multisect link:
https://www.internalfb.com/multisect/1893714
Here are the tasks that are relevant to this breakage:

We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

If you believe this diff has been generated in error you may Commandeer and Abandon it.

Test Plan: NA

Reviewed By: fegin

Differential Revision: D45027286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99353
Approved by: https://github.com/izaitsevfb, https://github.com/fegin
2023-04-17 20:53:10 +00:00
Chien-Chin Huang
8e328762ff [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

Differential Revision: [D44897935](https://our.internmc.facebook.com/intern/diff/D44897935/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98912
Approved by: https://github.com/awgu
2023-04-13 20:37:11 +00:00
Andrew Gu
662a8cf74d [FSDP][8/N] Simplify addr padding internals (#97796)
This is a follow-up to the last PR to greatly simplify the approach. This should be much cleaner.

**Details**
Let `N` denote the number of original parameters flattened into a given flat parameter with `M` extra padding tensors.
- `_numels_with_padding`: length `N + M`
- `_is_padding_mask`: length `N + M`
- `_numels`, `_param_infos`, `_shapes`, `_fqns`, `_param_extensions`: length `N`

`_shard_param_indices` and `_shard_param_offsets` were used to determine (1) if a given original parameter is in the local shard and if so, then (2) what is its offset in the _sharded_ flat parameter, and (3) how many numel are in the _sharded_ flat parameter.

This PR reworks how to achieve (1), (2), and (3) to allow for simplifying the previously mentioned data structures. In particular, it saves one extra tuple `_shard_param_infos: Tuple[_ShardParamInfo, ...]` of length `N` where each `_ShardParamInfo` entry gives exactly the needed info. For example, the offset into the sharded flat parameter is now pre-computed, so we do not need to do `offset = 0; offset += numel_in_shard` over a `for` loop each time now.

For optimizer state dict, `FSDPParamInfo.param_indices` now maps to the indexes with respect to the length `N` data structures, not the length `N + M` ones. The only purpose of `param_indices` is to be able to index into `flat_param._shard_param_infos[i]` to get the contained info to flatten the unsharded original parameter optimizer state and extract the part in the local shard.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97796
Approved by: https://github.com/rohan-varma
2023-03-28 22:19:44 +00:00
Andrew Gu
1c15cd48e2 [FSDP][7/N] Add alignment padding for use_orig_params=True (#97667)
This PR adds intra-`FlatParameter` 16-byte alignment padding to the `use_orig_params=True` code path to avoid clones in TorchInductor.

**Approach**
The `FlatParameter` maintains several data structures about its original parameters. Notably, the data structures `_param_infos`, `_shapes`, `_numels`, and `_fqns` have the same length and index in the same way.

This PR treats alignment padding _like_ an original parameter in that the padding gets flattened into the `FlatParameter`. Therefore, it must be reflected in the aforementioned data structures. However, given the way in which the data structures are used, we choose to do the following if the `i`th tensor flattened into the `FlatParameter` is padding:
- `_numels[i]` is the numel of padding
- `_param_infos[i] == _shapes[i] == _fqns[i] == None`

This choice is because (1) we must record the padding numel to account for it (e.g. for views) and (2) we prefer to preserve the invariant that the data structures index in the same way over avoiding `None` entries.

To ease the burden of other FSDP developers, we separate the parameter flattening logic:
- `_init_flat_param_and_metadata()`: This should be called only once in the `FlatParamHandle` constructor. The `FlatParameter` metadata is assumed to be static thereafter.
- `flatten_tensors()` / `flatten_tensors_into_flat_param()`: These can be used for optimizer and model state dict and can be called after construction time.

This separation allows `_init_flat_param_and_metadata()` to contain the much heavier metadata logic, while keeping the latter methods to be much lighter. The only constraint is that the alignment padding logic must be kept consistent between the two, but this should be worth the simper interface.

**Testing**
- This PR directly modifies the `use_orig_params=True` code path, so all existing tests passing gives good signal.
    - Some existing unit tests had to be adjusted to account for the alignment padding.
- This PR adds two tests in `test_fsdp_flatten_params.py` to explicitly test the sharding metadata with alignment for both parameter full precision and mixed precision since the latter requires possibly more padding elements due to the decreased per-element size.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97667
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Andrew Gu
b9049a7f11 [FSDP][6/N] Rename param/module name helpers for clarity (#97666)
This is an easy PR. It has some remaining local changes that I had that I felt clarified naming.
- `_param_fqns` -> `_param_name_infos` since it returns a tuple of `fqn, param_name, module_name`, not only `fqn`. (similarly for `_shared_param_fqns` -> `_shared_param_name_infos`)
- nit: `parameter_module_names` -> `param_module_names` for consistency since we almost never fully spell out `parameter`. (similarly for `shared_parameter_module_names` -> `shared_param_module_names`)
- nit: `full_fqn` -> `fqn_from_global_root`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97666
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Andrew Gu
30a6ed34a0 [FSDP][5/N] Lift FSDPParamInfo to use FlatParamHandle (#97665)
This PR changes `FSDPParamInfo` in `_optim_utils.py` to save the `FlatParamHandle`, not directly the `FlatParameter`. This is in preparation for subsequent PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97665
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Andrew Gu
c622559968 [FSDP][3/N] Minor fixes (rename, assert message) (#97663)
This is an easy PR.
- It renames `_shard_indices` to `_shard_param_indices` for consistency.
- It fixes an old mention of `comm_module` in an assert message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97663
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Andrew Gu
a27882ecd1 [FSDP][2/N] Rename "flattened parameter" -> "flat parameter" (pt. 2) (#97662)
From our recent experience, we refer to FSDP's `FlatParameter` as "flat parameter", not "flattened parameter". This PR renames that in `_optim_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97662
Approved by: https://github.com/rohan-varma
2023-03-28 01:46:43 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Chien-Chin Huang
793cb3f424 [FSDP][optim_state_dict] Print out more useful error message for optim_state_dict (#96860)
Summary: Print out more useful error message for optim_state_dict

Test Plan: CI

Reviewed By: wz337

Differential Revision: D43556073

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96860
Approved by: https://github.com/rohan-varma, https://github.com/wz337
2023-03-21 01:04:24 +00:00
Aaron Gokaslan
5471621497 [BE] Remove unnecessary dict comprehensions (#97116)
Removes unnecessary dict comprehensions that optimize creation of dicts from iterables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116
Approved by: https://github.com/kit1980
2023-03-20 00:56:57 +00:00
Chien-Chin Huang
15e58c19ec [FSDP][optim_state_dict] Copy step tensor so that each parameter has its own step (#96313)
Summary: When parameters are flattening, multiple parameters share the same step. When unflattening the parameters, current implementation still make these parameters share the same step. When this is not wrong, some training infra get confused by sharing tensor storages. This PR fixes the issue.

Test Plan: CI

Reviewed By: awgu

Differential Revision: D43893592

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96313
Approved by: https://github.com/zhaojuanmao
2023-03-10 04:51:30 +00:00
Chien-Chin Huang
92edac72aa [FSDP][optim_state_dict] Fix a memory leakage in optim_state_dict (#96263)
Summary: The original code uses a class variable to store flat_parameter result. This could cause memory leakage.

Test Plan: CI and a E2E run

Reviewed By: awgu

Differential Revision: D43893577

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96263
Approved by: https://github.com/zhaojuanmao
2023-03-08 08:43:42 +00:00
Colin Taylor
16a4579335 [FSDP] [composable] [BE] warning should read TorchRec, not DMP (#95010)
Summary: as title

Test Plan: N/A

Differential Revision: D43375189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95010
Approved by: https://github.com/awgu, https://github.com/fegin
2023-02-17 03:31:30 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Aaron Gokaslan
3d82d8d0ed [BE] Enable more flake8-comprehensions checks (#94601)
I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR.

This is a follow up to #94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94601
Approved by: https://github.com/ezyang
2023-02-10 23:40:29 +00:00
Chien-Chin Huang
2180a0dc0c [FSDP][optim_state_dict] Remove the dead code (#94448)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94448
Approved by: https://github.com/awgu
2023-02-09 06:32:40 +00:00
Aaron Gokaslan
3ce1ebb6fb Apply some safe comprehension optimizations (#94323)
Optimize unnecessary collection cast calls, unnecessary calls to list, tuple, and dict, and simplify calls to the sorted builtin. This should strictly improve speed and improve readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94323
Approved by: https://github.com/albanD
2023-02-07 23:53:46 +00:00
Chien-Chin Huang
ab4fe01e72 [FSDP][optim_state_dict] Returns the initial states of the empty parameters for KeyedOptimizer/NamedOptimizer (#94130)
KeyedOptimizer and NamedOptimizer expect the states exist in the state_dict when `load_state_dict` is called even if the corresponding parameters are empty (size == 0). This PR adds the support to make KeyedOptimizer work with `use_orig_params=True`.

Differential Revision: [D43019458](https://our.internmc.facebook.com/intern/diff/D43019458/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94130
Approved by: https://github.com/rohan-varma
2023-02-07 23:36:56 +00:00
Chien-Chin Huang
bc6d54f6d8 [FSDP][optim_state_dict] Let optim_state_dict ignore the non-FSDP managed parameters that do not reside on the rank (#94129)
When FSDP is used with other parallelism (e.g., TorchRec), some parameters that are not managed by FSDP may not reside on all the ranks (TorchRec is model parallelism). When `use_orig_params=True` , FSDP will synchronize the FQNs among ranks. As a result, a rank may get the FQNs that the rank does not actually own. If the FQN belongs to a TorchRec managed parameter, FSDP has to ignore the parameter state. Otherwise FSDP does not know how to store the state.

This PR add the logic to ignore the parameters that are not managed by FSDP and are not on the rank.

Differential Revision: [D42982778](https://our.internmc.facebook.com/intern/diff/D42982778/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94129
Approved by: https://github.com/rohan-varma
2023-02-07 06:29:28 +00:00
Chien-Chin Huang
0f5b6caa16 [FSDP][optim_state_dict] Ignore the state check on rank that does not own the corresponding parameter (#93318)
When a rank does not own a parameter (parameter.numel() == 0), its optim state is not valid and should not be checked against the current saved one.

Differential Revision: [D42865237](https://our.internmc.facebook.com/intern/diff/D42865237/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93318
Approved by: https://github.com/rohan-varma
2023-02-03 00:50:04 +00:00
Chien-Chin Huang
e32d99ae19 [FSDP][optim_state_dict] Make FSDP.optim_state_dict compatbile with DMP (#93285)
`torchrec.DistributedModelParallel` overwrites `named_parameters` and is not compatible with `FullyShardedDataParallel`'s optim_state_dict. This PR adds some workaround in `FullyShardedDataParallel` to make both work together.

Differential Revision: [D42764611](https://our.internmc.facebook.com/intern/diff/D42764611/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93285
Approved by: https://github.com/rohan-varma
2023-02-02 23:42:54 +00:00
Andrew Gu
10990734ce [FSDP][2/N] _summon_full_params -> _unshard_params (#92297)
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
    - `summon_full_params()` core logic to `_unshard_params()`
    - `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
    - Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state

**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
    - We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
    - Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.

**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.

I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297
Approved by: https://github.com/rohan-varma
2023-02-02 15:10:14 +00:00
Chien-Chin Huang
888771dc5d [FSDP][optim_state_dict] Fix _is_named_optimizer when the state is empty (#93303)
Optimizer state is not eager initializaion -- only NamedOptimizer and KeyedOptimizer are. This PR makes it `_is_named_optimizer` work with regular optimizers.

Differential Revision: [D42858589](https://our.internmc.facebook.com/intern/diff/D42858589/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93303
Approved by: https://github.com/fduwjj
2023-01-31 03:49:26 +00:00
Chien-Chin Huang
a4238976a8 [FSDP][optim_state_dict] Ensure correct devices for tensors when doing all_gather (#92992)
When doing `_all_gather_optim_state`, we need to ensure that `step` tensors are  on CPU and other tensors are on GPUs. This PR add the logic to ensure the locality.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92992
Approved by: https://github.com/fduwjj
2023-01-27 06:50:36 +00:00
Chien-Chin Huang
8b1b47c36a [FSDP][optim_state_dict] Use all_gather to deal with uneven size tensors (#92991)
The current `_all_gather_optim_state` pads the uneven tensors which is not necessary as `all_gather` support the uneven tensors. This PR removes the padding logic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92991
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2023-01-27 06:46:44 +00:00
Chien-Chin Huang
8f294f785f [FSDP][optim_state_dict] Fix the conditions to check non-parameter associated states (#92744)
If a state is not associated with any parameter, `FSDP.optim_state_dict` should still save it. The current implementation to determine whether a state is associated with a parameter is not completely correct and can cause `use_orig_params=True` have extra states.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92744
Approved by: https://github.com/awgu
2023-01-23 17:40:50 +00:00
Chien-Chin Huang
92d412d684 [FSDP][optim_state_dict][11/N] Let FSDP support NamedOptimizer/KeyedOptimizer when use_orig_params is False (#92184)
Current design of FSDP only support NamedOptimizer/KeyedOptimizer when use_orig_params is True this PR adds the support even if use_orig_params if False. This PR also adds the support for user-defined optimizer states -- states that are not associated with any particular parameters.

Differential Revision: [D42497416](https://our.internmc.facebook.com/intern/diff/D42497416/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92184
Approved by: https://github.com/colin2328, https://github.com/rohan-varma
2023-01-18 21:24:30 +00:00
Chien-Chin Huang
1439cb0314 [FSDP][optim_state_dict][9/N] Rewrite the all-gather flow of optimizer state to support older GPUs (#91343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91343
Approved by: https://github.com/rohan-varma
2023-01-17 17:21:19 +00:00
Rohan Varma
a155f64957 Update _optim_utils.py (#91935)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91935
Approved by: https://github.com/awgu, https://github.com/fegin
2023-01-11 22:06:26 +00:00
Chien-Chin Huang
0e8565d1d5 [FSDP][optim_state_dict][8/N] Enable fully_shard optim state_dict save and load (#91234)
**What does this PR do?**
This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91234
Approved by: https://github.com/rohan-varma
2022-12-30 06:56:44 +00:00
Chien-Chin Huang
6cea4f3d57 [FSDP][optim_state_dict][7/N] Make FSDP support NamedOptimizer (#91160)
**What does this PR do?**
This PR refactors FSDP optimizer state_dict APIs to accept `NamedOptimizer` as the input optimizer. The key difference is that the state_dict returned by `NamedOptimizer` is already keyed as FQN. This PR majorly changes the internal mapping to allows the optimizer state_dict to be keyed as FQN.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91160
Approved by: https://github.com/fduwjj, https://github.com/rohan-varma
2022-12-22 04:35:26 +00:00
Chien-Chin Huang
1ab6ac4682 [FSDP][optim_state_dict][6/N] Refactor the optim_state_dict APIs to support hooks (#90798)
**What does this PR do?**

This PR splits the FSDP optim_state_dict APIs into common implementation parts that are shared for different frontend APIs (we have many now and will consolidate them gradually). This PR also add `_optim_state_dict_post_hook` and `_load_optim_state_dict_pre_hook` for the integration with `NamedOptimzer`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90798
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-12-21 21:38:14 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
39d9dd135a [FSDP][Easy] ufmt files (#90858)
```
ufmt format torch/distributed/fsdp
ufmt format test/distributed/fsdp
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90858
Approved by: https://github.com/rohan-varma
2022-12-15 04:15:26 +00:00
Chien-Chin Huang
4a2d64994c [FSDP][optim_state_dict][4/N] Remove the unused _get_flat_param_to_fsdp_module API (#89980)
This is an easy PR, just remove an unused internal API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89980
Approved by: https://github.com/awgu
2022-12-13 21:01:46 +00:00
Chien-Chin Huang
043de8d1b1 [FSDP][optim_state_dict][3/N] Support use_orig_param optim_state_dict (non-broadcast version) (#89900)
**What:**
This PR add the optim state_dict support of `use_orig_params` with rank0_only is False. rank0_only support will be added in a following PR. The design of this PR focus on the simplicity and may not have good performance, especially for optim state_dict loading. Since optim state_dict loading is only called once in the beginning of the training, performance is not the major concern.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89900
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2022-12-13 20:45:21 +00:00
Chien-Chin Huang
44779d9bc6 [FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_info to map from original FQN to flat_param (#89899)
**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89899
Approved by: https://github.com/awgu
2022-12-07 19:40:47 +00:00
Ram Rachum
351d73b97f Fix exception causes all over the codebase (#90271)
This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
2022-12-07 04:29:00 +00:00
Chien-Chin Huang
72fdfad4ad [FSDP][optim_state_dict][1/N] Restructure _optim_state_dict to prepare the support of use_orig_param (#89898)
**Motivation:**
Restructure some APIs in _optim_state_dict.py to allow better future extension, mostly for supporting use_orig_params. NO logic change in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89898
Approved by: https://github.com/awgu
2022-12-05 21:01:48 +00:00
Chien-Chin Huang
ae4074669e [FSDP][state_dict][6/N] Remove most FSDP module dependency from _optim_utils (#88638)
**What**
This PR removes most `FullyShardedDataParallel` dependencies from `optim_utils`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88638
Approved by: https://github.com/awgu
2022-11-12 03:16:37 +00:00
Andrew Gu
a689502275 [FSDP] Do not include empty state in _flatten_optim_state_dict() (#88353)
983c0e7f31/torch/optim/adam.py (L163)
The above line requires that a candidate optimizer state dict being loaded via `load_state_dict()` has non-empty state for its 0th parameter (via `state_values[0]`). This PR changes FSDP to only include non-empty mappings in the state returned by `_flatten_optim_state_dict()`, which is the subroutine for both `shard_full_optim_state_dict()` and `flatten_sharded_optim_state_dict()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88353
Approved by: https://github.com/fegin
2022-11-03 11:33:10 +00:00
Andrew Gu
73de44fc56 [FSDP] Rename unflat_param_name -> fqn for consistency (#88123)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88123
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
bf2819a836 [FSDP()][24/N] Refactor _lazy_init() (#87939)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87939
Approved by: https://github.com/zhaojuanmao
2022-11-02 16:35:47 +00:00
Andrew Gu
cbc9faebfe [FSDP()][1/N] Start refactoring FSDP root pre-forward (#87915)
Welcome! This PR starts the refactoring journey.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87915
Approved by: https://github.com/mrshenli
2022-10-29 06:50:30 +00:00
Andrew Gu
e3cf81e0a7 [FSDP] ufmt /fsdp (#87811)
This applies `ufmt` to all of the FSDP files in the `torch/distributed/fsdp/` directory.

**Test Plan**
CI

**Notes**
For VSCode users,
- Install `ufmt`: https://pypi.org/project/ufmt/
- Install VSCode `ufmt` extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
- Include in `settings.json`:
```
{
    "[python]": {
        "editor.defaultFormatter": "omnilib.ufmt",
        "editor.formatOnSave": true,
    },
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87811
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2022-10-27 04:25:55 +00:00
Rohan Varma
701b3dd773 optim utils all_gather_into_tensor (#87769)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87769
Approved by: https://github.com/awgu
2022-10-26 16:20:46 +00:00
Andrew Gu
be682befbc [FSDP] Add use_orig_params (#84911)
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
See 85831 (removing link to avoid spamming the issue whenever I update this PR).

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84911
Approved by: https://github.com/rohan-varma
2022-10-07 18:07:17 +00:00