I prefer to not modify the module if it does not have any of our APIs applied. The side effect of inserting a registry on the module when calling a getter is non-intuitive to me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113654
Approved by: https://github.com/fegin
**Why this PR?**
For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.
It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.
**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87649
Approved by: https://github.com/zhaojuanmao