**Overview**
This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables.
This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.
**Discussion**
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.
**For Reviewers**
- `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`).
- `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94196
Approved by: https://github.com/rohan-varma
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.
There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.
The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.
---
I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.
---
This PR reworks some tree traversal.
One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
submod1: Submodule()
submod2: Submodule(
subsubmod: Subsubmodule(),
),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.
Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.
The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.
- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.
---
Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.
The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.
Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:
---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.
For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.
For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---
After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.
---
As an example, consider:
```
mod: Module(
sub1: Submodule(
subsub1: Subsubmodule(),
subsub2: Subsubmodule(),
),
sub2: Submodule(
subsub1: Subsubmodule(),
subsub2: Subsubmodule(),
),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.
This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).
In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.
**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).
`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.
I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.
This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
This PR introduces the composable FSDP API (with constructor semantics only) along with some further constructor refactoring. A notable contribution here is `_get_submodule_to_states()`, which performs auto wrapping without actually wrapping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87924
Approved by: https://github.com/mrshenli
**Overview**
This PR introduces `FlatParamHandle` to enable non-recursive FSDP wrapping. The class absorbs the unflattening/flattening logic from `FlattenParamsWrapper` but does not require wrapping a particular `nn.Module`.
## Discussion
### Introducing `FlatParamHandle`
There is flexibility in the design space for how to allocate attributes and methods to `FlatParameter` versus a wrapping class like `FlatParamHandle` or `FlattenParamsWrapper`. Several points in the design space provide the same functionality, so deciding on an allocation is arguably stylistic, though then preference should be given to cleaner designs.
The forefront consideration is that a `FlatParameter`'s metadata should be initialized once, while its data may be reloaded via checkpointing. This motivates decoupling the metadata initialization from the `FlatParameter` constructor, which should instead only handle the parameter data. Thus, we have both a `FlatParamHandle` managing a `FlatParameter` and the `FlatParameter` itself.
```
class FlatParamHandle:
def __init__(self, module: nn.Module, params: Sequence[nn.Parameter]):
# Calls `_init_flat_param()`
def _init_flat_param(self, module: nn.Module, params: Sequence[nn.Parameter]):
# Calls `flatten_params()` and initializes metadata
@staticmethod
def flatten_params(params: Sequence[torch.Tensor], requires_grad: bool) -> FlatParameter:
# Also may be used for checkpoint reloading
class FlatParameter(nn.Parameter):
# Constructor is not overridden
```
Under this separation with `FlatParameter` as solely as a data container, we keep methods manipulating `FlatParameter` on the `FlatParamHandle`. Because `FlatParameter`'s constructor is not overridden, we should be able to replace it with another tensor type e.g. `ShardedTensor` with minimal changes.
### Compatibility with `FlattenParamsWrapper`
To ensure backward compatibility, `FlattenParamsWrapper` now holds a `FlatParamHandle`. Existing logic from `FlattenParamsWrapper` simply routes to the handle now.
A `FullyShardedDataParallel` instance holds references to all of its handles.
- For the recursive-wrapping paradigm, there is at most one handle, which is from its `FlattenParamsWrapper` if it manages parameters.
- For the non-recursive wrapping paradigm, there may be multiple handles, all owned by the single (root) `FullyShardedDataParallel` instance.
## For Reviewers
### `FlatParameter` Construction
In the existing implementation, a `FlatParameter`'s metadata was partially initialized in its constructor (e.g. `_param_numels`, `_param_shapes`) and partially initialized by the owning `FlattenParamsWrapper` (e.g. `_param_infos`, `_shared_param_infos`). The latter part was needed due to requiring module information. With this PR, the metadata initialization is consolidated in `FlatParamHandle`.
- During model construction, a `FlatParameter` should be initialized via the handle constructor`FlatParamHandle(params, module)`.
- During sharded checkpoint loading, a `FlatParameter` should be initialized via the static method `FlatParamHandle.flatten_params(new_params)`.
- The checkpointing implementation is responsible for checking that `new_params` used to construct the `FlatParameter` data to load is consistent with the existing `FlatParameter`'s metadata.
These are the only two cases for `FlatParameter` construction right now, so there is no real functionality regression by not recomputing some of the metadata in the `FlatParameter` constructor. The `nn.Module.state_dict()` is implemented using in-place `copy_()`, so the new loaded `FlatParameter`'s metadata *should* match the existing `FlatParameter`'s metadata for correctness anyway. (I.e. we do not support a usage where we reload a `FlatParameter` with differing metadata into an existing `FlatParameter`.)
### BC Breaking
- `ShardMetadata` -> `FlatParamShardMetadata` to avoid name conflict with `ShardedTensor`
- `metadata()` -> removed (unused)
- `FlatParameter` attributes
- `_param_numels` -> `_numels`
- `_param_shapes` -> `_shapes`
- `_param_names` -> `_prefixed_param_names`
- `full_numel` -> `_unsharded_size.numel()`
- `_param_indice_in_shard` -> `_shard_indices`
- `_sharded_param_offsets` -> `_shard_param_offsets`
- `num_padded` -> `_shard_numel_padded`
- `param_offsets` -> not saved; directly constructed in `_get_flat_param_offsets()` and used once
- `FlattenParamsWrapper` `param_list` argument -> `params` for consistency with `FlatParameter`
## Follow-Ups
- The current `FlatParameter`'s `data` represents either the sharded unflattened parameter, unsharded unflattened parameter, or reduced-precision sharded unflattened parameter, depending dynamically on the runtime context. When its `data` represents one quantity, the other quantities are still saved as attributes on the `FlatParameter` (e.g. `_local_shard`, `_full_param_padded`, `_mp_shard`). `FullyShardedDataParallel` directly manipulates the `data`.
We should investigate the tradeoffs of having those attributes on the `FlatParameter` versus moving them to the `FlatParamHandle`. The motivation for the latter is to define a clean interface for `FullyShardedDataParallel` to manage parameter data in preparation for generalizing to multiple parameter groups, to managing non-`FlatParameter`s, and to supporting non-CUDA devices. (More explicitly, `FullyShardedDataParallel`'s parameter *variables* would be set to different `Tensor` variables, none of which own another, instead of `FullyShardedDataParallel`'s parameter variables' *data* being set to different `Tensor` variables, all owned by the `FlatParameter`, and the data management would be folded into handle, hidden from `FullyShardedDataParallel`.)
- We should investigate if we can coalesce the remaining logic in `FlattenParamsWrapper` into `FullyShardedDataParallel` and remove `FlattenParamsWrapper`.
- We may want to move the mixed precision logic to the handle instead of the `FullyShardedDataParallel` instance to enable per-`FlatParameter` mixed precision instead of per-`FullyShardedDataParallel`. Otherwise, the non-recursive wrapping path is bound to all-or-nothing mixed precision.
Differential Revision: [D37250558](https://our.internmc.facebook.com/intern/diff/D37250558)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79652
Approved by: https://github.com/zhaojuanmao, https://github.com/fegin, https://github.com/rohan-varma
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.
This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.
Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75560
make apply_to_tensors support OrderedDict type
ghstack-source-id: 153492674
Test Plan: unit test
Reviewed By: rohan-varma
Differential Revision: D35524080
fbshipit-source-id: c71fa86ddb642b32aad6358fdcb040c4a0593f12
(cherry picked from commit 937f7acd65878cbea51e533fc4491fc577a904f7)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73530
Improve the unittests for the utility functions
ghstack-source-id: 150340208
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D34529335
fbshipit-source-id: 4e7283a58850d4674d22fb038284d6c762729dd4
(cherry picked from commit e53b0fd9d224578a89806408b71e145b967d602a)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084
make fsdp folder to be public
ghstack-source-id: 148173447
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D33903417
fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66904
Add more FSDP unit tests to cover core logic, freezing weights and flatten parameter wrappe, these unit tests are refactored to be aligned with PyTorch commonly used test classes
ghstack-source-id: 141335614
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D31779565
fbshipit-source-id: c727110d1d7570c0ec49e42cadfc9e9a5e440073