pytorch/torch/distributed/fsdp/_common_utils.py
Andrew Gu d89cf2fdc9 [FSDP()][7/N] Refactor most of ctor (#87921)
The goal of this PR is to make one pass over the FSDP constructor and refactor each helper method call to not be `self.<...>`. Subsequent PRs will make further passes over the FSDP constructor.

This PR looks like a lot of lines of code change, but it is only reorganization. Methods are moved to `_init_utils.py` and `_common_utils.py`. This also marks the beginning of moving methods from `_utils.py` to `_common_utils.py` -- they will be coalesced eventually. I am only using `_common_utils.py` as a staging ground to include the methods that have been affected by the refactoring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87921
Approved by: https://github.com/mrshenli
2022-10-31 16:45:24 +00:00

144 lines
5.0 KiB
Python

"""
This file includes private common utilities for FSDP.
"""
from enum import auto, Enum
from typing import Callable, Dict, List
import torch
import torch.distributed.fsdp.flat_param as flat_param_file
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
)
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
FSDP_FLATTENED = "_fsdp_flattened"
class TrainingState(Enum):
"""
An enum that indicates the state of a ``FullyShardedDataParallel` instance.
"""
IDLE = auto()
FORWARD_BACKWARD = auto()
SUMMON_FULL_PARAMS = auto()
class HandleTrainingState(Enum):
"""
An enum that indicates the state of a ``FlatParamHandle`.
"""
IDLE = auto()
FORWARD = auto()
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
SUMMON_FULL_PARAMS = auto()
def clean_tensor_name(tensor_name: str) -> str:
"""
Cleans the parameter or buffer name by removing any module wrapper
prefixes.
"""
tensor_name = tensor_name.replace(FSDP_PREFIX, "")
# TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
# it couples `CheckpointWrapper` and FSDP and also does not scale for more
# module wrappers.
tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
return tensor_name
def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
"""
Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
avoid re-flattening it during nested construction.
"""
setattr(tensor, FSDP_FLATTENED, True)
def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
"""Returns if ``tensor`` has been marked as flattened by FSDP."""
return getattr(tensor, FSDP_FLATTENED, False)
def _get_param_to_unflat_param_names(
model: torch.nn.Module,
dedup_shared_params: bool = True,
) -> Dict[torch.nn.Parameter, List[str]]:
"""
Constructs a mapping from flattened parameter (including non-FSDP-module
parameters) to its unflattened parameter names. For non-FSDP-module
parameters, these mapped-to lists always contain a single element. The
unflattened parameter names should match the keys of the model state dict.
For shared parameters, only the first parameter name is included (following
the ``torch.nn.Module.parameters()`` order).
Args:
model (torch.nn.Module): Root module (which may or may not be a
:class:`FullyShardedDataParallel` instance).
dedup_shared_params (bool): If ``True``, only includes the first
list of unflattened parameter names corresponding to a parameter
in the module walk order; if ``False``, then includes all of the
unflattened parameter names.
"""
def module_fn(module, prefix, param_to_unflat_param_names):
for param_name, param in module.named_parameters(recurse=False):
module_prefixed_param_names = (
param._fqns
if type(param) is flat_param_file.FlatParameter
else [param_name]
) # prefixed from `module`
fully_prefixed_param_names = [
clean_tensor_name(prefix + name) for name in module_prefixed_param_names
] # fully prefixed from the top level including `prefix`
# If this parameter has already been visited, then it is a
# shared parameter; then, only take the first parameter name
is_shared_param = param in param_to_unflat_param_names
if not is_shared_param:
param_to_unflat_param_names[param] = fully_prefixed_param_names
elif not dedup_shared_params:
param_to_unflat_param_names[param].extend(fully_prefixed_param_names)
def return_fn(param_to_unflat_param_names):
return param_to_unflat_param_names
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
return _apply_to_modules(
model,
module_fn,
return_fn,
param_to_unflat_param_names,
)
def _apply_to_modules(
root_module: torch.nn.Module,
module_fn: Callable,
return_fn: Callable,
*args,
**kwargs,
):
"""
Performs a pre-order traversal of the modules in the hierarchy rooted at
``root_module``, applying ``module_fn`` at each module and finally
returning a value using ``return_fn``. The traversal constructs the full
module prefix name (e.g. "module.submodule." just like in model state dict)
and makes that available to ``module_fn``.
"""
def f(module: torch.nn.Module, prefix: str, *args, **kwargs):
# Call the module function before recursing over children (pre-order)
module_fn(module, prefix, *args, **kwargs)
for submodule_name, submodule in module.named_children():
if submodule is not None:
new_prefix = prefix + submodule_name + "."
f(submodule, new_prefix, *args, **kwargs)
f(root_module, "", *args, **kwargs)
return return_fn(*args, **kwargs)