pytorch/torch/distributed/tensor/parallel/ddp.py
Iris Zhang 72ce5dd13e [2D] Remove enable_2d_with_fsdp() API and make remove_enable_2d_with_fsdp private (#112473)
As we have our new 2D flow out, we want to remove `enable_2d_with_fsdp()`.
In addition, we change pre_dp_module_transform to private, as we may need to change the UX later on.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112473
Approved by: https://github.com/fegin, https://github.com/wanchaol
2023-11-16 01:14:00 +00:00

97 lines
3.4 KiB
Python

from typing import Any, List, Tuple
import torch.nn as nn
from torch.distributed.tensor.parallel._data_parallel_utils import (
_flatten_tensor,
_unflatten_tensor,
)
__all__ = [] # type: ignore[var-annotated]
def _get_submodule_n_params(module: nn.Module, path: str):
"""
Get submodule and the direct path of parameter from the module
"""
if "." in path:
path_list = path.split(".")
parent_module_path = ".".join(path_list[:-1])
module = module.get_submodule(parent_module_path)
path = path_list[-1]
return module, path
def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]):
"""
Update parameters within the module
"""
for item in param_list:
parent_module, module_path, t = item
assert hasattr(parent_module, module_path)
delattr(parent_module, module_path)
setattr(parent_module, module_path, t)
def _reconstruct_dtensor(module: nn.Module, _input: Any):
"""
Recontruct DTensor parameters from local tensors
"""
param_list = []
# TODO: To add perf optimizations to this iterations
for name, t in module.named_parameters():
if hasattr(t, "_st_info"):
dtensor = _unflatten_tensor(t, t._st_info)
param_list.append((*_get_submodule_n_params(module, name), dtensor))
_update_module_param(param_list) # type: ignore[arg-type]
def _localize_dtensor(module: nn.Module, *_: Any):
"""
Convert DTensor parameters to local tensors
"""
param_list = []
for name, param in module.named_parameters():
t, sharding_info = _flatten_tensor(param)
if sharding_info is not None:
t = nn.Parameter(t)
t._st_info = sharding_info # type: ignore[attr-defined]
param_list.append((*_get_submodule_n_params(module, name), t))
_update_module_param(param_list) # type: ignore[arg-type]
def _pre_dp_module_transform(module: nn.Module):
"""
Enable the composability between Tensor Parallelism (TP) and Data
Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which
are DTensors to local tensors before wrapping with data parallelism API.
We then register two hooks, one for converting local tensors back to DTensor
preforward and one to convert DTensors back to tensors after Forward. By
integrating this way, we avoid any special handling of DTensor parameters by DDP
and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP.
For now, this API only works with ``DistributedDataParallel``. It will later support
other DP methods such as FSDP.
Args:
module (:class:`nn.Module`):
Module which has been applied TP on.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
>>>
>>> # Define the module.
>>> m = module(...)
>>> parallelize_module(m, PairwiseParallel())
>>> m = pre_dp_module_transform(m)
>>> m = DDP(m)
>>>
"""
_localize_dtensor(module, None, None)
# TODO: To add test cases and ensure that it works for nested modules
module.register_forward_pre_hook(_reconstruct_dtensor)
module.register_forward_hook(_localize_dtensor)