""" This file includes private common utilities for FSDP. """ from enum import auto, Enum from typing import Callable, Dict, List import torch import torch.distributed.fsdp.flat_param as flat_param_file from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_PREFIX, ) FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" FSDP_PREFIX = FSDP_WRAPPED_MODULE + "." FSDP_FLATTENED = "_fsdp_flattened" class TrainingState(Enum): """ An enum that indicates the state of a ``FullyShardedDataParallel` instance. """ IDLE = auto() FORWARD_BACKWARD = auto() SUMMON_FULL_PARAMS = auto() class HandleTrainingState(Enum): """ An enum that indicates the state of a ``FlatParamHandle`. """ IDLE = auto() FORWARD = auto() BACKWARD_PRE = auto() BACKWARD_POST = auto() SUMMON_FULL_PARAMS = auto() def clean_tensor_name(tensor_name: str) -> str: """ Cleans the parameter or buffer name by removing any module wrapper prefixes. """ tensor_name = tensor_name.replace(FSDP_PREFIX, "") # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as # it couples `CheckpointWrapper` and FSDP and also does not scale for more # module wrappers. tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "") return tensor_name def _set_fsdp_flattened(tensor: torch.Tensor) -> None: """ Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to avoid re-flattening it during nested construction. """ setattr(tensor, FSDP_FLATTENED, True) def _is_fsdp_flattened(tensor: torch.Tensor) -> bool: """Returns if ``tensor`` has been marked as flattened by FSDP.""" return getattr(tensor, FSDP_FLATTENED, False) def _get_param_to_unflat_param_names( model: torch.nn.Module, dedup_shared_params: bool = True, ) -> Dict[torch.nn.Parameter, List[str]]: """ Constructs a mapping from flattened parameter (including non-FSDP-module parameters) to its unflattened parameter names. For non-FSDP-module parameters, these mapped-to lists always contain a single element. The unflattened parameter names should match the keys of the model state dict. For shared parameters, only the first parameter name is included (following the ``torch.nn.Module.parameters()`` order). Args: model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance). dedup_shared_params (bool): If ``True``, only includes the first list of unflattened parameter names corresponding to a parameter in the module walk order; if ``False``, then includes all of the unflattened parameter names. """ def module_fn(module, prefix, param_to_unflat_param_names): for param_name, param in module.named_parameters(recurse=False): module_prefixed_param_names = ( param._fqns if type(param) is flat_param_file.FlatParameter else [param_name] ) # prefixed from `module` fully_prefixed_param_names = [ clean_tensor_name(prefix + name) for name in module_prefixed_param_names ] # fully prefixed from the top level including `prefix` # If this parameter has already been visited, then it is a # shared parameter; then, only take the first parameter name is_shared_param = param in param_to_unflat_param_names if not is_shared_param: param_to_unflat_param_names[param] = fully_prefixed_param_names elif not dedup_shared_params: param_to_unflat_param_names[param].extend(fully_prefixed_param_names) def return_fn(param_to_unflat_param_names): return param_to_unflat_param_names param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {} return _apply_to_modules( model, module_fn, return_fn, param_to_unflat_param_names, ) def _apply_to_modules( root_module: torch.nn.Module, module_fn: Callable, return_fn: Callable, *args, **kwargs, ): """ Performs a pre-order traversal of the modules in the hierarchy rooted at ``root_module``, applying ``module_fn`` at each module and finally returning a value using ``return_fn``. The traversal constructs the full module prefix name (e.g. "module.submodule." just like in model state dict) and makes that available to ``module_fn``. """ def f(module: torch.nn.Module, prefix: str, *args, **kwargs): # Call the module function before recursing over children (pre-order) module_fn(module, prefix, *args, **kwargs) for submodule_name, submodule in module.named_children(): if submodule is not None: new_prefix = prefix + submodule_name + "." f(submodule, new_prefix, *args, **kwargs) f(root_module, "", *args, **kwargs) return return_fn(*args, **kwargs)