pytorch/torch/distributed/_composable/replicate.py
Chien-Chin Huang d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00

87 lines
2.5 KiB
Python

from typing import List, Tuple
import torch
import torch.nn as nn
from . import _ddp
from .contract import contract
class _ReplicateState:
def __init__(self) -> None:
self.modules: List[nn.Module] = []
self.has_initialized: bool = False
self._param_list: nn.ParameterList = nn.ParameterList()
self.kwargs: dict = {}
def mark_modules(self, *modules: nn.Module, **kwargs) -> None:
for module in modules:
self.modules.append(module)
replicate.state(module)._distributed_state = self
replicate.state(module)._params_collected = False
module.register_forward_pre_hook(self.forward_pre_hook)
# TODO(@yhcharles): fix type error
module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type]
self.kwargs = kwargs
def _recursive_collect_params(self, module: nn.Module) -> None:
# TODO: skip if managed by other APIs
if hasattr(replicate.state(module), "_params_collected"):
if replicate.state(module)._params_collected:
return
replicate.state(module)._params_collected = True
self._param_list.extend(
param
for param in module.parameters(recurse=False)
# for param in module.parameters()
if param.requires_grad
)
for child in module.children():
self._recursive_collect_params(child)
def init_helper(self) -> None:
if self.has_initialized:
return
self.has_initialized = True
for module in self.modules:
self._recursive_collect_params(module)
self._ddp = _ddp.DistributedDataParallel(
self._param_list, **self.kwargs
)
def forward_pre_hook(
self, module: nn.Module, input: Tuple[torch.Tensor]
) -> None:
self.init_helper()
self._ddp.pre_forward()
def forward_post_hook(
self,
module: nn.Module,
input: Tuple[torch.Tensor],
output: torch.Tensor,
) -> torch.Tensor:
return self._ddp.post_forward(output)
@contract()
def replicate(
module: nn.Module, # NOTE: contract now supports single module only
**kwargs,
) -> nn.Module:
r"""Replicates a module
Args:
module (torch.nn.Module): module to replicate
Example::
>>> module = nn.Linear(3, 3)
>>> replicate(module)
"""
_ReplicateState().mark_modules(module, **kwargs)
return module