pytorch/torch/distributed/fsdp/flatten_params_wrapper.py
Rohan Varma f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00

477 lines
19 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
import contextlib
from itertools import accumulate
from typing import (
Any,
Dict,
Generator,
Iterator,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
)
import torch
import torch.nn as nn
from torch import Tensor
from torch.distributed.utils import _replace_by_prefix
ParamOffset = Tuple[int, int]
SharedParamInfo = Tuple[str, str, nn.Module, str, nn.Module, str]
FLAT_PARAM = "flat_param"
FPW_MODULE = "_fpw_module"
def _post_state_dict_hook(
module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any
) -> Dict[str, Any]:
"""
_post_state_dict_hook() is called after the state_dict() is executed
and before returning the state_dict to the users.
This API post-processes the keys of the state_dict to remove the
FlattenParamsWrapper internal prefix.
"""
# Move everything from FPW_MODULE up one level.
_replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix)
return state_dict
def _pre_load_state_dict_hook(
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> None:
"""
_pre_load_state_dict_hook() is called before the _load_from_state_dict() is
executed. This API pre-processes the keys of the state_dict to add the
FlattenParamsWrapper internal prefix.
"""
# Push everything down to FPW_MODULE level.
_replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.")
# The flat_param_* keys actually needs to move one level up.
flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}"
for k in list(state_dict.keys()):
if k.startswith(flat_param_key):
last_part = k.split(".")[-1]
assert last_part.startswith(
FLAT_PARAM
), f"Expected key to contain flat_param, but key name is {k}"
_replace_by_prefix(state_dict, k, prefix + last_part)
class ParamInfo(NamedTuple):
module_name: str
module: nn.Module
param_name: str
class ShardMetadata(NamedTuple):
param_names: List[str]
param_shapes: List[torch.Size]
param_numels: List[int]
param_offsets: List[ParamOffset]
class FlatParameter(nn.Parameter):
"""
A parameter that is initialized from a list of parameters. All the
parameters will be flattened and concatened to form the flat parameter.
Args:
params (Sequence[nn.Parameter])
The parameters to be flattend and concatened.
requires_grad (bool):
Set to True if gradients need to be computed for this parameter,
False otherwise.
"""
def __new__(
cls, params: Sequence[nn.Parameter], requires_grad: bool = True
) -> "FlatParameter":
"""Make an object using the parent's __new__ function."""
# A empty or non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
raise ValueError("An non-empty list or tuple argument is needed")
# Normally, all items are Parameters. But during pickling, we will have a single
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
# are set.
if not all(isinstance(p, (nn.Parameter, Tensor)) for p in params):
incorrect_parameters = [
p for p in params if not isinstance(p, (nn.Parameter, Tensor))
]
raise ValueError(
f"List items need to be Parameter types {incorrect_parameters}"
)
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and
# (2) making a module hierarchy flat (using a single tensor to replace a tree of
# tensors). Therefore, adding back nesting and hierarchy is counter-productive.
# If nesting is encountered in the future, the reasonable thing to do is likely
# for the top level FlatParameter to absorb the nested one and keep the result flat,
# free from hierarchy.
if any(isinstance(p, FlatParameter) for p in params):
raise ValueError("Nesting FlatParameter is not supported")
data = torch.cat(
[
p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1)
for p in params
],
0,
)
return super(FlatParameter, cls).__new__(
cls, data, requires_grad=requires_grad
) # type: ignore[call-arg]
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
self._is_sharded = False
self._param_numels = [p.numel() for p in params]
# The total element numbers. This is equal to the summation of the
# ``numel()`` of all the parameters.
self.full_numel = sum(self._param_numels)
assert self.numel() <= self.full_numel, (
"Parameter numbers mismatched. "
f"The number of elements in FlatParameter: {self.numel()} vs. "
f"the number of elements in original parameters: {self.full_numel}."
)
# The shapes of each individual parameter.
self._param_shapes = [p.size() for p in params]
cumulative_sum = list(accumulate(self._param_numels))
begin = [0] + cumulative_sum[:-1]
end = [e - 1 for e in cumulative_sum]
self._param_infos: List[ParamInfo] = []
self._shared_param_infos: List[SharedParamInfo] = []
# The element offsets (begin/end pair) in the flat parameter of each
# individual parameter.
self._param_offsets = list(zip(begin, end))
# The indices (begin/end pair) of the parameters that are included in
# this FlatParameter. The default value is all the parameters because
# no sharding happen yet.
self._param_indice_in_shard = (0, len(self._param_infos) - 1)
# The offsets in each parameter that is included in the FlatParameter.
self._sharded_param_offsets: List[ParamOffset] = [
(0, numel) for numel in self._param_numels
]
# The number of padding elements.
self.num_padded = 0
def shard_by_offsets(self, start: int, end: int, num_padded: int) -> None:
assert self._is_sharded
if start < 0 or end < 0 or end < start:
raise ValueError(
f"Shard the flatten parameter with an invalid offset pair {(start, end)}."
)
_shard_size = end - start + 1
self.num_padded = num_padded
if self.num_padded > _shard_size:
raise ValueError("The number of padding is larger than the shard size.")
self._sharded_param_offsets.clear()
ranges = []
for idx, offset in enumerate(self._param_offsets):
if start > offset[1] or end < offset[0]:
continue
if start <= offset[0]:
sharded_param_start = 0
sharded_param_end = min(offset[1], end) - offset[0]
else:
sharded_param_start = start - offset[0]
sharded_param_end = min(offset[1], end) - offset[0]
ranges.append(idx)
self._sharded_param_offsets.append((sharded_param_start, sharded_param_end))
if ranges:
self._param_indice_in_shard = (ranges[0], ranges[-1])
def _offset_to_slice(self) -> slice:
if self._param_indice_in_shard[0] > self._param_indice_in_shard[1]:
return slice(0, 0)
return slice(self._param_indice_in_shard[0], self._param_indice_in_shard[1] + 1)
def get_param_views(
self, external_data: Optional[Tensor] = None
) -> Iterator[Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
assert (
self.data.numel() <= self.full_numel
), f"Incorrect internal state {self.data.numel()} vs. {self.full_numel}"
data = external_data if external_data is not None else self
if data.numel() != self.full_numel:
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {self.full_numel}"
)
return (
t.view(s)
for (t, s) in zip(data.split(self._param_numels), self._param_shapes)
)
@property
def _num_unflattened_params(self) -> int:
"""Returns the number of unflattened parameters that comprise this
flattened parameter."""
assert hasattr(self, "_param_infos"), \
"`_param_infos` has not been set, meaning this `FlatParameter` " \
"has not been initialized yet"
num_unflat_params = len(self._param_infos)
assert num_unflat_params > 0, "`FlatParameter` corresponding to 0 " \
"unflattened parameters"
return num_unflat_params
@property
def param_info(self) -> List[ParamInfo]:
return self._param_infos
@property
def _param_names(self):
return [".".join([m, n]) if m else n for (m, _, n) in self._param_infos]
def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
return self._param_names, self._param_shapes, self._param_numels
def shard_metadata(
self,
) -> ShardMetadata:
"""
Return tuple of (names, shapes, numels) metadata for the sharded parameter
metadata of this flat parameter.
"""
return ShardMetadata(
self._param_names[self._offset_to_slice()],
self._param_shapes[self._offset_to_slice()],
self._param_numels[self._offset_to_slice()],
self._sharded_param_offsets[:],
)
class FlattenParamsWrapper(nn.Module):
"""
A wrapper for transparently flattening a Module's parameters.
The original implementation [1] reparameterizes a PyTorch module
that is called ReparamModule. The ReparamModule has only a flattened
parameter representing all parameters of the wrapped module.
Compared to the original implementation [1], this version:
- removes tracing
- supports shared parameters
- is renamed to FlattenParamsWrapper
[1] https://github.com/SsnL/PyTorch-Reparam-Module
Args:
module (nn.Module):
The module to wrap.
param_list (List[nn.Parameter]):
Only flatten parameters appearing in the given list.
Note, if only a single param is in the list, it still gets
flattened and the original param is removed and replaced
with the flatten one.
"""
def __init__(self, module: nn.Module, param_list: List[nn.Parameter]):
super().__init__()
self._fpw_module = module
# People may test whether this module contains parameters by using
# `getattr(module, "flat_param") is None`. This is not always accurate
# as the above condition is also true if this module is unflattened.
# `no_params` explicitly shows this module has no parameters and
# is always correct regardless flattened or unflattened.
self.no_params = True
self.flat_param = None
# Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back.
# The hooks must be registered even if the target param_list is empty as
# all submodules in FlattenParamsWrapper should be pre/post processed by
# the hooks.
self._register_state_dict_hook(_post_state_dict_hook)
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
if len(param_list) == 0:
return
# A list of parameters to be flatten
unique_param_list = set(param_list)
self.no_params = False
# convert from list of Parameters to set of (Module, parameter_name) tuples, which
# will survive in case the Parameter instances are reset.
# it includes (m, n) that points to the same parameter.
self.param_set = set()
for m in self.modules():
for n, p in m.named_parameters(recurse=False):
if p in unique_param_list:
self.param_set.add((m, n))
params, param_infos, shared_param_infos = self._init_flatten_params()
self.flat_param = FlatParameter(params, params[0].requires_grad)
self.flat_param._param_infos = param_infos
self.flat_param._shared_param_infos = shared_param_infos
# This attribute is used to remember the flat_param inside the unflatten_params()
# context. With this attribute, FSDP can access the flat parameter metadata
# even if flat_param is temporarily deleted.
# ``orig_flat_param` is a list to avoid being tracked by ``state_dict()``.
self.orig_flat_param: List[Optional[FlatParameter]] = [None]
self._flatten_params()
# Sanity check for the string constants.
assert getattr(self, FPW_MODULE) is self._fpw_module
assert getattr(self, FLAT_PARAM) is self.flat_param
@property
def module(self) -> Any:
"""Support _fsdp_wrapped_module.module in case we are immitating DDP, which has .module
property to the underlying module.
"""
return self._fpw_module
def _init_flatten_params(
self,
) -> Tuple[List[nn.Parameter], List[ParamInfo], List[SharedParamInfo]]:
"""Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters.
This also fills param_infos and shared_param_infos.
"""
param_infos: List[ParamInfo] = []
shared_param_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[str, nn.Module, str]] = {}
params = []
for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in self.param_set:
if p in shared_param_memo:
mname, shared_m, shared_n = shared_param_memo[p]
shared_param_infos.append(
(module_name, mname, m, n, shared_m, shared_n)
)
else:
shared_param_memo[p] = (module_name, m, n)
param_infos.append(ParamInfo(module_name, m, n))
params.append(p)
del shared_param_memo
assert (
len(set(p.dtype for p in params)) == 1
), "expects all parameters to have same dtype"
assert (
len(set(p.requires_grad for p in params)) == 1
), "expects all parameters to have same requires_grad"
assert len(params) == len(set(params)), "params list should not have dups"
return params, param_infos, shared_param_infos
def _flatten_params(self, external_data: Optional[FlatParameter] = None) -> None:
"""Flatten the managed parameters and replaced the original
attributes with views to the flat param. If `external_data`
is passed, it will be used as the flat_param.
"""
# register the flatten one
assert (
getattr(self, "flat_param", None) is not None or external_data is not None
), "Can not flatten params when both flat_param and external_data are None."
if external_data is not None:
self.flat_param = external_data
self.register_parameter("flat_param", self.flat_param)
assert self.flat_param is not None # avoid mypy complain.
# deregister the names as parameters
for _, m, n in self.flat_param._param_infos:
delattr(m, n)
for _, _, m, n, _, _ in self.flat_param._shared_param_infos:
delattr(m, n)
# register the views as plain attributes
self._unflatten_params_as_views()
def _unflatten_params_as_views(self) -> None:
"""Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""
assert (
self.flat_param is not None
), "Can not unflatten params as views when flat_param is None."
ps = self._get_param_views()
for (_, m, n), p in zip(self.flat_param._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (_, _, m, n, shared_m, shared_n) in self.flat_param._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n))
def _unflatten_params(self) -> None:
"""Undo flattening and create separate parameters from the already flattened
self.flat_param.
"""
assert (
self.flat_param is not None
), "Can not unflatten params when flat_param is None."
ps = self._get_param_views()
for (_, m, n), p in zip(self.flat_param._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
m.register_parameter(n, nn.Parameter(p))
for (_, _, m, n, shared_m, shared_n) in self.flat_param._shared_param_infos:
if hasattr(m, n):
delattr(m, n)
m.register_parameter(n, getattr(shared_m, shared_n))
del self.flat_param
@contextlib.contextmanager
def unflatten_params(self) -> Generator:
"""
Unflatten params. If the current instance is already unflattened, then
it will remain unflattened after the context manager exits.
"""
if getattr(self, "flat_param", None) is None:
yield
else:
self.orig_flat_param[0] = self.flat_param
self._unflatten_params()
# Put yield in a try...finally in case the caller catches the exception and handles
# it. In that case, we need to properly handle the undoing of state here.
try:
yield
finally:
self._flatten_params(self.orig_flat_param[0])
self.orig_flat_param[0] = None
def _get_param_views(
self, external_data: Optional[Tensor] = None
) -> Iterator[Tensor]:
"""Return a generator of views that map to the original parameters."""
assert self.flat_param is not None
return self.flat_param.get_param_views(external_data)
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name) # fallback to wrapped module
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key)
def _unflatten_params_if_needed(self) -> None:
if self.flat_param is not None:
self._unflatten_params_as_views()
def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_if_needed()
return self.module(*inputs, **kwinputs)