I prefer to not modify the module if it does not have any of our APIs applied. The side effect of inserting a registry on the module when calling a getter is non-intuitive to me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113654
Approved by: https://github.com/fegin
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297
Approved by: https://github.com/rohan-varma
**What does this PR do?**
This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91234
Approved by: https://github.com/rohan-varma
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.
---
This PR reworks some tree traversal.
One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
submod1: Submodule()
submod2: Submodule(
subsubmod: Subsubmodule(),
),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.
Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.
The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.
- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.
---
Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.
The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli