Commit Graph

17 Commits

Author SHA1 Message Date
Andrew Gu
1ee189ce8e [FSDP] Issue warning when clamping to NO_SHARD (#90060)
Fixes https://github.com/pytorch/pytorch/issues/90050. I hope that this was not meant as an onboarding task :/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90060
Approved by: https://github.com/zhaojuanmao
2022-12-03 15:58:25 +00:00
Andrew Gu
d1760d7a42 [FSDP][Easy] Remove outdated TODO (#89217)
**Overview**
This PR removes an outdated TODO:
```
# TODO (awgu): When exposing the original parameters, we need to also
# use this attribute to prevent re-synchronizing parameters.
```

**Justification**
We only pass `managed_params` to `_sync_module_params_and_buffers()`, where `managed_params` is defined as
```
managed_params = list(_get_orig_params(root_module, state._ignored_params))
```
This `_get_orig_params()` call excludes parameters already flattened by FSDP. Thus, `_sync_module_params_and_buffers()` will not re-sync already-synchronized parameters. Each parameter appears in `managed_params` for some FSDP instance exactly once and hence is only synchronized once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89217
Approved by: https://github.com/mrshenli
2022-11-30 01:42:16 +00:00
Andrew Gu
d01bf1d1f1 [FSDP] Introduce ModuleWrapPolicy for simplicity (#88450)
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
2022-11-12 04:14:32 +00:00
Chien-Chin Huang
7aa144ac54 [FSDP][state_dict][5/N] Remove the FSDP module dependency from _state_dict_utils (#88637)
**What**
This PR completely removes the `FullyShardedDataParallel` dependency from `_state_dict_utils` -- `_state_dict_utils` now depends only on `_FSDPState` and all the utils modules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88637
Approved by: https://github.com/awgu
2022-11-11 21:22:13 +00:00
Chien-Chin Huang
324ac93a43 [FSDP][state_dict][2/N] Move state_dict related enums/dataclasses/states to state_dict_utils.py, api.py and init_state_dict() (#88481)
**Motivation**:
Several Enums, Dataclasses and states defined in fully_sharded_data_paralle.py should be moved to a place where the composable FSDP can access. This PR does the move.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88481
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-11-11 12:28:37 +00:00
Andrew Gu
95a9721a15 [FSDP()][Easy] Rename _State to _FSDPState (#88234)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88234
Approved by: https://github.com/mrshenli
2022-11-03 11:29:01 +00:00
Andrew Gu
73de44fc56 [FSDP] Rename unflat_param_name -> fqn for consistency (#88123)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88123
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
f35d5145a1 [FSDP] Simplify _get_buffer_names() (#88122)
This is a follow-up from a previous PR in this stack. The PR simplifies the `_get_buffer_names()` implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88122
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
572a3d2d6e [FSDP] Remove unneeded torch.no_grad() context when offloading to CPU (#88121)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88121
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
32d22edc67 [FSDP()][27/N] Add forward hook registration (#88040)
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88040
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
2022-11-02 23:25:53 +00:00
Andrew Gu
bf2819a836 [FSDP()][24/N] Refactor _lazy_init() (#87939)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87939
Approved by: https://github.com/zhaojuanmao
2022-11-02 16:35:47 +00:00
Andrew Gu
d172dcf316 [FSDP()][21/N] Refactor and fix _cast_buffers() (#87935)
This PR refactors and fixes `_cast_buffers()`.

**Before**
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.
- `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L700)
- There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L712-L717)
- Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly.

If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.

**After**
This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype.

Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names.

**Follow-Ups**
- We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up.
- We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87935
Approved by: https://github.com/mrshenli
2022-11-02 11:32:56 +00:00
Andrew Gu
5a53f024e4 [FSDP()][15/N] Refactor _init_streams() (#87928)
This PR is easy. I think I move `_init_streams()` again in a later PR though :/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87928
Approved by: https://github.com/mrshenli
2022-11-01 17:26:03 +00:00
Andrew Gu
8039317c07 [FSDP()][12/N] Easy cleanup (#87925)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87925
Approved by: https://github.com/mrshenli
2022-11-01 12:39:24 +00:00
Andrew Gu
c1e28731b3 [FSDP()][10/N][11/N] Introduce composable (ctor only) (#87924)
This PR introduces the composable FSDP API (with constructor semantics only) along with some further constructor refactoring. A notable contribution here is `_get_submodule_to_states()`, which performs auto wrapping without actually wrapping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87924
Approved by: https://github.com/mrshenli
2022-11-01 12:39:24 +00:00
Andrew Gu
78170701a3 [FSDP()][9/N] Refactor ctor (continued) (#87923)
This PR makes a second pass over the constructor. The logic has been grouped into `_init_<...>` functions based on intent (e.g. `_init_prefetching_state()` or `_init_runtime_state()`). This makes the initialization code for composable FSDP much cleaner than having to re-write the same sequences of lower-level helper calls.

This PR also moves `_ExecOrderData` into its own file `_exec_order_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87923
Approved by: https://github.com/mrshenli
2022-11-01 12:39:21 +00:00
Andrew Gu
d89cf2fdc9 [FSDP()][7/N] Refactor most of ctor (#87921)
The goal of this PR is to make one pass over the FSDP constructor and refactor each helper method call to not be `self.<...>`. Subsequent PRs will make further passes over the FSDP constructor.

This PR looks like a lot of lines of code change, but it is only reorganization. Methods are moved to `_init_utils.py` and `_common_utils.py`. This also marks the beginning of moving methods from `_utils.py` to `_common_utils.py` -- they will be coalesced eventually. I am only using `_common_utils.py` as a staging ground to include the methods that have been affected by the refactoring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87921
Approved by: https://github.com/mrshenli
2022-10-31 16:45:24 +00:00