Commit Graph

11 Commits

Author SHA1 Message Date
zhouzaida
b51f92ebda [Docs] Fix docstring format (#99396)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99396
Approved by: https://github.com/awgu
2023-04-28 01:10:07 +00:00
Andrew Gu
803e30441f [FSDP][Docs] Per-device NCCL stream is per PG (#95705)
71ad1005f6/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp (L647-L649)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95705
Approved by: https://github.com/fegin
2023-03-07 13:38:03 +00:00
Chien-Chin Huang
4b0f1cc1ee [FSDP][optim_state_dict][10/N] Make optim_state_dict and optim_state_dict_to_load public (#92118)
Make optim_state_dict and optim_state_dict_to_load public APIs and consolidate them with state_dict by using the same state_dict_type to decide how to perform the optimizer state_dict save and load.

Differential Revision: [D42488022](https://our.internmc.facebook.com/intern/diff/D42488022/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92118
Approved by: https://github.com/rohan-varma
2023-02-02 08:04:20 +00:00
Andrew Gu
3305265962 [FSDP] Clarify MixedPrecision docs (#91974)
New docs:
![Screen Shot 2023-01-10 at 8 07 19 PM](https://user-images.githubusercontent.com/31054793/211694428-c8ebf210-85c5-4b8a-a174-ee8022d8b8fd.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91974
Approved by: https://github.com/zhaojuanmao
2023-01-12 03:41:58 +00:00
Yanli Zhao
9b144ddbe4 Make input casting in root module only in default (#91365)
Make input casting in root module only in default, meanwhile allowing to set different mixed precisions for different submodules
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91365
Approved by: https://github.com/awgu
2022-12-29 03:20:32 +00:00
Shen Li
80542add73 [FSDP] Allow MixedPrecision to skip inputs (#90620)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90620
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-12-11 06:39:38 +00:00
Rohan Varma
793a999ce0 Hybrid Sharded Data Parallel (#89915)
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- @awgu 's excellent prototype: 5ad3a16d48
- @liangluofb For ideation, feedback, and initial implementation and experimentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89915
Approved by: https://github.com/awgu
2022-12-08 16:18:03 +00:00
Chien-Chin Huang
324ac93a43 [FSDP][state_dict][2/N] Move state_dict related enums/dataclasses/states to state_dict_utils.py, api.py and init_state_dict() (#88481)
**Motivation**:
Several Enums, Dataclasses and states defined in fully_sharded_data_paralle.py should be moved to a place where the composable FSDP can access. This PR does the move.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88481
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-11-11 12:28:37 +00:00
Andrew Gu
ab8f3333ff [FSDP][Docs] Simplify mixed_precision ctor docs (#88429)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88429
Approved by: https://github.com/mrshenli
2022-11-03 23:15:32 +00:00
Andrew Gu
c87f0501ab [FSDP][Docs] Add note mentioning rate limiter for backward prefetch (#88120)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88120
Approved by: https://github.com/mrshenli
2022-11-02 23:25:53 +00:00
Andrew Gu
9d9267c6f7 [FSDP()][3/N] Refactor public APIs (#87917)
- This PR defines a new `api.py` meant to hold the public API for FSDP (minus `FullyShardedDataParallel` itself). This is needed because several of the `_<...>_utils.py` files rely on the public API, and we cannot import from `torch.distributed.fsdp.fully_sharded_data_parallel` without a circular import. Calling the file `api.py` follows the convention used by `ShardedTensor`.
- This PR cleans up the wording in the `BackwardPrefetch`, `ShardingStrategy`, `MixedPrecision`, and `CPUOffload` docstrings.
- This PR adds the aforementioned classes to `fsdp.rst` to have them rendered in public docs.
- To abide by the public bindings contract (`test_public_bindings.py`), the aforementioned classes are removed from `fully_sharded_data_parallel.py`'s `__all__`. This is technically BC breaking if someone uses `from torch.distributed.fsdp.fully_sharded_data_parallel import *`; however, that does not happen in any of our own external or internal code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87917
Approved by: https://github.com/mrshenli
2022-10-31 16:45:21 +00:00