mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297
Approved by: https://github.com/rohan-varma
301 lines
10 KiB
Python
301 lines
10 KiB
Python
"""
|
|
This file includes private common utilities for FSDP.
|
|
"""
|
|
|
|
import traceback
|
|
from enum import auto, Enum
|
|
from typing import (
|
|
Callable,
|
|
Dict,
|
|
Generator,
|
|
Iterable,
|
|
List,
|
|
no_type_check,
|
|
Optional,
|
|
Set,
|
|
)
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.fsdp.flat_param as flat_param_file
|
|
import torch.nn as nn
|
|
from torch.distributed._composable_state import _get_module_state, _State
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
_CHECKPOINT_PREFIX,
|
|
)
|
|
|
|
from .api import (
|
|
FullOptimStateDictConfig,
|
|
FullStateDictConfig,
|
|
OptimStateDictConfig,
|
|
ShardingStrategy,
|
|
StateDictConfig,
|
|
StateDictType,
|
|
)
|
|
|
|
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
|
|
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
|
|
FSDP_FLATTENED = "_fsdp_flattened"
|
|
|
|
|
|
class _FSDPState(_State):
|
|
def __init__(self) -> None:
|
|
# TODO: Move all the attributes to this class to enable typing for
|
|
# FSDP/fully_shard.
|
|
self._ignored_modules: Set[nn.Module] = set()
|
|
self._ignored_params: Set[nn.Parameter] = set()
|
|
self.process_group: Optional[dist.ProcessGroup] = None
|
|
self.rank: int = -1
|
|
self.world_size: int = -1
|
|
self.sharding_strategy = ShardingStrategy.FULL_SHARD
|
|
self._use_orig_params: bool = False
|
|
self.training_state = TrainingState.IDLE
|
|
self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
|
|
self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
|
|
self._state_dict_config: StateDictConfig = FullStateDictConfig()
|
|
self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
|
|
self._is_root: Optional[bool] = None
|
|
self._handles: List[flat_param_file.FlatParamHandle] = []
|
|
self._fully_sharded_module_to_handles: Dict[
|
|
nn.Module, flat_param_file.FlatParamHandle
|
|
] = {}
|
|
self.compute_device = torch.device("cuda", torch.cuda.current_device())
|
|
|
|
|
|
def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
|
|
state = _get_module_state(module)
|
|
if state is None or not isinstance(state, _FSDPState):
|
|
return None
|
|
return state
|
|
|
|
|
|
def _get_module_fsdp_state_if_fully_sharded_module(
|
|
module: nn.Module,
|
|
) -> Optional[_FSDPState]:
|
|
state = _get_module_fsdp_state(module)
|
|
if state is None:
|
|
return None
|
|
if state == module: # FullyShardedDataParallel module case.
|
|
return state
|
|
if module in state._fully_sharded_module_to_handles: # fully_shard case.
|
|
return state
|
|
return None
|
|
|
|
|
|
class TrainingState(Enum):
|
|
"""
|
|
An enum that indicates the state of a ``FullyShardedDataParallel` instance.
|
|
"""
|
|
|
|
IDLE = auto()
|
|
FORWARD_BACKWARD = auto()
|
|
SUMMON_FULL_PARAMS = auto()
|
|
|
|
|
|
class HandleTrainingState(Enum):
|
|
"""
|
|
An enum that indicates the state of a ``FlatParamHandle`.
|
|
"""
|
|
|
|
IDLE = auto()
|
|
FORWARD = auto()
|
|
BACKWARD_PRE = auto()
|
|
BACKWARD_POST = auto()
|
|
SUMMON_FULL_PARAMS = auto()
|
|
|
|
|
|
def _is_composable(state: _FSDPState):
|
|
# TODO: This is a temporary hack for differentiate between code paths.
|
|
return not isinstance(state, nn.Module)
|
|
|
|
|
|
@no_type_check
|
|
def _module_handles(state: _FSDPState, module: nn.Module) -> List:
|
|
"""
|
|
Returns the ``FlatParamHandle`` s corresponding to ``module``. These are
|
|
the handles that contain some parameter in ``module``.
|
|
"""
|
|
if _is_composable(state):
|
|
assert (
|
|
module in state._fully_sharded_module_to_handles
|
|
), f"Expects a `comm_module` but got {module} on rank {state.rank}"
|
|
return state._fully_sharded_module_to_handles[module][:]
|
|
else:
|
|
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
|
|
return module._handles[:]
|
|
|
|
|
|
@no_type_check
|
|
def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
|
|
"""Returns if ``module`` has parameters managed by FSDP."""
|
|
return len(_module_handles(state, module)) > 0
|
|
|
|
|
|
def _get_sharding_strategy(handles: Iterable):
|
|
"""
|
|
Returns the sharding strategy of the group of handles given by ``handles``
|
|
or ``None`` if ``handles`` is empty. The input should be the handles
|
|
corresponding to one module, so we enforce that they all share the same
|
|
sharding strategy.
|
|
"""
|
|
sharding_strategy = None
|
|
for handle in handles:
|
|
if sharding_strategy is None:
|
|
sharding_strategy = handle._sharding_strategy
|
|
elif (
|
|
sharding_strategy is not None
|
|
and sharding_strategy != handle._sharding_strategy
|
|
):
|
|
raise AssertionError(
|
|
"Expects each group of handles to have the same sharding "
|
|
f"strategy but got {sharding_strategy} and {handle._sharding_strategy}"
|
|
)
|
|
return sharding_strategy
|
|
|
|
|
|
def clean_tensor_name(tensor_name: str) -> str:
|
|
"""
|
|
Cleans the parameter or buffer name by removing any module wrapper
|
|
prefixes.
|
|
"""
|
|
tensor_name = tensor_name.replace(FSDP_PREFIX, "")
|
|
# TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
|
|
# it couples `CheckpointWrapper` and FSDP and also does not scale for more
|
|
# module wrappers.
|
|
tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
|
|
return tensor_name
|
|
|
|
|
|
def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
|
|
"""
|
|
Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
|
|
avoid re-flattening it during nested construction.
|
|
"""
|
|
setattr(tensor, FSDP_FLATTENED, True)
|
|
|
|
|
|
def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
|
|
"""Returns if ``tensor`` has been marked as flattened by FSDP."""
|
|
return getattr(tensor, FSDP_FLATTENED, False)
|
|
|
|
|
|
def _get_param_to_fqns(
|
|
model: torch.nn.Module,
|
|
dedup_shared_params: bool = True,
|
|
) -> Dict[nn.Parameter, List[str]]:
|
|
"""
|
|
Constructs a mapping from parameter to a list of its FQNs. Each normal
|
|
parameter maps to a singleton list containing its FQN, while each
|
|
``FlatParameter`` maps to a list of its original parameter FQNs, which may
|
|
have length greater than one. All FQNs are prefixed starting from
|
|
``model``.
|
|
|
|
Args:
|
|
model (torch.nn.Module): Root module (which may or may not be a
|
|
:class:`FullyShardedDataParallel` instance).
|
|
dedup_shared_params (bool): For shared parameters, if ``True``, only
|
|
includes the FQNs corresponding to the first encounter of the
|
|
shared parameter in the module traversal; if ``False``, then
|
|
includes the FQNs across all encounters. (Default: ``True``)
|
|
"""
|
|
|
|
def module_fn(module, prefix, param_to_fqns):
|
|
for param_name, param in module.named_parameters(recurse=False):
|
|
local_fqns = (
|
|
param._fqns
|
|
if type(param) is flat_param_file.FlatParameter
|
|
else [param_name]
|
|
) # prefixed from `module`
|
|
global_fqns = [
|
|
clean_tensor_name(prefix + name) for name in local_fqns
|
|
] # prefixed from the top level `model` (i.e. including `prefix`)
|
|
is_shared_param = param in param_to_fqns
|
|
if not is_shared_param:
|
|
param_to_fqns[param] = global_fqns
|
|
elif not dedup_shared_params:
|
|
param_to_fqns[param].extend(global_fqns)
|
|
|
|
def return_fn(param_to_fqns):
|
|
return param_to_fqns
|
|
|
|
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
|
|
return _apply_to_modules(
|
|
model,
|
|
module_fn,
|
|
return_fn,
|
|
param_to_unflat_param_names,
|
|
)
|
|
|
|
|
|
def _apply_to_modules(
|
|
root_module: torch.nn.Module,
|
|
module_fn: Callable,
|
|
return_fn: Callable,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Performs a pre-order traversal of the modules in the hierarchy rooted at
|
|
``root_module``, applying ``module_fn`` at each module and finally
|
|
returning a value using ``return_fn``. The traversal constructs the full
|
|
module prefix name (e.g. "module.submodule." just like in model state dict)
|
|
and makes that available to ``module_fn``.
|
|
"""
|
|
|
|
def f(module: torch.nn.Module, prefix: str, *args, **kwargs):
|
|
# Call the module function before recursing over children (pre-order)
|
|
module_fn(module, prefix, *args, **kwargs)
|
|
for submodule_name, submodule in module.named_children():
|
|
if submodule is not None:
|
|
new_prefix = prefix + submodule_name + "."
|
|
f(submodule, new_prefix, *args, **kwargs)
|
|
|
|
f(root_module, "", *args, **kwargs)
|
|
return return_fn(*args, **kwargs)
|
|
|
|
|
|
@no_type_check
|
|
def _assert_in_training_states(
|
|
state: _FSDPState,
|
|
training_states: List[TrainingState],
|
|
) -> None:
|
|
"""Asserts that FSDP is in the states ``_training_states``."""
|
|
# Raise a `ValueError` instead of using `assert` to ensure that these
|
|
# logical assertions run even if `assert`s are disabled
|
|
if state.training_state not in training_states:
|
|
msg = (
|
|
f"expected to be in states {training_states} but current state is "
|
|
f"{state.training_state}"
|
|
)
|
|
# Print the error on rank 0 in case this is called in the backward pass
|
|
if state.rank == 0:
|
|
if isinstance(state, nn.Module):
|
|
print(f"Asserting FSDP instance is: {state}")
|
|
print(f"ERROR: {msg}")
|
|
traceback.print_stack()
|
|
raise ValueError(msg)
|
|
|
|
|
|
def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
|
|
"""
|
|
Returns:
|
|
Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
|
|
parent-less) with respect to the modules in the set itself. In other
|
|
words, these are the modules in ``modules`` that are not the child of
|
|
any other module in ``modules``.
|
|
"""
|
|
root_modules: Set[nn.Module] = set()
|
|
module_to_submodules = {module: set(module.modules()) for module in modules}
|
|
for candidate_module in modules:
|
|
is_root_module = True
|
|
for module, submodules in module_to_submodules.items():
|
|
is_child_module = (
|
|
candidate_module is not module and candidate_module in submodules
|
|
)
|
|
if is_child_module:
|
|
is_root_module = False
|
|
break
|
|
if is_root_module:
|
|
root_modules.add(candidate_module)
|
|
return root_modules
|