This PR moves `_fsdp_root_pre_forward()` to `_runtime_utils.py`.
Note: This PR includes a (temporary) fix for `NO_SHARD` + `CPUOffload(offload_params=True)`, where we set `non_blocking=False` when copying the gradient from device to host. It is only included in this PR since the test was **flaky** (but not consistently failing) on this PR , so I needed to fix to unblock land.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87930
Approved by: https://github.com/mrshenli
This PR makes a second pass over the constructor. The logic has been grouped into `_init_<...>` functions based on intent (e.g. `_init_prefetching_state()` or `_init_runtime_state()`). This makes the initialization code for composable FSDP much cleaner than having to re-write the same sequences of lower-level helper calls.
This PR also moves `_ExecOrderData` into its own file `_exec_order_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87923
Approved by: https://github.com/mrshenli
The goal of this PR is to make one pass over the FSDP constructor and refactor each helper method call to not be `self.<...>`. Subsequent PRs will make further passes over the FSDP constructor.
This PR looks like a lot of lines of code change, but it is only reorganization. Methods are moved to `_init_utils.py` and `_common_utils.py`. This also marks the beginning of moving methods from `_utils.py` to `_common_utils.py` -- they will be coalesced eventually. I am only using `_common_utils.py` as a staging ground to include the methods that have been affected by the refactoring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87921
Approved by: https://github.com/mrshenli
- This PR defines a new `api.py` meant to hold the public API for FSDP (minus `FullyShardedDataParallel` itself). This is needed because several of the `_<...>_utils.py` files rely on the public API, and we cannot import from `torch.distributed.fsdp.fully_sharded_data_parallel` without a circular import. Calling the file `api.py` follows the convention used by `ShardedTensor`.
- This PR cleans up the wording in the `BackwardPrefetch`, `ShardingStrategy`, `MixedPrecision`, and `CPUOffload` docstrings.
- This PR adds the aforementioned classes to `fsdp.rst` to have them rendered in public docs.
- To abide by the public bindings contract (`test_public_bindings.py`), the aforementioned classes are removed from `fully_sharded_data_parallel.py`'s `__all__`. This is technically BC breaking if someone uses `from torch.distributed.fsdp.fully_sharded_data_parallel import *`; however, that does not happen in any of our own external or internal code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87917
Approved by: https://github.com/mrshenli
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`.
- At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision).
- At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87916
Approved by: https://github.com/mrshenli