pytorch/torch/distributed/_composable/replicate.py
Rohan Varma 51ff9ce997 [Replicate] Simplify code a bit (#98889)
Simplifies the code, such as making self.modules not a list and only a
single module.

Differential Revision: [D44899281](https://our.internmc.facebook.com/intern/diff/D44899281/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98889
Approved by: https://github.com/mrshenli, https://github.com/yhcharles
2023-04-13 03:21:06 +00:00

99 lines
3.1 KiB
Python

from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from .contract import _get_registry, contract
@contract()
def replicate(
module: nn.Module, # NOTE: contract now supports single module only
**kwargs,
) -> nn.Module:
r"""Replicates a module
Args:
module (torch.nn.Module): module to replicate
Example::
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> module = nn.Linear(3, 3)
>>> replicate(module)
"""
torch._C._log_api_usage_once("torch.distributed.replicate")
_ReplicateState().mark_module(module, **kwargs)
return module
def _can_compose(module: nn.Module) -> bool:
r"""Check if module is composable for `replicate` API."""
return "fully_shard" not in _get_registry(module)
class _ReplicateState:
def __init__(self) -> None:
self.module: Optional[nn.Module] = None
self.has_initialized: bool = False
self._param_list: nn.ParameterList = nn.ParameterList()
self.kwargs: dict = {}
def mark_module(self, module: nn.Module, **kwargs) -> None:
if not _can_compose(module):
raise AssertionError(
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
)
self.module = module
replicate.state(module)._params_collected = False
module.register_forward_pre_hook(self.forward_pre_hook)
# TODO(@yhcharles): fix type error
module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type]
self.kwargs = kwargs
def _recursive_collect_params(self, module: nn.Module) -> None:
# skip if managed by other APIs
if not _can_compose(module):
return
# skip if module parameters already collected
replicate_state = replicate.state(module)
# if replicate_state is None, `module` is a child module that has not been explicitly
# tagged as replicate().
if replicate_state is not None:
if hasattr(replicate_state, "_params_collected"):
if replicate_state._params_collected:
return
replicate_state._params_collected = True
self._param_list.extend(
param for param in module.parameters(recurse=False) if param.requires_grad
)
for child in module.children():
self._recursive_collect_params(child)
def init_helper(self) -> None:
if self.has_initialized:
return
self.has_initialized = True
self._recursive_collect_params(self.module) # type: ignore[arg-type]
self._ddp = DistributedDataParallel(self._param_list, **self.kwargs)
def forward_pre_hook(
self, module: nn.Module, input: Tuple[torch.Tensor, ...]
) -> None:
self.init_helper()
self._ddp._pre_forward()
def forward_post_hook(
self,
module: nn.Module,
input: Tuple[torch.Tensor],
output: torch.Tensor,
) -> torch.Tensor:
return self._ddp._post_forward(output)