pytorch/torch/distributed/fsdp/_wrap_utils.py
Andrew Gu 5ee230face [FSDP][1/N] Refactor module materialization (#94196)
**Overview**
This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables.

This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.

**Discussion**
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.

**For Reviewers**
- `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`).
- `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94196
Approved by: https://github.com/rohan-varma
2023-02-13 21:43:00 +00:00

172 lines
7.1 KiB
Python

import collections
import functools
import warnings
from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple
import torch
import torch.nn as nn
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
from torch.distributed.fsdp._utils import (
_contains_batchnorm,
_override_batchnorm_mixed_precision,
)
from torch.distributed.fsdp.wrap import (
_FSDPPolicy,
_or_policy,
_recursive_wrap,
_wrap_batchnorm_individually,
)
class FullyShardedModuleState(NamedTuple):
"""
Module state for ``_get_fully_sharded_module_to_states()``, representing
a logical grouping (e.g. parameters to be flattened together).
"""
params: List[nn.Parameter]
buffers: List[torch.Tensor]
def _auto_wrap(
auto_wrap_kwargs: Dict[str, Any],
fsdp_kwargs: Dict[str, Any],
module_wrapper_cls: Any, # e.g. `FullyShardedDataParallel`
) -> None:
"""
Recursively auto wraps the root module given by the key "module" in
``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and
``fsdp_kwargs``.
Precondition: ``auto_wrap_policy`` contains the arguments expected by
``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``.
``fsdp_kwargs`` contains all FSDP arguments except ``module``.
"""
auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"]
# Support new way to pass an auto wrap policy
if isinstance(auto_wrap_policy, _FSDPPolicy):
auto_wrap_policy = auto_wrap_policy.policy
root_module = auto_wrap_kwargs["module"]
assert auto_wrap_policy is not None
# For auto wrapping, submodules should not already be wrapped with FSDP
# since double wrapping is not supported
for module_name, module in root_module.named_modules():
if isinstance(module, module_wrapper_cls):
raise ValueError(
f"Expected {module_name} to NOT be FullyShardedDataParallel "
"if using an `auto_wrap_policy`"
)
mixed_precision = fsdp_kwargs["mixed_precision"]
if mixed_precision is not None and _contains_batchnorm(root_module):
_override_batchnorm_mixed_precision(root_module)
auto_wrap_policy = functools.partial(
_or_policy, policies=[_wrap_batchnorm_individually, auto_wrap_policy]
)
warnings.warn(
"Both mixed precision and an `auto_wrap_policy` were specified "
"for FSDP, where the wrapped module has batch norm submodules. "
"The batch norm submodules will be wrapped as separate FSDP "
"instances with mixed precision disabled since some batch norm "
"kernels do not support low precision."
)
auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy
_recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs)
def _get_fully_sharded_module_to_states(
root_module: nn.Module,
auto_wrap_policy: _FSDPPolicy,
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
) -> Dict[nn.Module, FullyShardedModuleState]:
"""
Returns a mapping from fully sharded module to its parameters, buffers,
parameter names, and buffer names, where each entry logically represents a
grouping according to the given auto wrap policy and ignored
modules/parameters. However, this method does not actually perform any
module wrapping.
The mapped-to values are the states from the subtree rooted at the
corresponding submodule key, excluding child submodules in the mapping and
ignored state. Sibling submodules cannot be grouped together. The parameter
and buffer names are prefixed starting from the submodule.
Each non-ignored parameter and buffer appears exactly once in the returned
``dict``, and the ``dict`` is ordered by increasing tree depth. A mapped-to
parameter list may be empty if the fully sharded module has no parameters
or if its parameters were assigned to a parent fully sharded module
instead.
"""
# Record the modules to wrap without actually wrapping
wrapped_modules_set: Set[nn.Module] = set() # these are only logically wrapped
wrapper_cls = functools.partial(_record_module_wrapper_cls, wrapped_modules_set)
if auto_wrap_policy is not None:
_recursive_wrap(
root_module,
auto_wrap_policy=auto_wrap_policy.policy,
wrapper_cls=wrapper_cls,
ignored_modules=ignored_modules,
ignored_params=ignored_params,
only_wrap_children=False,
)
# Always include the root module even if not wrapped by the given policy
wrapped_modules_set.add(root_module)
fully_sharded_module_to_states = collections.OrderedDict()
visited_params = set()
for ignored_param in ignored_params:
visited_params.add(ignored_param)
visited_buffers = set()
# Construct `wrapped_modules` to follow `.modules()` order to ensure that
# downstream data structures (`._handles`) match those of the wrapper path.
# NOTE: Since `.modules()` follows a depth-first order, which is a
# topological sort, and we iterate over `wrapped_modules` following that
# order, parent-child shared parameters are assigned to the parent module.
wrapped_modules: List[nn.Module] = []
for module in root_module.modules():
if module in wrapped_modules_set:
wrapped_modules.append(module)
for submodule in wrapped_modules:
# Perform a DFS from `submodule` and record all unvisited state that is
# not already associated with another module in `wrapped_modules`. We
# use DFS to follow the `.modules()` order.
deque: Deque[Tuple[nn.Module, str]] = collections.deque()
deque.append((submodule, ""))
params: List[nn.Parameter] = []
buffers: List[torch.Tensor] = []
while len(deque) > 0:
module, prefix = deque.popleft()
# Reverse `named_children()`, use `appendleft()`, and add to the
# deque before processing to perform non-recursive DFS
for child_module_name, child_module in reversed(
list(module.named_children())
):
if child_module not in wrapped_modules_set:
deque.appendleft((child_module, prefix + child_module_name + "."))
for param in module.parameters(recurse=False):
if param not in visited_params and not _is_fsdp_flattened(param):
params.append(param)
visited_params.add(param)
for buffer in module.buffers(recurse=False):
if buffer not in visited_buffers:
buffers.append(buffer)
visited_buffers.add(buffer)
fully_sharded_module_to_states[submodule] = FullyShardedModuleState(
params, buffers
)
return fully_sharded_module_to_states
def _record_module_wrapper_cls(
wrapped_modules_set: Set[nn.Module],
module: nn.Module,
**kwargs,
) -> nn.Module:
"""
This defines a pseudo-wrapper class to be passed to ``_recursive_wrap()``
that records the wrapped module to the input ``wrapped_modules_set``
without actually wrapping with a class.
"""
wrapped_modules_set.add(module)
return module