Commit Graph

4 Commits

Author SHA1 Message Date
Chien-Chin Huang
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
Shen Li
a69cdd9cf8 Add global registry to composable API contract (#90579)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90579
Approved by: https://github.com/awgu, https://github.com/yhcharles
2022-12-10 22:41:10 +00:00
Charlie Yan
f3af5ba48e [WIP] Composable API: replicate and DistributedState (#87649)
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87649
Approved by: https://github.com/zhaojuanmao
2022-11-17 03:06:31 +00:00
Shen Li
ce3e0e9856 Add state to distributed composable API (#87838)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87838
Approved by: https://github.com/yhcharles
2022-10-28 13:31:40 +00:00