Commit Graph

38 Commits

Author SHA1 Message Date
Andrew Gu
10271a60a8 [FSDP] Skip _use_sharded_views() for SHARD_GRAD_OP (#98250)
This PR has `SHARD_GRAD_OP` (and `_HYBRID_SHARD_ZERO2`) skip `_use_sharded_views()` in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (`_use_sharded_views()`) and the pre-backward unshard (`_use_unsharded_views()`).

<details>
<summary>(Before) Pre-backward hook: 4.356 ms</summary>

<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">

</details>

<details>
<summary>(After) Pre-backward hook: 1.044 ms</summary>

![Screenshot 2023-04-04 at 9 05 53 AM](https://user-images.githubusercontent.com/31054793/229800917-9580ce6b-3721-469a-9212-f0cbfd8cbb52.png)

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98250
Approved by: https://github.com/rohan-varma
2023-04-04 17:07:28 +00:00
James Reed
3b1b585a59 [FSDP] Fix bug in determining whether parameters need to be materialized (#97488)
Previously, `_need_to_materialize_module` would return false because:

* `managed_params =_get_orig_params(module, ignored_params)` returns a generator
* `is_meta_module = any(param.is_meta for param in managed_params)` exhausts the generator in its check
* `any(fake.is_fake(param) for param in managed_params)` would try to iterate over the empty generator and get an empty sequence, thus returning `False`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97488
Approved by: https://github.com/ngimel, https://github.com/awgu
2023-03-25 08:24:57 +00:00
Rohan Varma
605a77fd59 Log FSDP mixed precision (#97367)
Log to clarify the mp config in jobs

Differential Revision: [D44307044](https://our.internmc.facebook.com/intern/diff/D44307044/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97367
Approved by: https://github.com/awgu
2023-03-24 16:01:59 +00:00
Andrew Gu
5ee230face [FSDP][1/N] Refactor module materialization (#94196)
**Overview**
This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables.

This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.

**Discussion**
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.

**For Reviewers**
- `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`).
- `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94196
Approved by: https://github.com/rohan-varma
2023-02-13 21:43:00 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Chien-Chin Huang
4b0f1cc1ee [FSDP][optim_state_dict][10/N] Make optim_state_dict and optim_state_dict_to_load public (#92118)
Make optim_state_dict and optim_state_dict_to_load public APIs and consolidate them with state_dict by using the same state_dict_type to decide how to perform the optimizer state_dict save and load.

Differential Revision: [D42488022](https://our.internmc.facebook.com/intern/diff/D42488022/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92118
Approved by: https://github.com/rohan-varma
2023-02-02 08:04:20 +00:00
Andrew Gu
63d6ee7d02 [FSDP][Easy] Remove outdated comment (#92739)
We pass `fully_sharded_module`, not `root_module`, after recent refactoring to unify composable and wrapper FSDP for now. This PR removes the comment explaining why before we passed in `root_module`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92739
Approved by: https://github.com/mrshenli
2023-01-23 15:52:49 +00:00
Andrew Gu
0d4bbd1996 [Lint] Add FSDP/composable API files to ufmt include (#90873)
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.

There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.

The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.

---

I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
2023-01-18 05:33:34 +00:00
Yanli Zhao
2004df9097 Remove python ddp (#91663)
As it is not used by anyone and also it is not maintained by PyTorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91663
Approved by: https://github.com/rohan-varma
2023-01-04 05:22:30 +00:00
Yanli Zhao
f613633124 Remove _ignored_param_names (#91530)
'_ignored_param_names' is only used in 'param_hook' during state_dict() post hook processing to check a parameter key needs to be cloned or not. But it is not needed, as state_dict() post hook only passes fsdp managed parameter keys to 'param_hook', see https://github.com/pytorch/pytorch/blob/master/torch/distributed/fsdp/_state_dict_utils.py#L203. That means the passed parameter keys are always not part of '_ignored_param_names'.

so we should be able to safely remove '_ignored_param_names' and related codes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91530
Approved by: https://github.com/rohan-varma
2022-12-31 03:28:22 +00:00
Andrew Gu
aec09eeb3a [FSDP][7/N] Support replicate in fully_shard (#91044)
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
2022-12-20 16:49:18 +00:00
Shen Li
e5a48da664 Allow FSDP to have ignored modules out of wrapped root (#91079)
Motivations for this change:

1. TorchRec returns inconsistent results on `m.named_parameters()`
   and `m.m1.named_parameters()` if m1 is a `ShardedModule`. Basically,
   `ShardedModule` appears in `m.named_modules()`, but its parameters
   are not in `m.named_parameters()`. As a result, when we identify
   `ShardedModule` and pass them as `ignored_modules` to FSDP, FSDP
   complains about key error in `_get_ignored_params`.
2. If users are manually wrapping submodules with FSDP, it could be
   easier for them to keep a global set of ignored parameters, instead
   of create a new collection for every FSDP invocation.

Given the above two reasons, we allow FSDP to have ignored modules
out of the wrapped root module.

Differential Revision: [D42132394](https://our.internmc.facebook.com/intern/diff/D42132394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91079
Approved by: https://github.com/awgu
2022-12-19 14:28:25 +00:00
Andrew Gu
8cd1808dbf [FSDP] Introduce "fully sharded module"; remove comm. module (#90933)
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.

Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:

---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.

For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.

For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---

After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.

---
As an example, consider:
```
mod: Module(
  sub1: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
  sub2: Submodule(
    subsub1: Subsubmodule(),
    subsub2: Subsubmodule(),
  ),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
2022-12-16 18:45:52 +00:00
Andrew Gu
1ba4e3c711 [FSDP][BE] Remove _module_to_handles, HandleConfig; use term "fqn"; clarify docs (#90840)
This PR
- Removes `_module_to_handles` since it is no longer used. We instead use `_comm_module_to_handles`.
- Removes `HandleConfig` and stores its fields directly as attributes on `FlatParamHandle`.
- Uses the term `fqn`/`fqns` uniformly in `flat_param.py` instead of `prefixed_param_name` / `prefixed_param_names`.
- Clarifies some documentation.

I am including all of these BE items in the same PR to save CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90840
Approved by: https://github.com/rohan-varma
2022-12-14 21:37:37 +00:00
Andrew Gu
93aee0cdc9 [FSDP][Easy] ufmt files (#90548)
```
ufmt format torch/distributed/fsdp
ufmt format test/distributed/fsdp
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90548
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-14 02:02:53 +00:00
Andrew Gu
fc429512d5 [FSDP] Clean up FlatParamHandle dtypes, post-backward hook (#90660)
This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.

**Overview**
This PR addresses everything in https://github.com/pytorch/pytorch/issues/90657 except renaming `keep_low_precision_grads` to `keep_grads_in_reduce_dtype` since that is BC breaking. I recommend reading the issue before preceding.

For `MixedPrecision(param_dtype, reduce_dtype, ...)`, the exact rule for parameter and gradient reduction mixed precision that we are following is:
> If `param_dtype is not None` and `reduce_dtype is None`, then we infer `reduce_dtype = param_dtype`. Otherwise, we take `param_dtype` and `reduce_dtype` as is.

This PR enforces that, at the `FlatParamHandle` level, `handle._config.fwd_bwd_param_dtype` and `handle._config.reduce_dtype` are never `None`. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in `handle._orig_param_dtype`. It is no longer to check against `None`.

This avoids ambiguous cases such as when the user passes `MixedPrecision(param_dtype=torch.float32)`. In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).

**Additional Details**
- We remove `FullyShardedDataParallel._mixed_precision_enabled_for_params`, `FullyShardedDataParallel._mixed_precision_enabled_for_reduce`, and `FullyShardedDataParallel._mixed_precision_keep_low_precision_grads` since they are not used.
- The unit test `test_meta_device_with_mixed_precision()` exercises a tricky edge case with meta device initialization, `apply()` (calling into `summon_full_params()`), and `param_dtype=torch.float32` for a nested wrapping case, where each nested instance has parameters.
- We include some minor fixes/improvements to the communication hook implementation.

**Follow-Ups**
- We should get rid of `HandleConfig` and store its fields as attributes on `FlatParamHandle` directly.
- Rename `keep_low_precision_grads` to `keep_grads_in_reduce_dtype`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90660
Approved by: https://github.com/zhaojuanmao
2022-12-13 07:34:59 +00:00
Andrew Gu
e7efeb5282 [FSDP] Save _stream_to_name for debugging (#90611)
This saves a data structure `_stream_to_name: Dict[torch.cuda.Stream, str]` that maps each FSDP stream to its name. This can help in debugging by checking `_stream_to_name[torch.cuda.current_stream()]` to see if it is `"default"` or `"unshard"` in the post-backward hook for example.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90611
Approved by: https://github.com/rohan-varma
2022-12-11 03:46:18 +00:00
Rohan Varma
793a999ce0 Hybrid Sharded Data Parallel (#89915)
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- @awgu 's excellent prototype: 5ad3a16d48
- @liangluofb For ideation, feedback, and initial implementation and experimentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89915
Approved by: https://github.com/awgu
2022-12-08 16:18:03 +00:00
Andrew Gu
21a0e809c2 [Composable API] Match fully_shard() comm. schedule with wrapper FSDP (#90387)
- This PR introduces a new concept, the _communication module_ (denoted `comm_module`), that represents the module responsible for the unshard/reshard pair for a `FlatParamHandle`. This is well-defined because the current design assumes that each `FlatParamHandle` only has _one_ unshard/reshard pair for either the forward or backward pass.
    - For the wrapper code path, the `comm_module` is exactly the module already being passed to the `FlatParamHandle` constructor.
    - For the composable code path, the `comm_module` is not necessarily the module already being passed to the `FlatParamHandle`. This is because the module already being passed is always the local FSDP root module to give complete FQNs, instead of local FQNs. Distinguishing the communication module from the local FSDP root module can provide more flexibility for non-recursive wrapping designs in the future.
- This PR adds a unit test `test_unshard_reshard_order` that explicitly checks that `_unshard` and `_reshard` are called in the exactly the same order across the two code paths.
- This PR does not fix `test_checkpoint_fsdp_submodules_use_reentrant`. However, the error message changes, so this PR accommodates that.
    - The error is now the same as if we used the equivalent wrapper FSDP:
    ```
    test_model.u1 = FSDP(test_model.u1, use_orig_params=True)
    test_model.u2 = FSDP(test_model.u2, use_orig_params=True)
    ```
    - The error is also the same as if we used wrapper FSDP with `use_orig_params=False`, so it is not unique to `use_orig_params=True`.

---

**`comm_module` Example**

```
model = Model(
    seq1: nn.Sequential(
        nn.Linear
        nn.ReLU
        nn.Linear
        nn.ReLU
    )
    seq2: nn.Sequential(
        nn.Linear
        nn.ReLU
        nn.Linear
        nn.ReLU
    )
)
policy = ModuleWrapPolicy({nn.Sequential})
fully_shard(model, policy=policy)
FullyShardedDataParallel(model, auto_wrap_policy=policy)
```
- This policy constructs two `FlatParamHandle`s, one for `seq1` and one for `seq2`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `model` as the `module` argument to every `FlatParamHandle`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90387
Approved by: https://github.com/mrshenli
2022-12-08 15:55:20 +00:00
Ram Rachum
351d73b97f Fix exception causes all over the codebase (#90271)
This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
2022-12-07 04:29:00 +00:00
Andrew Gu
45b40be078 [FSDP()] Fix fully_shard fwd hook registration (#90201)
I need to rebase later after Shen's PRs land.

The idea is to only register the pre/post-forward hook on the _root modules_ among the modules that consume a `FlatParameter`. (Yes, the term _root module_ is heavily overloaded. We may want to clarify that at some point. Here, _root_ is being used in the graph sense, meaning parent-less, and the scope is only among the modules consuming a `FlatParameter`.)

This avoids unnecessary pre/post-forward hooks running, which would lead to errors because the unshard is not truly idempotent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90201
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-06 06:09:03 +00:00
Andrew Gu
1ee189ce8e [FSDP] Issue warning when clamping to NO_SHARD (#90060)
Fixes https://github.com/pytorch/pytorch/issues/90050. I hope that this was not meant as an onboarding task :/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90060
Approved by: https://github.com/zhaojuanmao
2022-12-03 15:58:25 +00:00
Andrew Gu
d1760d7a42 [FSDP][Easy] Remove outdated TODO (#89217)
**Overview**
This PR removes an outdated TODO:
```
# TODO (awgu): When exposing the original parameters, we need to also
# use this attribute to prevent re-synchronizing parameters.
```

**Justification**
We only pass `managed_params` to `_sync_module_params_and_buffers()`, where `managed_params` is defined as
```
managed_params = list(_get_orig_params(root_module, state._ignored_params))
```
This `_get_orig_params()` call excludes parameters already flattened by FSDP. Thus, `_sync_module_params_and_buffers()` will not re-sync already-synchronized parameters. Each parameter appears in `managed_params` for some FSDP instance exactly once and hence is only synchronized once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89217
Approved by: https://github.com/mrshenli
2022-11-30 01:42:16 +00:00
Andrew Gu
d01bf1d1f1 [FSDP] Introduce ModuleWrapPolicy for simplicity (#88450)
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
2022-11-12 04:14:32 +00:00
Chien-Chin Huang
7aa144ac54 [FSDP][state_dict][5/N] Remove the FSDP module dependency from _state_dict_utils (#88637)
**What**
This PR completely removes the `FullyShardedDataParallel` dependency from `_state_dict_utils` -- `_state_dict_utils` now depends only on `_FSDPState` and all the utils modules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88637
Approved by: https://github.com/awgu
2022-11-11 21:22:13 +00:00
Chien-Chin Huang
324ac93a43 [FSDP][state_dict][2/N] Move state_dict related enums/dataclasses/states to state_dict_utils.py, api.py and init_state_dict() (#88481)
**Motivation**:
Several Enums, Dataclasses and states defined in fully_sharded_data_paralle.py should be moved to a place where the composable FSDP can access. This PR does the move.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88481
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-11-11 12:28:37 +00:00
Andrew Gu
95a9721a15 [FSDP()][Easy] Rename _State to _FSDPState (#88234)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88234
Approved by: https://github.com/mrshenli
2022-11-03 11:29:01 +00:00
Andrew Gu
73de44fc56 [FSDP] Rename unflat_param_name -> fqn for consistency (#88123)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88123
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
f35d5145a1 [FSDP] Simplify _get_buffer_names() (#88122)
This is a follow-up from a previous PR in this stack. The PR simplifies the `_get_buffer_names()` implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88122
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
572a3d2d6e [FSDP] Remove unneeded torch.no_grad() context when offloading to CPU (#88121)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88121
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
32d22edc67 [FSDP()][27/N] Add forward hook registration (#88040)
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88040
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
2022-11-02 23:25:53 +00:00
Andrew Gu
bf2819a836 [FSDP()][24/N] Refactor _lazy_init() (#87939)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87939
Approved by: https://github.com/zhaojuanmao
2022-11-02 16:35:47 +00:00
Andrew Gu
d172dcf316 [FSDP()][21/N] Refactor and fix _cast_buffers() (#87935)
This PR refactors and fixes `_cast_buffers()`.

**Before**
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.
- `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L700)
- There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L712-L717)
- Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly.

If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.

**After**
This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype.

Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names.

**Follow-Ups**
- We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up.
- We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87935
Approved by: https://github.com/mrshenli
2022-11-02 11:32:56 +00:00
Andrew Gu
5a53f024e4 [FSDP()][15/N] Refactor _init_streams() (#87928)
This PR is easy. I think I move `_init_streams()` again in a later PR though :/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87928
Approved by: https://github.com/mrshenli
2022-11-01 17:26:03 +00:00
Andrew Gu
8039317c07 [FSDP()][12/N] Easy cleanup (#87925)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87925
Approved by: https://github.com/mrshenli
2022-11-01 12:39:24 +00:00
Andrew Gu
c1e28731b3 [FSDP()][10/N][11/N] Introduce composable (ctor only) (#87924)
This PR introduces the composable FSDP API (with constructor semantics only) along with some further constructor refactoring. A notable contribution here is `_get_submodule_to_states()`, which performs auto wrapping without actually wrapping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87924
Approved by: https://github.com/mrshenli
2022-11-01 12:39:24 +00:00
Andrew Gu
78170701a3 [FSDP()][9/N] Refactor ctor (continued) (#87923)
This PR makes a second pass over the constructor. The logic has been grouped into `_init_<...>` functions based on intent (e.g. `_init_prefetching_state()` or `_init_runtime_state()`). This makes the initialization code for composable FSDP much cleaner than having to re-write the same sequences of lower-level helper calls.

This PR also moves `_ExecOrderData` into its own file `_exec_order_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87923
Approved by: https://github.com/mrshenli
2022-11-01 12:39:21 +00:00
Andrew Gu
d89cf2fdc9 [FSDP()][7/N] Refactor most of ctor (#87921)
The goal of this PR is to make one pass over the FSDP constructor and refactor each helper method call to not be `self.<...>`. Subsequent PRs will make further passes over the FSDP constructor.

This PR looks like a lot of lines of code change, but it is only reorganization. Methods are moved to `_init_utils.py` and `_common_utils.py`. This also marks the beginning of moving methods from `_utils.py` to `_common_utils.py` -- they will be coalesced eventually. I am only using `_common_utils.py` as a staging ground to include the methods that have been affected by the refactoring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87921
Approved by: https://github.com/mrshenli
2022-10-31 16:45:24 +00:00