mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module. This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed. Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224 Approved by: https://github.com/zhaojuanmao
477 lines
19 KiB
Python
477 lines
19 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# Copyright (c) Tongzhou Wang
|
|
# Licensed under the MIT License.
|
|
|
|
import contextlib
|
|
from itertools import accumulate
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Generator,
|
|
Iterator,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
)
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
from torch.distributed.utils import _replace_by_prefix
|
|
|
|
|
|
ParamOffset = Tuple[int, int]
|
|
SharedParamInfo = Tuple[str, str, nn.Module, str, nn.Module, str]
|
|
FLAT_PARAM = "flat_param"
|
|
FPW_MODULE = "_fpw_module"
|
|
|
|
|
|
def _post_state_dict_hook(
|
|
module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
_post_state_dict_hook() is called after the state_dict() is executed
|
|
and before returning the state_dict to the users.
|
|
This API post-processes the keys of the state_dict to remove the
|
|
FlattenParamsWrapper internal prefix.
|
|
"""
|
|
# Move everything from FPW_MODULE up one level.
|
|
_replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix)
|
|
return state_dict
|
|
|
|
|
|
def _pre_load_state_dict_hook(
|
|
state_dict: Dict[str, Any],
|
|
prefix: str,
|
|
*args: Any,
|
|
) -> None:
|
|
"""
|
|
_pre_load_state_dict_hook() is called before the _load_from_state_dict() is
|
|
executed. This API pre-processes the keys of the state_dict to add the
|
|
FlattenParamsWrapper internal prefix.
|
|
"""
|
|
# Push everything down to FPW_MODULE level.
|
|
_replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.")
|
|
# The flat_param_* keys actually needs to move one level up.
|
|
flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}"
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith(flat_param_key):
|
|
last_part = k.split(".")[-1]
|
|
assert last_part.startswith(
|
|
FLAT_PARAM
|
|
), f"Expected key to contain flat_param, but key name is {k}"
|
|
_replace_by_prefix(state_dict, k, prefix + last_part)
|
|
|
|
|
|
class ParamInfo(NamedTuple):
|
|
module_name: str
|
|
module: nn.Module
|
|
param_name: str
|
|
|
|
|
|
class ShardMetadata(NamedTuple):
|
|
param_names: List[str]
|
|
param_shapes: List[torch.Size]
|
|
param_numels: List[int]
|
|
param_offsets: List[ParamOffset]
|
|
|
|
|
|
class FlatParameter(nn.Parameter):
|
|
"""
|
|
A parameter that is initialized from a list of parameters. All the
|
|
parameters will be flattened and concatened to form the flat parameter.
|
|
|
|
Args:
|
|
params (Sequence[nn.Parameter])
|
|
The parameters to be flattend and concatened.
|
|
requires_grad (bool):
|
|
Set to True if gradients need to be computed for this parameter,
|
|
False otherwise.
|
|
"""
|
|
|
|
def __new__(
|
|
cls, params: Sequence[nn.Parameter], requires_grad: bool = True
|
|
) -> "FlatParameter":
|
|
"""Make an object using the parent's __new__ function."""
|
|
|
|
# A empty or non-list input doesn't make sense.
|
|
if not isinstance(params, (list, tuple)) or len(params) == 0:
|
|
raise ValueError("An non-empty list or tuple argument is needed")
|
|
|
|
# Normally, all items are Parameters. But during pickling, we will have a single
|
|
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
|
|
# are set.
|
|
if not all(isinstance(p, (nn.Parameter, Tensor)) for p in params):
|
|
incorrect_parameters = [
|
|
p for p in params if not isinstance(p, (nn.Parameter, Tensor))
|
|
]
|
|
raise ValueError(
|
|
f"List items need to be Parameter types {incorrect_parameters}"
|
|
)
|
|
|
|
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and
|
|
# (2) making a module hierarchy flat (using a single tensor to replace a tree of
|
|
# tensors). Therefore, adding back nesting and hierarchy is counter-productive.
|
|
# If nesting is encountered in the future, the reasonable thing to do is likely
|
|
# for the top level FlatParameter to absorb the nested one and keep the result flat,
|
|
# free from hierarchy.
|
|
if any(isinstance(p, FlatParameter) for p in params):
|
|
raise ValueError("Nesting FlatParameter is not supported")
|
|
|
|
data = torch.cat(
|
|
[
|
|
p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1)
|
|
for p in params
|
|
],
|
|
0,
|
|
)
|
|
|
|
return super(FlatParameter, cls).__new__(
|
|
cls, data, requires_grad=requires_grad
|
|
) # type: ignore[call-arg]
|
|
|
|
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
|
|
self._is_sharded = False
|
|
self._param_numels = [p.numel() for p in params]
|
|
# The total element numbers. This is equal to the summation of the
|
|
# ``numel()`` of all the parameters.
|
|
self.full_numel = sum(self._param_numels)
|
|
assert self.numel() <= self.full_numel, (
|
|
"Parameter numbers mismatched. "
|
|
f"The number of elements in FlatParameter: {self.numel()} vs. "
|
|
f"the number of elements in original parameters: {self.full_numel}."
|
|
)
|
|
# The shapes of each individual parameter.
|
|
self._param_shapes = [p.size() for p in params]
|
|
cumulative_sum = list(accumulate(self._param_numels))
|
|
begin = [0] + cumulative_sum[:-1]
|
|
end = [e - 1 for e in cumulative_sum]
|
|
|
|
self._param_infos: List[ParamInfo] = []
|
|
self._shared_param_infos: List[SharedParamInfo] = []
|
|
|
|
# The element offsets (begin/end pair) in the flat parameter of each
|
|
# individual parameter.
|
|
self._param_offsets = list(zip(begin, end))
|
|
# The indices (begin/end pair) of the parameters that are included in
|
|
# this FlatParameter. The default value is all the parameters because
|
|
# no sharding happen yet.
|
|
self._param_indice_in_shard = (0, len(self._param_infos) - 1)
|
|
# The offsets in each parameter that is included in the FlatParameter.
|
|
self._sharded_param_offsets: List[ParamOffset] = [
|
|
(0, numel) for numel in self._param_numels
|
|
]
|
|
# The number of padding elements.
|
|
self.num_padded = 0
|
|
|
|
def shard_by_offsets(self, start: int, end: int, num_padded: int) -> None:
|
|
assert self._is_sharded
|
|
if start < 0 or end < 0 or end < start:
|
|
raise ValueError(
|
|
f"Shard the flatten parameter with an invalid offset pair {(start, end)}."
|
|
)
|
|
_shard_size = end - start + 1
|
|
self.num_padded = num_padded
|
|
if self.num_padded > _shard_size:
|
|
raise ValueError("The number of padding is larger than the shard size.")
|
|
self._sharded_param_offsets.clear()
|
|
|
|
ranges = []
|
|
for idx, offset in enumerate(self._param_offsets):
|
|
if start > offset[1] or end < offset[0]:
|
|
continue
|
|
if start <= offset[0]:
|
|
sharded_param_start = 0
|
|
sharded_param_end = min(offset[1], end) - offset[0]
|
|
else:
|
|
sharded_param_start = start - offset[0]
|
|
sharded_param_end = min(offset[1], end) - offset[0]
|
|
ranges.append(idx)
|
|
self._sharded_param_offsets.append((sharded_param_start, sharded_param_end))
|
|
if ranges:
|
|
self._param_indice_in_shard = (ranges[0], ranges[-1])
|
|
|
|
def _offset_to_slice(self) -> slice:
|
|
if self._param_indice_in_shard[0] > self._param_indice_in_shard[1]:
|
|
return slice(0, 0)
|
|
return slice(self._param_indice_in_shard[0], self._param_indice_in_shard[1] + 1)
|
|
|
|
def get_param_views(
|
|
self, external_data: Optional[Tensor] = None
|
|
) -> Iterator[Tensor]:
|
|
"""Return a generator of views that map to the original parameters."""
|
|
# Note, self.data could be sharded, so its numel is <= to the sum.
|
|
assert (
|
|
self.data.numel() <= self.full_numel
|
|
), f"Incorrect internal state {self.data.numel()} vs. {self.full_numel}"
|
|
data = external_data if external_data is not None else self
|
|
if data.numel() != self.full_numel:
|
|
raise ValueError(
|
|
f"Incorrect numel of supplied data: got {data.numel()} but expected {self.full_numel}"
|
|
)
|
|
return (
|
|
t.view(s)
|
|
for (t, s) in zip(data.split(self._param_numels), self._param_shapes)
|
|
)
|
|
|
|
@property
|
|
def _num_unflattened_params(self) -> int:
|
|
"""Returns the number of unflattened parameters that comprise this
|
|
flattened parameter."""
|
|
assert hasattr(self, "_param_infos"), \
|
|
"`_param_infos` has not been set, meaning this `FlatParameter` " \
|
|
"has not been initialized yet"
|
|
num_unflat_params = len(self._param_infos)
|
|
assert num_unflat_params > 0, "`FlatParameter` corresponding to 0 " \
|
|
"unflattened parameters"
|
|
return num_unflat_params
|
|
|
|
@property
|
|
def param_info(self) -> List[ParamInfo]:
|
|
return self._param_infos
|
|
|
|
@property
|
|
def _param_names(self):
|
|
return [".".join([m, n]) if m else n for (m, _, n) in self._param_infos]
|
|
|
|
def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
|
|
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
|
|
return self._param_names, self._param_shapes, self._param_numels
|
|
|
|
def shard_metadata(
|
|
self,
|
|
) -> ShardMetadata:
|
|
"""
|
|
Return tuple of (names, shapes, numels) metadata for the sharded parameter
|
|
metadata of this flat parameter.
|
|
"""
|
|
return ShardMetadata(
|
|
self._param_names[self._offset_to_slice()],
|
|
self._param_shapes[self._offset_to_slice()],
|
|
self._param_numels[self._offset_to_slice()],
|
|
self._sharded_param_offsets[:],
|
|
)
|
|
|
|
|
|
class FlattenParamsWrapper(nn.Module):
|
|
"""
|
|
A wrapper for transparently flattening a Module's parameters.
|
|
The original implementation [1] reparameterizes a PyTorch module
|
|
that is called ReparamModule. The ReparamModule has only a flattened
|
|
parameter representing all parameters of the wrapped module.
|
|
Compared to the original implementation [1], this version:
|
|
- removes tracing
|
|
- supports shared parameters
|
|
- is renamed to FlattenParamsWrapper
|
|
[1] https://github.com/SsnL/PyTorch-Reparam-Module
|
|
Args:
|
|
module (nn.Module):
|
|
The module to wrap.
|
|
param_list (List[nn.Parameter]):
|
|
Only flatten parameters appearing in the given list.
|
|
Note, if only a single param is in the list, it still gets
|
|
flattened and the original param is removed and replaced
|
|
with the flatten one.
|
|
"""
|
|
|
|
def __init__(self, module: nn.Module, param_list: List[nn.Parameter]):
|
|
super().__init__()
|
|
self._fpw_module = module
|
|
# People may test whether this module contains parameters by using
|
|
# `getattr(module, "flat_param") is None`. This is not always accurate
|
|
# as the above condition is also true if this module is unflattened.
|
|
# `no_params` explicitly shows this module has no parameters and
|
|
# is always correct regardless flattened or unflattened.
|
|
self.no_params = True
|
|
self.flat_param = None
|
|
|
|
# Register hook to be called after state_dict() to remove the
|
|
# "_fpw_module." prefix and before load_state_dict() to add it back.
|
|
# The hooks must be registered even if the target param_list is empty as
|
|
# all submodules in FlattenParamsWrapper should be pre/post processed by
|
|
# the hooks.
|
|
self._register_state_dict_hook(_post_state_dict_hook)
|
|
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
|
|
|
|
if len(param_list) == 0:
|
|
return
|
|
|
|
# A list of parameters to be flatten
|
|
unique_param_list = set(param_list)
|
|
self.no_params = False
|
|
|
|
# convert from list of Parameters to set of (Module, parameter_name) tuples, which
|
|
# will survive in case the Parameter instances are reset.
|
|
# it includes (m, n) that points to the same parameter.
|
|
self.param_set = set()
|
|
for m in self.modules():
|
|
for n, p in m.named_parameters(recurse=False):
|
|
if p in unique_param_list:
|
|
self.param_set.add((m, n))
|
|
|
|
params, param_infos, shared_param_infos = self._init_flatten_params()
|
|
self.flat_param = FlatParameter(params, params[0].requires_grad)
|
|
self.flat_param._param_infos = param_infos
|
|
self.flat_param._shared_param_infos = shared_param_infos
|
|
|
|
# This attribute is used to remember the flat_param inside the unflatten_params()
|
|
# context. With this attribute, FSDP can access the flat parameter metadata
|
|
# even if flat_param is temporarily deleted.
|
|
# ``orig_flat_param` is a list to avoid being tracked by ``state_dict()``.
|
|
self.orig_flat_param: List[Optional[FlatParameter]] = [None]
|
|
self._flatten_params()
|
|
|
|
# Sanity check for the string constants.
|
|
assert getattr(self, FPW_MODULE) is self._fpw_module
|
|
assert getattr(self, FLAT_PARAM) is self.flat_param
|
|
|
|
@property
|
|
def module(self) -> Any:
|
|
"""Support _fsdp_wrapped_module.module in case we are immitating DDP, which has .module
|
|
property to the underlying module.
|
|
"""
|
|
return self._fpw_module
|
|
|
|
def _init_flatten_params(
|
|
self,
|
|
) -> Tuple[List[nn.Parameter], List[ParamInfo], List[SharedParamInfo]]:
|
|
"""Build metadata for need-to-be-flatten parameters and returns a list
|
|
contains the need-to-be-flatten parameters.
|
|
This also fills param_infos and shared_param_infos.
|
|
"""
|
|
param_infos: List[ParamInfo] = []
|
|
shared_param_infos = []
|
|
shared_param_memo: Dict[nn.Parameter, Tuple[str, nn.Module, str]] = {}
|
|
params = []
|
|
for module_name, m in self.named_modules():
|
|
for n, p in m.named_parameters(recurse=False):
|
|
if p is not None and (m, n) in self.param_set:
|
|
if p in shared_param_memo:
|
|
mname, shared_m, shared_n = shared_param_memo[p]
|
|
shared_param_infos.append(
|
|
(module_name, mname, m, n, shared_m, shared_n)
|
|
)
|
|
else:
|
|
shared_param_memo[p] = (module_name, m, n)
|
|
param_infos.append(ParamInfo(module_name, m, n))
|
|
params.append(p)
|
|
del shared_param_memo
|
|
|
|
assert (
|
|
len(set(p.dtype for p in params)) == 1
|
|
), "expects all parameters to have same dtype"
|
|
assert (
|
|
len(set(p.requires_grad for p in params)) == 1
|
|
), "expects all parameters to have same requires_grad"
|
|
assert len(params) == len(set(params)), "params list should not have dups"
|
|
|
|
return params, param_infos, shared_param_infos
|
|
|
|
def _flatten_params(self, external_data: Optional[FlatParameter] = None) -> None:
|
|
"""Flatten the managed parameters and replaced the original
|
|
attributes with views to the flat param. If `external_data`
|
|
is passed, it will be used as the flat_param.
|
|
"""
|
|
# register the flatten one
|
|
assert (
|
|
getattr(self, "flat_param", None) is not None or external_data is not None
|
|
), "Can not flatten params when both flat_param and external_data are None."
|
|
if external_data is not None:
|
|
self.flat_param = external_data
|
|
self.register_parameter("flat_param", self.flat_param)
|
|
|
|
assert self.flat_param is not None # avoid mypy complain.
|
|
# deregister the names as parameters
|
|
for _, m, n in self.flat_param._param_infos:
|
|
delattr(m, n)
|
|
for _, _, m, n, _, _ in self.flat_param._shared_param_infos:
|
|
delattr(m, n)
|
|
|
|
# register the views as plain attributes
|
|
self._unflatten_params_as_views()
|
|
|
|
def _unflatten_params_as_views(self) -> None:
|
|
"""Unlike ``_unflatten_params``, this function unflatten into views and keep
|
|
self.flat_param unchanged.
|
|
"""
|
|
assert (
|
|
self.flat_param is not None
|
|
), "Can not unflatten params as views when flat_param is None."
|
|
ps = self._get_param_views()
|
|
for (_, m, n), p in zip(self.flat_param._param_infos, ps):
|
|
setattr(m, n, p) # This will set as plain attr
|
|
|
|
for (_, _, m, n, shared_m, shared_n) in self.flat_param._shared_param_infos:
|
|
setattr(m, n, getattr(shared_m, shared_n))
|
|
|
|
def _unflatten_params(self) -> None:
|
|
"""Undo flattening and create separate parameters from the already flattened
|
|
self.flat_param.
|
|
"""
|
|
assert (
|
|
self.flat_param is not None
|
|
), "Can not unflatten params when flat_param is None."
|
|
ps = self._get_param_views()
|
|
for (_, m, n), p in zip(self.flat_param._param_infos, ps):
|
|
if hasattr(m, n):
|
|
delattr(m, n)
|
|
m.register_parameter(n, nn.Parameter(p))
|
|
for (_, _, m, n, shared_m, shared_n) in self.flat_param._shared_param_infos:
|
|
if hasattr(m, n):
|
|
delattr(m, n)
|
|
m.register_parameter(n, getattr(shared_m, shared_n))
|
|
|
|
del self.flat_param
|
|
|
|
@contextlib.contextmanager
|
|
def unflatten_params(self) -> Generator:
|
|
"""
|
|
Unflatten params. If the current instance is already unflattened, then
|
|
it will remain unflattened after the context manager exits.
|
|
"""
|
|
if getattr(self, "flat_param", None) is None:
|
|
yield
|
|
else:
|
|
self.orig_flat_param[0] = self.flat_param
|
|
self._unflatten_params()
|
|
# Put yield in a try...finally in case the caller catches the exception and handles
|
|
# it. In that case, we need to properly handle the undoing of state here.
|
|
try:
|
|
yield
|
|
finally:
|
|
self._flatten_params(self.orig_flat_param[0])
|
|
self.orig_flat_param[0] = None
|
|
|
|
def _get_param_views(
|
|
self, external_data: Optional[Tensor] = None
|
|
) -> Iterator[Tensor]:
|
|
"""Return a generator of views that map to the original parameters."""
|
|
assert self.flat_param is not None
|
|
return self.flat_param.get_param_views(external_data)
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
"""Forward missing attributes to wrapped module."""
|
|
try:
|
|
return super().__getattr__(name) # defer to nn.Module's logic
|
|
except AttributeError:
|
|
return getattr(self.module, name) # fallback to wrapped module
|
|
|
|
def __getitem__(self, key: int) -> Any:
|
|
"""Forward indexing calls in case the module is a nn.Sequential."""
|
|
return self.module.__getitem__(key)
|
|
|
|
def _unflatten_params_if_needed(self) -> None:
|
|
if self.flat_param is not None:
|
|
self._unflatten_params_as_views()
|
|
|
|
def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
|
|
self._unflatten_params_if_needed()
|
|
return self.module(*inputs, **kwinputs)
|