Commit Graph

17 Commits

Author SHA1 Message Date
Rohan Varma
c43e88665a [Resubmit] helpers to torch.dist.utils (#95025)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95025
Approved by: https://github.com/fegin
2023-02-17 18:24:20 +00:00
Andrew Gu
5ee230face [FSDP][1/N] Refactor module materialization (#94196)
**Overview**
This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables.

This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.

**Discussion**
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.

**For Reviewers**
- `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`).
- `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94196
Approved by: https://github.com/rohan-varma
2023-02-13 21:43:00 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Andrew Gu
0d4bbd1996 [Lint] Add FSDP/composable API files to ufmt include (#90873)
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.

There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.

The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.

---

I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
2023-01-18 05:33:34 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Andrew Gu
8cd1808dbf [FSDP] Introduce "fully sharded module"; remove comm. module (#90933)
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.

Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:

---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.

For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.

For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---

After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.

---
As an example, consider:
```
mod: Module(
  sub1: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
  sub2: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
2022-12-16 18:45:52 +00:00
Andrew Gu
d01bf1d1f1 [FSDP] Introduce ModuleWrapPolicy for simplicity (#88450)
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
2022-11-12 04:14:32 +00:00
Andrew Gu
c1e28731b3 [FSDP()][10/N][11/N] Introduce composable (ctor only) (#87924)
This PR introduces the composable FSDP API (with constructor semantics only) along with some further constructor refactoring. A notable contribution here is `_get_submodule_to_states()`, which performs auto wrapping without actually wrapping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87924
Approved by: https://github.com/mrshenli
2022-11-01 12:39:24 +00:00
Andrew Gu
107f92a683 [FSDP] ufmt FSDP test (#87812)
This applies `ufmt` to all of the FSDP test files in the `test/distributed/fsdp/` directory.

**Test Plan**
CI

**Notes**
For VSCode users,
- Install `ufmt`: https://pypi.org/project/ufmt/
- Install VSCode `ufmt` extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
- Include in `settings.json`:
```
{
    "[python]": {
        "editor.defaultFormatter": "omnilib.ufmt",
        "editor.formatOnSave": true,
    },
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87812
Approved by: https://github.com/rohan-varma
2022-10-27 04:25:55 +00:00
edward-io
e7ff9d44ad [fsdp] add ability to iterate through dataclasses in fsdp.utils (#82638)
### Description

previously FSDP was failing on a torchmultimodal model because `_apply_to_tensors` couldn't iterate over dataclasses.

### Issue

None

### Testing

unit test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82638
Approved by: https://github.com/rohan-varma
2022-08-05 18:34:31 +00:00
Andrew Gu
be656f55b1 [FSDP] Introduce FlatParamHandle (#79652)
**Overview**
This PR introduces `FlatParamHandle` to enable non-recursive FSDP wrapping. The class absorbs the unflattening/flattening logic from `FlattenParamsWrapper` but does not require wrapping a particular `nn.Module`.

## Discussion
### Introducing `FlatParamHandle`
There is flexibility in the design space for how to allocate attributes and methods to `FlatParameter` versus a wrapping class like `FlatParamHandle` or `FlattenParamsWrapper`. Several points in the design space provide the same functionality, so deciding on an allocation is arguably stylistic, though then preference should be given to cleaner designs.

The forefront consideration is that a `FlatParameter`'s metadata should be initialized once, while its data may be reloaded via checkpointing. This motivates decoupling the metadata initialization from the `FlatParameter` constructor, which should instead only handle the parameter data. Thus, we have both a `FlatParamHandle` managing a `FlatParameter` and the `FlatParameter` itself.
```
class FlatParamHandle:
    def __init__(self, module: nn.Module, params: Sequence[nn.Parameter]):
        # Calls `_init_flat_param()`
    def _init_flat_param(self, module: nn.Module, params: Sequence[nn.Parameter]):
        # Calls `flatten_params()` and initializes metadata
    @staticmethod
    def flatten_params(params: Sequence[torch.Tensor], requires_grad: bool) -> FlatParameter:
        # Also may be used for checkpoint reloading
class FlatParameter(nn.Parameter):
    # Constructor is not overridden
```
Under this separation with `FlatParameter` as solely as a data container, we keep methods manipulating `FlatParameter` on the `FlatParamHandle`. Because `FlatParameter`'s constructor is not overridden, we should be able to replace it with another tensor type e.g. `ShardedTensor` with minimal changes.

### Compatibility with `FlattenParamsWrapper`
To ensure backward compatibility, `FlattenParamsWrapper` now holds a `FlatParamHandle`. Existing logic from `FlattenParamsWrapper` simply routes to the handle now.

A `FullyShardedDataParallel` instance holds references to all of its handles.
- For the recursive-wrapping paradigm, there is at most one handle, which is from its `FlattenParamsWrapper` if it manages parameters.
- For the non-recursive wrapping paradigm, there may be multiple handles, all owned by the single (root) `FullyShardedDataParallel` instance.

## For Reviewers
### `FlatParameter` Construction
In the existing implementation, a `FlatParameter`'s metadata was partially initialized in its constructor (e.g. `_param_numels`, `_param_shapes`) and partially initialized by the owning `FlattenParamsWrapper` (e.g. `_param_infos`, `_shared_param_infos`). The latter part was needed due to requiring module information. With this PR, the metadata initialization is consolidated in `FlatParamHandle`.
- During model construction, a `FlatParameter` should be initialized via the handle constructor`FlatParamHandle(params, module)`.
- During sharded checkpoint loading, a `FlatParameter` should be initialized via the static method `FlatParamHandle.flatten_params(new_params)`.
    - The checkpointing implementation is responsible for checking that `new_params` used to construct the `FlatParameter` data to load is consistent with the existing `FlatParameter`'s metadata.

These are the only two cases for `FlatParameter` construction right now, so there is no real functionality regression by not recomputing some of the metadata in the `FlatParameter` constructor. The `nn.Module.state_dict()` is implemented using in-place `copy_()`, so the new loaded `FlatParameter`'s metadata *should* match the existing `FlatParameter`'s metadata for correctness anyway. (I.e. we do not support a usage where we reload a `FlatParameter` with differing metadata into an existing `FlatParameter`.)

### BC Breaking
- `ShardMetadata` -> `FlatParamShardMetadata` to avoid name conflict with `ShardedTensor`
    - `metadata()` -> removed (unused)
- `FlatParameter` attributes
    - `_param_numels` -> `_numels`
    - `_param_shapes` -> `_shapes`
    - `_param_names` -> `_prefixed_param_names`
    - `full_numel` -> `_unsharded_size.numel()`
    - `_param_indice_in_shard` -> `_shard_indices`
    - `_sharded_param_offsets` -> `_shard_param_offsets`
    - `num_padded` -> `_shard_numel_padded`
    - `param_offsets` -> not saved; directly constructed in `_get_flat_param_offsets()` and used once
- `FlattenParamsWrapper` `param_list` argument -> `params` for consistency with `FlatParameter`

## Follow-Ups

- The current `FlatParameter`'s `data` represents either the sharded unflattened parameter, unsharded unflattened parameter, or reduced-precision sharded unflattened parameter, depending dynamically on the runtime context. When its `data` represents one quantity, the other quantities are still saved as attributes on the `FlatParameter` (e.g. `_local_shard`, `_full_param_padded`, `_mp_shard`). `FullyShardedDataParallel` directly manipulates the `data`.
We should investigate the tradeoffs of having those attributes on the `FlatParameter` versus moving them to the `FlatParamHandle`. The motivation for the latter is to define a clean interface for `FullyShardedDataParallel` to manage parameter data in preparation for generalizing to multiple parameter groups, to managing non-`FlatParameter`s, and to supporting non-CUDA devices. (More explicitly, `FullyShardedDataParallel`'s parameter *variables* would be set to different `Tensor` variables, none of which own another, instead of `FullyShardedDataParallel`'s parameter variables' *data* being set to different `Tensor` variables, all owned by the `FlatParameter`, and the data management would be folded into handle, hidden from `FullyShardedDataParallel`.)
- We should investigate if we can coalesce the remaining logic in `FlattenParamsWrapper` into `FullyShardedDataParallel` and remove `FlattenParamsWrapper`.
- We may want to move the mixed precision logic to the handle instead of the `FullyShardedDataParallel` instance to enable per-`FlatParameter` mixed precision instead of per-`FullyShardedDataParallel`. Otherwise, the non-recursive wrapping path is bound to all-or-nothing mixed precision.

Differential Revision: [D37250558](https://our.internmc.facebook.com/intern/diff/D37250558)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79652
Approved by: https://github.com/zhaojuanmao, https://github.com/fegin, https://github.com/rohan-varma
2022-07-22 19:16:50 +00:00
Rohan Varma
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
yanlizhao
887a93e5ac support PackedSequence type for apply_for_tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76265

support PackedSequence type for apply_for_tensors, some rnn modules outputs are PackedSequence types

Differential Revision: [D35862156](https://our.internmc.facebook.com/intern/diff/D35862156/)

Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-04-26 22:03:25 +00:00
Yanli Zhao
0389f99c49 make apply_to_tensors support OrderedDict type (#75560)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75560

make apply_to_tensors support OrderedDict type
ghstack-source-id: 153492674

Test Plan: unit test

Reviewed By: rohan-varma

Differential Revision: D35524080

fbshipit-source-id: c71fa86ddb642b32aad6358fdcb040c4a0593f12
(cherry picked from commit 937f7acd65878cbea51e533fc4491fc577a904f7)
2022-04-11 05:26:31 +00:00
Chien-Chin Huang
e7a786ff34 [FSDP] Add the unittests for the _replace_by_prefix (#73530)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73530

Improve the unittests for the utility functions
ghstack-source-id: 150340208

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D34529335

fbshipit-source-id: 4e7283a58850d4674d22fb038284d6c762729dd4
(cherry picked from commit e53b0fd9d224578a89806408b71e145b967d602a)
2022-03-03 01:29:04 +00:00
Yanli Zhao
2336571cb7 make fsdp folder to be public (#72084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084

make fsdp folder to be public
ghstack-source-id: 148173447

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D33903417

fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
2022-02-02 15:50:14 +00:00
Yanli Zhao
df3f82a1ef Add more FSDP unit tests to cover core logic, freezing weights and flatten parameter wrapper (#66904)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66904

Add more FSDP unit tests to cover core logic, freezing weights and flatten parameter wrappe, these unit tests are refactored to be aligned with PyTorch commonly used test classes
ghstack-source-id: 141335614

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D31779565

fbshipit-source-id: c727110d1d7570c0ec49e42cadfc9e9a5e440073
2021-10-22 16:50:52 -07:00