Commit Graph

18 Commits

Author SHA1 Message Date
Andrew Gu
090fc62b24 [FSDP()] Register root pre-forward hook (#89572)
- This PR registers the FSDP root pre-forward hook as a module forward pre-hook following the recently added support for kwargs for those hooks.
- This PR also passes `prepend=True` for the normal (not root) pre-forward hook. This is not strictly required for this PR, but I believe it is needed for composability with activation checkpointing. (We want to run FSDP logic on the outside and AC logic on the inside, just like how we recommend `FSDP(AC(module))` for the wrapper versions.)

Fun fact: I originally chose the `[FSDP()]` prefix in the PR titles when we still referred to composable FSDP as functional-like FSDP, in which case `FSDP()` approximated "functional FSDP". I am preserving this usage to make searching for PRs relating to composable FSDP easier.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89572
Approved by: https://github.com/mrshenli
2022-11-28 16:56:32 +00:00
Chien-Chin Huang
c4fc5d372f [FSDP][state_dict][1/N] Moving state_dict logic to pre_state_dict_hook (#87900)
This is one step toward the ultimate goal: remove the overwritten state_dict in FSDP. All the logic should be either in `pre_state_dict_hook` or `post_state_dict_hook`.

Since current `nn.Module` does not support `pre_state_dict_hook`, this PR mimic `pre_state_dict_hook` by calling the pre hook inside post the hook, effectively ditching all the work done by `nn.Module.state_dict`. Once `pre_state_dict_hook` is supported by `nn.Module`, these pre hook calls can be moved out from the post hooks and be registered to `nn.Module.pre_state_dict_hook`.

The major issue of this temporary solution is that `post_state_dict_hook` is called from the leaf node to the root node. This makes the `module._lazy_init()` invalid as FSDP assumes `_lazy_init()` to be called from the root. As a result, `FSDP.state_dict` currently contains only one logic -- calling `module._lazy_init()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87900
Approved by: https://github.com/rohan-varma
2022-11-11 03:41:40 +00:00
Andrew Gu
6bf2776ac1 [FSDP][Perf] Do not call pad in no-padding case (#88769)
- Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case.
- This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88769
Approved by: https://github.com/zhaojuanmao
2022-11-10 18:18:55 +00:00
Chien-Chin Huang
4de50b2521 [FSDP] Allow to use TorchDispatch with FSDP (#88014)
Add `_no_dispatch_record_stream` to disable TorchDispatch before calling `record_stream()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88014
Approved by: https://github.com/awgu
2022-11-03 23:15:56 +00:00
Andrew Gu
95a9721a15 [FSDP()][Easy] Rename _State to _FSDPState (#88234)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88234
Approved by: https://github.com/mrshenli
2022-11-03 11:29:01 +00:00
Andrew Gu
6c858e3727 [FSDP][Easy] Remove unneeded TrainingState transition (#88232)
Follow-up from previous PR in the stack
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88232
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
32d22edc67 [FSDP()][27/N] Add forward hook registration (#88040)
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88040
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
2022-11-02 23:25:53 +00:00
Andrew Gu
30dc6cee3a [FSDP()][26/N] Move _lazy_init() into _fsdp_root_pre_forward() (#87941)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87941
Approved by: https://github.com/mrshenli
2022-11-02 17:45:08 +00:00
Andrew Gu
f132c171ac [FSDP()][25/N] Add _post_forward_reshard() (#87940)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87940
Approved by: https://github.com/mrshenli
2022-11-02 17:16:30 +00:00
Andrew Gu
bf2819a836 [FSDP()][24/N] Refactor _lazy_init() (#87939)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87939
Approved by: https://github.com/zhaojuanmao
2022-11-02 16:35:47 +00:00
Andrew Gu
d172dcf316 [FSDP()][21/N] Refactor and fix _cast_buffers() (#87935)
This PR refactors and fixes `_cast_buffers()`.

**Before**
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.
- `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L700)
- There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L712-L717)
- Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly.

If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.

**After**
This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype.

Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names.

**Follow-Ups**
- We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up.
- We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87935
Approved by: https://github.com/mrshenli
2022-11-02 11:32:56 +00:00
Andrew Gu
19c7df89fb [FSDP()][20/N][Easy] Move functions in file (#87932)
This PR is easy. I just wanted to group functions in the file according to the same logical order.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87932
Approved by: https://github.com/mrshenli
2022-11-02 11:32:48 +00:00
Andrew Gu
4635f56da1 [FSDP()][18/N] Refactor pre_forward_unshard() (#87931)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87931
Approved by: https://github.com/mrshenli
2022-11-02 11:32:45 +00:00
Andrew Gu
0a752688bd [FSDP()][17/N] Refactor _fsdp_root_pre_forward() (#87930)
This PR moves `_fsdp_root_pre_forward()` to `_runtime_utils.py`.

Note: This PR includes a (temporary) fix for `NO_SHARD` + `CPUOffload(offload_params=True)`, where we set `non_blocking=False` when copying the gradient from device to host. It is only included in this PR since the test was **flaky** (but not consistently failing) on this PR , so I needed to fix to unblock land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87930
Approved by: https://github.com/mrshenli
2022-11-02 11:32:42 +00:00
Andrew Gu
1f34067e9d [FSDP()][16/N] Refactor post-forward/pre-backward (#87929)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87929
Approved by: https://github.com/mrshenli
2022-11-01 17:26:03 +00:00
Andrew Gu
90c5f856b2 [FSDP()][14/N] Refactor pre-forward/post-backward (#87927)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87927
Approved by: https://github.com/mrshenli
2022-11-01 17:25:59 +00:00
Andrew Gu
b1750d0440 [FSDP()][13/N] Refactor unshard/reshard/grads (#87926)
This PR is not too complicated. We just move unshard/reshard/grads out to `_runtime_utils.py` and make them take `state: _State` instead of `self`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87926
Approved by: https://github.com/mrshenli
2022-11-01 13:37:31 +00:00
Andrew Gu
cbc9faebfe [FSDP()][1/N] Start refactoring FSDP root pre-forward (#87915)
Welcome! This PR starts the refactoring journey.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87915
Approved by: https://github.com/mrshenli
2022-10-29 06:50:30 +00:00