mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
**Why this PR?** For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP. It will be useful to have APIs like: `_get_module_state(module)`: return the composable state if this module is managed by composable API. `_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP. **What does this PR propose?** 1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it. 2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state. 3. Create `_get_module_state(module)` to look up `_module_state_mapping`. 4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147 Approved by: https://github.com/awgu
33 lines
949 B
Python
33 lines
949 B
Python
from typing import cast, Dict, Optional
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class _State:
|
|
pass
|
|
|
|
|
|
_module_state_mapping: Dict[nn.Module, _State] = {}
|
|
|
|
|
|
def _insert_module_state(module: nn.Module, state: _State) -> None:
|
|
global _module_state_mapping
|
|
assert module not in _module_state_mapping, f"Inserting {module} more than once."
|
|
_module_state_mapping[module] = state
|
|
|
|
|
|
def _get_module_state(module: nn.Module) -> Optional[_State]:
|
|
"""
|
|
Given a ``module``, this API finds out if the module is also a ``_State``
|
|
instance or if the module is managed by a composable API. If the module
|
|
is also a ``_State``, ``module`` will be casted to ``_State` and returned.
|
|
If it is managed by a composable API, the corresponding ``_State`` will
|
|
be returned.
|
|
"""
|
|
|
|
global _module_state_mapping
|
|
if isinstance(module, _State):
|
|
return cast(_State, module)
|
|
else:
|
|
return _module_state_mapping.get(module, None)
|