Commit Graph

35 Commits

Author SHA1 Message Date
Andrew Gu
3e44fcee2f [FSDP][2/N] Move fsdp_modules(root_only=False) -> _get_fsdp_states() (#90861)
This PR migrates all internal usages of `FullyShardedDataParallel.fsdp_modules(root_only=False)` to `_get_fsdp_states()`. This is to unify the code paths for composable and wrapper FSDP.

This PR _does not_ change the usages in test files. This is because we should revisit those usages separately as a way to track which functionality for which we have not tested composable FSDP.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90861
Approved by: https://github.com/rohan-varma
2022-12-16 12:21:47 +00:00
Andrew Gu
d04e3c994f [FSDP] Fix input grad propagation when using param mixed precision (#90921)
For parameter mixed precision, we cast the inputs to the low precision parameter dtype. If the input has tensors that require gradient, then we must cast them in place in order for them to receive a gradient. The cast should be tracked by autograd (e.g. with `grad_fn` equal to `ToCopyBackward0`). This removes the `torch.no_grad` context when calling `_apply_to_tensors`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90921
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-15 23:55:19 +00:00
Andrew Gu
c4718e9b09 [FSDP] Enable mixed hybrid/non-hybrid sharding strategies (#90846)
In the context of hybrid sharding strategies, we only need to enforce the same process groups among the instances using a hybrid sharding strategy, not all instances. We can even mix and match the two different hybrid sharding strategies. This PR relaxes the validation to support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90846
Approved by: https://github.com/rohan-varma
2022-12-15 15:36:23 +00:00
Andrew Gu
1ba4e3c711 [FSDP][BE] Remove _module_to_handles, HandleConfig; use term "fqn"; clarify docs (#90840)
This PR
- Removes `_module_to_handles` since it is no longer used. We instead use `_comm_module_to_handles`.
- Removes `HandleConfig` and stores its fields directly as attributes on `FlatParamHandle`.
- Uses the term `fqn`/`fqns` uniformly in `flat_param.py` instead of `prefixed_param_name` / `prefixed_param_names`.
- Clarifies some documentation.

I am including all of these BE items in the same PR to save CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90840
Approved by: https://github.com/rohan-varma
2022-12-14 21:37:37 +00:00
Andrew Gu
b66cedd906 [FSDP] Fix use_orig_params=True + no_sync() (#90546)
`no_sync()` introduces a separate case where a `FlatParameter` maintains an _unsharded_ gradient, instead of a _sharded_ one. This PR fixes `no_sync()` with `use_orig_params=True` by dealing with this separate case.

The existing `use_orig_params=False` already bypasses the built-in parameter/gradient size check, where the `flat_param` is sharded, while the `flat_param.grad` is unsharded. For `use_orig_params=True`, we need to use the same `.data` hack to side step the size check that we used to side step the dtype check for `keep_low_precision_grads=True`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90546
Approved by: https://github.com/rohan-varma
2022-12-13 23:40:04 +00:00
Andrew Gu
fc429512d5 [FSDP] Clean up FlatParamHandle dtypes, post-backward hook (#90660)
This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.

**Overview**
This PR addresses everything in https://github.com/pytorch/pytorch/issues/90657 except renaming `keep_low_precision_grads` to `keep_grads_in_reduce_dtype` since that is BC breaking. I recommend reading the issue before preceding.

For `MixedPrecision(param_dtype, reduce_dtype, ...)`, the exact rule for parameter and gradient reduction mixed precision that we are following is:
> If `param_dtype is not None` and `reduce_dtype is None`, then we infer `reduce_dtype = param_dtype`. Otherwise, we take `param_dtype` and `reduce_dtype` as is.

This PR enforces that, at the `FlatParamHandle` level, `handle._config.fwd_bwd_param_dtype` and `handle._config.reduce_dtype` are never `None`. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in `handle._orig_param_dtype`. It is no longer to check against `None`.

This avoids ambiguous cases such as when the user passes `MixedPrecision(param_dtype=torch.float32)`. In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).

**Additional Details**
- We remove `FullyShardedDataParallel._mixed_precision_enabled_for_params`, `FullyShardedDataParallel._mixed_precision_enabled_for_reduce`, and `FullyShardedDataParallel._mixed_precision_keep_low_precision_grads` since they are not used.
- The unit test `test_meta_device_with_mixed_precision()` exercises a tricky edge case with meta device initialization, `apply()` (calling into `summon_full_params()`), and `param_dtype=torch.float32` for a nested wrapping case, where each nested instance has parameters.
- We include some minor fixes/improvements to the communication hook implementation.

**Follow-Ups**
- We should get rid of `HandleConfig` and store its fields as attributes on `FlatParamHandle` directly.
- Rename `keep_low_precision_grads` to `keep_grads_in_reduce_dtype`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90660
Approved by: https://github.com/zhaojuanmao
2022-12-13 07:34:59 +00:00
Shen Li
7ec1cb8553 [FSDP] Fix _pre_forward type annotation (#90621)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90621
Approved by: https://github.com/awgu, https://github.com/Skylion007
2022-12-11 06:39:38 +00:00
Shen Li
80542add73 [FSDP] Allow MixedPrecision to skip inputs (#90620)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90620
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-12-11 06:39:38 +00:00
Andrew Gu
31351c61dd [FSDP] Tighten post-bwd cast to reduce_dtype (#90615)
This lowers the `reduce_dtype` retrieval to the `handle` instead of the `state` in preparation for `fully_shard`, and this adds a guard to avoid a no-op `to()` call.

Note that this change pretty much gets overridden in following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90615
Approved by: https://github.com/rohan-varma
2022-12-11 06:39:34 +00:00
Andrew Gu
e7efeb5282 [FSDP] Save _stream_to_name for debugging (#90611)
This saves a data structure `_stream_to_name: Dict[torch.cuda.Stream, str]` that maps each FSDP stream to its name. This can help in debugging by checking `_stream_to_name[torch.cuda.current_stream()]` to see if it is `"default"` or `"unshard"` in the post-backward hook for example.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90611
Approved by: https://github.com/rohan-varma
2022-12-11 03:46:18 +00:00
Shen Li
082450609c [FSDP] Allow nested FSDP wrapper to use different mixed precision (#90523)
The main change is to move `args` and `kwargs` dtype convertion
from `_root_pre_forward` to `_pre_forward`, so that every
FSDP has a chance to apply its own precision.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90523
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2022-12-09 20:06:05 +00:00
Andrew Gu
2cf703214b [Composable API][Easy] Fix some follow-ups (#90471)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90471
Approved by: https://github.com/mrshenli
2022-12-09 00:26:38 +00:00
Rohan Varma
793a999ce0 Hybrid Sharded Data Parallel (#89915)
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- @awgu 's excellent prototype: 5ad3a16d48
- @liangluofb For ideation, feedback, and initial implementation and experimentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89915
Approved by: https://github.com/awgu
2022-12-08 16:18:03 +00:00
Andrew Gu
21a0e809c2 [Composable API] Match fully_shard() comm. schedule with wrapper FSDP (#90387)
- This PR introduces a new concept, the _communication module_ (denoted `comm_module`), that represents the module responsible for the unshard/reshard pair for a `FlatParamHandle`. This is well-defined because the current design assumes that each `FlatParamHandle` only has _one_ unshard/reshard pair for either the forward or backward pass.
    - For the wrapper code path, the `comm_module` is exactly the module already being passed to the `FlatParamHandle` constructor.
    - For the composable code path, the `comm_module` is not necessarily the module already being passed to the `FlatParamHandle`. This is because the module already being passed is always the local FSDP root module to give complete FQNs, instead of local FQNs. Distinguishing the communication module from the local FSDP root module can provide more flexibility for non-recursive wrapping designs in the future.
- This PR adds a unit test `test_unshard_reshard_order` that explicitly checks that `_unshard` and `_reshard` are called in the exactly the same order across the two code paths.
- This PR does not fix `test_checkpoint_fsdp_submodules_use_reentrant`. However, the error message changes, so this PR accommodates that.
    - The error is now the same as if we used the equivalent wrapper FSDP:
    ```
    test_model.u1 = FSDP(test_model.u1, use_orig_params=True)
    test_model.u2 = FSDP(test_model.u2, use_orig_params=True)
    ```
    - The error is also the same as if we used wrapper FSDP with `use_orig_params=False`, so it is not unique to `use_orig_params=True`.

---

**`comm_module` Example**

```
model = Model(
    seq1: nn.Sequential(
        nn.Linear
        nn.ReLU
        nn.Linear
        nn.ReLU
    )
    seq2: nn.Sequential(
        nn.Linear
        nn.ReLU
        nn.Linear
        nn.ReLU
    )
)
policy = ModuleWrapPolicy({nn.Sequential})
fully_shard(model, policy=policy)
FullyShardedDataParallel(model, auto_wrap_policy=policy)
```
- This policy constructs two `FlatParamHandle`s, one for `seq1` and one for `seq2`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `model` as the `module` argument to every `FlatParamHandle`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90387
Approved by: https://github.com/mrshenli
2022-12-08 15:55:20 +00:00
Andrew Gu
45b40be078 [FSDP()] Fix fully_shard fwd hook registration (#90201)
I need to rebase later after Shen's PRs land.

The idea is to only register the pre/post-forward hook on the _root modules_ among the modules that consume a `FlatParameter`. (Yes, the term _root module_ is heavily overloaded. We may want to clarify that at some point. Here, _root_ is being used in the graph sense, meaning parent-less, and the scope is only among the modules consuming a `FlatParameter`.)

This avoids unnecessary pre/post-forward hooks running, which would lead to errors because the unshard is not truly idempotent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90201
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-06 06:09:03 +00:00
Andrew Gu
c8aaad040e [FSDP] Limit all gather after pre-unshard (#89057)
To reuse memory when allocating the unsharded `FlatParameter` in the unshard stream, we only need to block the CPU thread on the preceding free event (i.e. `event.synchronize()`) before allocating the unsharded memory, which happens in `handle.unshard()`. Notably, this can be done after the pre-unshard logic, which at most performs _sharded_ allocations (low precision shard or H2D sharded `FlatParameter` copy) in its own pre-unshard stream. This enables the pre-unshard to overlap with any pending ops.

With this change, I believe that we should use `limit_all_gathers=True` all the time to stay true to FSDP's proposed memory semantics.

If a user wants to set `limit_all_gathers=False`, that would mean that he/she wants to overlap ops that are issued after the unshard logic's all-gather with ops that are pending at the time when FSDP _would_ block the CPU thread via `event.synchronize()`.
- If the user is willing to not reuse memory for that all-gather, then the user may as well have applied `NO_SHARD` and optionally ZeRO-1 (if this niche is important, then maybe we should consider hardening ZeRO-1). This is because now the unsharded memory for the all-gather additionally contributes to peak memory since it cannot reuse memory.
- If the user wanted to reuse memory for that all-gather, then we needed to block the CPU thread. There is no way around that given the caching allocator semantics.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89057
Approved by: https://github.com/mrshenli
2022-11-30 02:27:02 +00:00
Andrew Gu
6e2da426f0 [FSDP] Relax post-backward assert (#89791)
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89791
Approved by: https://github.com/zhaojuanmao
2022-11-29 17:25:56 +00:00
Andrew Gu
090fc62b24 [FSDP()] Register root pre-forward hook (#89572)
- This PR registers the FSDP root pre-forward hook as a module forward pre-hook following the recently added support for kwargs for those hooks.
- This PR also passes `prepend=True` for the normal (not root) pre-forward hook. This is not strictly required for this PR, but I believe it is needed for composability with activation checkpointing. (We want to run FSDP logic on the outside and AC logic on the inside, just like how we recommend `FSDP(AC(module))` for the wrapper versions.)

Fun fact: I originally chose the `[FSDP()]` prefix in the PR titles when we still referred to composable FSDP as functional-like FSDP, in which case `FSDP()` approximated "functional FSDP". I am preserving this usage to make searching for PRs relating to composable FSDP easier.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89572
Approved by: https://github.com/mrshenli
2022-11-28 16:56:32 +00:00
Chien-Chin Huang
c4fc5d372f [FSDP][state_dict][1/N] Moving state_dict logic to pre_state_dict_hook (#87900)
This is one step toward the ultimate goal: remove the overwritten state_dict in FSDP. All the logic should be either in `pre_state_dict_hook` or `post_state_dict_hook`.

Since current `nn.Module` does not support `pre_state_dict_hook`, this PR mimic `pre_state_dict_hook` by calling the pre hook inside post the hook, effectively ditching all the work done by `nn.Module.state_dict`. Once `pre_state_dict_hook` is supported by `nn.Module`, these pre hook calls can be moved out from the post hooks and be registered to `nn.Module.pre_state_dict_hook`.

The major issue of this temporary solution is that `post_state_dict_hook` is called from the leaf node to the root node. This makes the `module._lazy_init()` invalid as FSDP assumes `_lazy_init()` to be called from the root. As a result, `FSDP.state_dict` currently contains only one logic -- calling `module._lazy_init()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87900
Approved by: https://github.com/rohan-varma
2022-11-11 03:41:40 +00:00
Andrew Gu
6bf2776ac1 [FSDP][Perf] Do not call pad in no-padding case (#88769)
- Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case.
- This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88769
Approved by: https://github.com/zhaojuanmao
2022-11-10 18:18:55 +00:00
Chien-Chin Huang
4de50b2521 [FSDP] Allow to use TorchDispatch with FSDP (#88014)
Add `_no_dispatch_record_stream` to disable TorchDispatch before calling `record_stream()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88014
Approved by: https://github.com/awgu
2022-11-03 23:15:56 +00:00
Andrew Gu
95a9721a15 [FSDP()][Easy] Rename _State to _FSDPState (#88234)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88234
Approved by: https://github.com/mrshenli
2022-11-03 11:29:01 +00:00
Andrew Gu
6c858e3727 [FSDP][Easy] Remove unneeded TrainingState transition (#88232)
Follow-up from previous PR in the stack
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88232
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
32d22edc67 [FSDP()][27/N] Add forward hook registration (#88040)
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88040
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
2022-11-02 23:25:53 +00:00
Andrew Gu
30dc6cee3a [FSDP()][26/N] Move _lazy_init() into _fsdp_root_pre_forward() (#87941)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87941
Approved by: https://github.com/mrshenli
2022-11-02 17:45:08 +00:00
Andrew Gu
f132c171ac [FSDP()][25/N] Add _post_forward_reshard() (#87940)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87940
Approved by: https://github.com/mrshenli
2022-11-02 17:16:30 +00:00
Andrew Gu
bf2819a836 [FSDP()][24/N] Refactor _lazy_init() (#87939)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87939
Approved by: https://github.com/zhaojuanmao
2022-11-02 16:35:47 +00:00
Andrew Gu
d172dcf316 [FSDP()][21/N] Refactor and fix _cast_buffers() (#87935)
This PR refactors and fixes `_cast_buffers()`.

**Before**
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.
- `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L700)
- There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L712-L717)
- Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly.

If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.

**After**
This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype.

Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names.

**Follow-Ups**
- We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up.
- We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87935
Approved by: https://github.com/mrshenli
2022-11-02 11:32:56 +00:00
Andrew Gu
19c7df89fb [FSDP()][20/N][Easy] Move functions in file (#87932)
This PR is easy. I just wanted to group functions in the file according to the same logical order.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87932
Approved by: https://github.com/mrshenli
2022-11-02 11:32:48 +00:00
Andrew Gu
4635f56da1 [FSDP()][18/N] Refactor pre_forward_unshard() (#87931)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87931
Approved by: https://github.com/mrshenli
2022-11-02 11:32:45 +00:00
Andrew Gu
0a752688bd [FSDP()][17/N] Refactor _fsdp_root_pre_forward() (#87930)
This PR moves `_fsdp_root_pre_forward()` to `_runtime_utils.py`.

Note: This PR includes a (temporary) fix for `NO_SHARD` + `CPUOffload(offload_params=True)`, where we set `non_blocking=False` when copying the gradient from device to host. It is only included in this PR since the test was **flaky** (but not consistently failing) on this PR , so I needed to fix to unblock land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87930
Approved by: https://github.com/mrshenli
2022-11-02 11:32:42 +00:00
Andrew Gu
1f34067e9d [FSDP()][16/N] Refactor post-forward/pre-backward (#87929)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87929
Approved by: https://github.com/mrshenli
2022-11-01 17:26:03 +00:00
Andrew Gu
90c5f856b2 [FSDP()][14/N] Refactor pre-forward/post-backward (#87927)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87927
Approved by: https://github.com/mrshenli
2022-11-01 17:25:59 +00:00
Andrew Gu
b1750d0440 [FSDP()][13/N] Refactor unshard/reshard/grads (#87926)
This PR is not too complicated. We just move unshard/reshard/grads out to `_runtime_utils.py` and make them take `state: _State` instead of `self`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87926
Approved by: https://github.com/mrshenli
2022-11-01 13:37:31 +00:00
Andrew Gu
cbc9faebfe [FSDP()][1/N] Start refactoring FSDP root pre-forward (#87915)
Welcome! This PR starts the refactoring journey.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87915
Approved by: https://github.com/mrshenli
2022-10-29 06:50:30 +00:00