mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578 Approved by: https://github.com/ezyang
2103 lines
94 KiB
Python
2103 lines
94 KiB
Python
import contextlib
|
|
import os
|
|
import warnings
|
|
from enum import auto, Enum
|
|
from itertools import accumulate, chain
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Generator,
|
|
Iterator,
|
|
List,
|
|
NamedTuple,
|
|
no_type_check,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.distributed._tensor import DTensor
|
|
from torch.distributed.fsdp._common_utils import (
|
|
_set_fsdp_flattened,
|
|
HandleTrainingState,
|
|
)
|
|
from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
|
|
|
|
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
|
|
from ._utils import _no_dispatch_record_stream, _same_storage
|
|
|
|
__all__ = [
|
|
"FlatParameter",
|
|
"FlatParamHandle",
|
|
"FlatParamShardMetadata",
|
|
"ParamInfo",
|
|
"SharedParamInfo",
|
|
"HandleShardingStrategy",
|
|
]
|
|
|
|
|
|
"""
|
|
[Note: Fully Sharded Module]
|
|
We define the "fully sharded module" to be the original ``nn.Module`` that owns
|
|
a ``FlatParamHandle``. It is the *single* module logically responsible for the
|
|
*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
|
|
forward or backward pass. The fully sharded module should be passed to the
|
|
``FlatParamHandle`` constructor.
|
|
|
|
For the wrapper code path:
|
|
- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
|
|
runs the unshard/reshard on behalf of the fully sharded module by overriding
|
|
``nn.Module.forward``.
|
|
- The fully sharded module is exactly the module passed to the
|
|
``FullyShardedDataParallel`` constructor's ``module`` argument.
|
|
|
|
For the non-wrapper code path:
|
|
- Hooks registered on the fully sharded module run the unshard/reshard.
|
|
- The fully sharded module may either be the direct argument to ``fully_shard``
|
|
or a submodule chosen by the provided wrapping policy.
|
|
"""
|
|
|
|
# Environment variable toggling whether to use unsafe `setattr()` for view
|
|
# setting in `_use_sharded_views()` and `_use_unsharded_views()`
|
|
# We should use 'safe' by default since it respects method overrides, but for
|
|
# special cases such as for high CPU overhead or for intentionally bypassing
|
|
# checks in the overrides, we may use 'unsafe'.
|
|
_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
|
|
|
|
|
|
class ParamInfo(NamedTuple):
|
|
"""Information for an original module parameter."""
|
|
|
|
param_name: str # unprefixed
|
|
module: nn.Module
|
|
module_name: str
|
|
|
|
|
|
class SharedParamInfo(NamedTuple):
|
|
"""
|
|
Additional information for a shared parameter.
|
|
|
|
For each shared parameter, we designate one module and its parameter
|
|
variable to be the primary owner, determined as the first one encountered
|
|
in the parameter walk. These are prefixed with "prim". The primary module
|
|
and parameter do not have their own :class:`SharedParamInfo` instance.
|
|
"""
|
|
|
|
param_name: str # unprefixed
|
|
module: nn.Module
|
|
module_name: str
|
|
prim_param_name: str # unprefixed
|
|
prim_module: nn.Module
|
|
prim_module_name: str
|
|
|
|
|
|
class FlatParamShardMetadata(NamedTuple):
|
|
"""
|
|
This holds metadata specific to this rank's shard of the flattened
|
|
parameter.
|
|
|
|
Attributes:
|
|
param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
|
|
shard of the parameters; see :class:`FlatParameter`.
|
|
param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
|
|
shard of the parameters; see :class:`FlatParameter`.
|
|
param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
|
|
of the parameters; see :class:`FlatParameter`.
|
|
param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
|
|
units of numels) giving this rank's part of each flattened
|
|
original module parameter.
|
|
"""
|
|
|
|
param_names: Tuple[str, ...]
|
|
param_shapes: Tuple[torch.Size, ...]
|
|
param_numels: Tuple[int, ...]
|
|
param_offsets: Tuple[Tuple[int, int], ...]
|
|
|
|
|
|
# TODO (awgu): Prefix these with "Handle" for now to avoid circular imports and
|
|
# inadvertent misuses; coalesce with those in fully_sharded_data_parallel.py
|
|
# later
|
|
class HandleShardingStrategy(Enum):
|
|
FULL_SHARD = auto()
|
|
SHARD_GRAD_OP = auto()
|
|
NO_SHARD = auto()
|
|
HYBRID_SHARD = auto()
|
|
_HYBRID_SHARD_ZERO2 = auto()
|
|
|
|
|
|
class FlatParameter(nn.Parameter):
|
|
"""
|
|
This is the flattened parameter used by :class:`FullyShardedDataParallel`.
|
|
It is comprised of one or more original parameters, which are flattened
|
|
and concatenated to construct the flattened parameter.
|
|
|
|
Under the current design, this parameter logically represents both the
|
|
unsharded and sharded flattened parameter, and its data changes storages
|
|
dynamically.
|
|
- In the :class:`FullyShardedDataParallel` constructor, the parameter
|
|
is initialized as unsharded and then sharded in-place.
|
|
- At runtime, the parameter is lazily (re)-initialized. The sharded
|
|
parameter data is saved in ``self._local_shard``, and a new ``Tensor``
|
|
``self._full_param_padded`` is created, which is the all-gather
|
|
destination and owns the unsharded parameter storage thereafter. (See
|
|
:meth:`FlatParamHandle.init_flat_param_attributes`.)
|
|
- Throughout runtime, the parameter data changes storages as needed,
|
|
e.g. to the sharded flattened parameter, reduced-precision sharded
|
|
flattened parameter, or the unsharded flattened parameter.
|
|
|
|
Attributes:
|
|
_unpadded_unsharded_size (torch.Size): Unsharded flattened parameter's
|
|
size without padding.
|
|
_padded_unsharded_size (torch.Size): Unsharded flattened parameter's
|
|
size with padding. This is only set for sharded strategies since
|
|
they require padding for the all-gather.
|
|
_sharded_size (torch.Size): Sharded flattened parameter's size with
|
|
padding. This is also set for ``NO_SHARD``, in which case it is the
|
|
same as the unsharded sizes. (We omit "padded" because there is no
|
|
analogous unpadded one.)
|
|
|
|
_param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
|
|
entry; see :class:`ParamInfo`.
|
|
_numels (Tuple[int, ...]): Each parameter's numel.
|
|
_shapes (Tuple[torch.Size, ...]): Each parameter's shape.
|
|
_fqns (Tuple[str, ...]): The original parameters' FQNs prefixed from
|
|
the owning handle's ``_fully_sharded_module``. The names are
|
|
guaranteed to be unique within the subtree rooted at that module.
|
|
_num_params (int): Number of original parameters flattened into this
|
|
flattened parameter; this is the length of ``_param_infos``,
|
|
``_numels``, ``_shapes``, and ``_fqns``.
|
|
_shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
|
|
info entries; see :class:`SharedParamInfo`.
|
|
_param_extensions (Tuple[Optional[Any], ...]): Parameter extensions
|
|
(i.e. some per-parameter state) used to customize pre-flatten and
|
|
post-unflatten behavior. This is experimental, and users should not
|
|
depend on its existence in the future.
|
|
_modules (Set[nn.Module]): Modules that contain some original parameter
|
|
that is flattened into the ``FlatParameter``.
|
|
|
|
_shard_param_offsets (List[Tuple[int, int])): [start, end] offsets (in
|
|
units of numel) giving this rank's part of each flattened original
|
|
module parameter; for any parameter ``p`` that is not sharded
|
|
across ranks, this will be [0, ``p.numel()``-1].
|
|
_shard_indices (Tuple[int, int]): [start, end] indices (in units of
|
|
parameters) for this rank's shard of the original model parameters,
|
|
where the parameters follow the order in which they were originally
|
|
flattened; this indexes appropriately into any data structure that
|
|
follows the flattening order (e.g. ``_param_infos``, ``_numels``,
|
|
etc.).
|
|
_shard_numel_padded (int): Numel padded for this rank's sharded
|
|
flattened parameter.
|
|
|
|
_local_shard (Tensor): Sharded flattened parameter with padding if
|
|
using a sharded strategy. If using ``NO_SHARD``, then this is the
|
|
unpadded unsharded flattened parameter, and there is no notion of a
|
|
sharded flattened parameter or padded unsharded flattened
|
|
parameter.
|
|
_full_param_padded (Tensor): Unsharded flattened parameter with
|
|
padding. This is not defined for ``NO_SHARD``. When using mixed
|
|
precision for parameters, this has the low precision.
|
|
_full_prec_full_param_padded (Tensor): Full precision unsharded
|
|
flattened parameter with padding. This is used for unsharding
|
|
outside of computation when using mixed precision for parameters.
|
|
This is never defined for ``NO_SHARD``.
|
|
_post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
|
|
Flattened parameter's :class:`AccumulateGrad` object and
|
|
post-backward hook handle.
|
|
_mp_shard (Tensor): Low precision sharded flattened parameter with
|
|
padding. This is only defined when parameter mixed precision is
|
|
enabled. For ``NO_SHARD``, this is used for computation.
|
|
_cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
|
|
This is only defined when offloading parameters is enabled.
|
|
_saved_grad_shard (Tensor): Sharded gradient with padding from previous
|
|
iterations for gradient accumulation without :meth:`no_sync`.
|
|
|
|
_params (Optional[List[nn.Parameter]]): The original parameter
|
|
variables if ``use_orig_params=True`` and ``None`` otherwise.
|
|
_shared_params (Optional[List[nn.Parameter]]): The original shared
|
|
parameter variables if ``use_orig_params=True`` and ``None``
|
|
otherwise.
|
|
_tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
|
|
views created in the forward and tracked by autograd when
|
|
``use_orig_params=True`` and is ``None`` otherwise. This is to
|
|
preserve those ``Tensor`` variables for the backward to ensure that
|
|
the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
|
|
in which case the post-backward hook does not run. This is relevant
|
|
for cases like reentrant activation checkpointing.
|
|
_is_grad_none (Optional[List[bool]]): A mask over the original
|
|
parameters' gradients indicating if it is logically ``None`` or not
|
|
if ``use_orig_params=True`` and ``None`` otherwise. This is needed
|
|
because only some of the parameters may have ``None`` gradient, in
|
|
which case the ``FlatParameter`` gradient must be non-``None`` and
|
|
must use zeros to approximate those original ``None`` gradients.
|
|
This mask informs FSDP to set the original parameter gradients to
|
|
``None`` (instead of zeros) as needed.
|
|
"""
|
|
|
|
def _init_metadata(
|
|
self,
|
|
param_infos: List[ParamInfo],
|
|
numels: List[int],
|
|
shapes: List[torch.Size],
|
|
fqns: List[str],
|
|
shared_param_infos: List[SharedParamInfo],
|
|
param_extensions: List[Any],
|
|
params: Optional[List[nn.Parameter]],
|
|
shared_params: Optional[List[nn.Parameter]],
|
|
) -> None:
|
|
"""
|
|
Initializes attributes holding metadata about the original parameters
|
|
comprising the flattened parameter.
|
|
|
|
We expose this method separate from the constructor to keep the
|
|
constructor only responsible for the flattened parameter's tensor data.
|
|
This method should only be called once per model, while the constructor
|
|
may be called multiple times, e.g. when reloading from a checkpoint, in
|
|
which case only the tensor data needs to be passed to the constructor.
|
|
Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
|
|
metadata is correctly assumed to be unchanged.
|
|
|
|
Args:
|
|
See the Attributes in the class docstring.
|
|
"""
|
|
assert len(param_infos) == len(numels)
|
|
assert len(param_infos) == len(shapes)
|
|
assert len(param_infos) == len(fqns)
|
|
assert len(param_infos) == len(param_extensions)
|
|
self._num_params = len(param_infos)
|
|
self._param_infos = tuple(param_infos)
|
|
self._numels = tuple(numels)
|
|
self._shapes = tuple(shapes)
|
|
self._fqns = tuple(fqns)
|
|
self._shared_param_infos = tuple(shared_param_infos)
|
|
self._param_extensions = tuple(param_extensions)
|
|
self._modules = {pi.module for pi in self._param_infos}.union(
|
|
{spi.module for spi in self._shared_param_infos}
|
|
)
|
|
assert (params is None) == (shared_params is None)
|
|
if params is not None:
|
|
assert shared_params is not None and len(shared_params) == len(
|
|
shared_param_infos
|
|
)
|
|
self._params: Optional[List[nn.Parameter]] = params
|
|
self._shared_params: Optional[List[nn.Parameter]] = shared_params
|
|
# Mark the original parameters to avoid flattening them into
|
|
# another `FlatParameter` during recursive construction
|
|
for param in chain(self._params, self._shared_params):
|
|
_set_fsdp_flattened(param)
|
|
self._is_grad_none: Optional[List[bool]] = [
|
|
False for _ in range(len(params))
|
|
]
|
|
self._tensors: Optional[List[Optional[Tensor]]] = [
|
|
None for _ in range(len(self._params))
|
|
]
|
|
else:
|
|
self._params = None
|
|
self._shared_params = None
|
|
self._is_grad_none = None
|
|
self._tensors = None
|
|
self._unpadded_unsharded_size = self.size()
|
|
_set_fsdp_flattened(self)
|
|
# Tracks whether the `FlatParameter`'s post-backward hook has been
|
|
# called to modify the behavior of the post-backward callback
|
|
self._post_backward_called = False
|
|
|
|
|
|
class FlatParamHandle:
|
|
"""
|
|
This handle manages a flattened parameter (:class:`FlatParameter`). This
|
|
includes sharding and view management.
|
|
|
|
Args:
|
|
params (Sequence[nn.Parameter]): The parameters to use for the
|
|
flattened parameter.
|
|
fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
|
|
device (torch.device): The compute and communication device, which
|
|
should be a non-CPU device. We refer to it as the compute device.
|
|
sharding_strategy (ShardingStrategy): Sharding strategy to apply to
|
|
this handle's ``FlatParameter``.
|
|
offload_params (bool): Whether to offload the handle's
|
|
``FlatParameter`` to CPU.
|
|
mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
|
|
setting passed to the FSDP constructor.
|
|
mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
|
|
precision setting passed to the FSDP constructor.
|
|
keep_low_precision_grads (bool): Whether to keep gradients in low
|
|
precision.
|
|
use_orig_params (bool): If ``True``, then FSDP preserves the original
|
|
parameter variables and returns them from ``named_parameters()``
|
|
(e.g. to support different optimizer hyperparameters within one
|
|
:class:`FlatParameter`). If ``False``, then FSDP reconstructs the
|
|
parameter every iteration and returns the :class:`FlatParameter` s
|
|
from ``named_parameters()``.
|
|
"""
|
|
|
|
##################
|
|
# INITIALIZATION #
|
|
##################
|
|
def __init__(
|
|
self,
|
|
params: Sequence[nn.Parameter],
|
|
fully_sharded_module: nn.Module,
|
|
device: torch.device,
|
|
sharding_strategy: HandleShardingStrategy,
|
|
offload_params: bool,
|
|
mp_param_dtype: Optional[torch.dtype],
|
|
mp_reduce_dtype: Optional[torch.dtype],
|
|
keep_low_precision_grads: bool,
|
|
process_group: dist.ProcessGroup,
|
|
use_orig_params: bool,
|
|
):
|
|
super().__init__()
|
|
use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
|
|
self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
|
|
self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
|
|
if use_unsafe_setattr:
|
|
self._setattr_tensor = _unsafe_setattr_tensor
|
|
self._setattr_param = _unsafe_setattr_param
|
|
else:
|
|
self._setattr_tensor = _safe_setattr_tensor_or_param
|
|
self._setattr_param = _safe_setattr_tensor_or_param
|
|
self.device = device
|
|
self.process_group = process_group
|
|
self.rank = process_group.rank()
|
|
self.world_size = process_group.size()
|
|
self._sharding_strategy = sharding_strategy
|
|
self._offload_params = offload_params
|
|
self._use_orig_params = use_orig_params
|
|
self._keep_low_precision_grads = keep_low_precision_grads
|
|
self._training_state = HandleTrainingState.IDLE
|
|
self._debug_level = dist.get_debug_level()
|
|
self._fully_sharded_module = fully_sharded_module
|
|
self._init_flat_param(params, fully_sharded_module, use_orig_params)
|
|
self._orig_param_dtype = self.flat_param.dtype
|
|
self._use_unsharded_views(as_params=False)
|
|
self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
|
|
|
|
def _init_flat_param(
|
|
self,
|
|
params: Sequence[Optional[nn.Parameter]],
|
|
module: nn.Module,
|
|
use_orig_params: bool,
|
|
) -> None:
|
|
"""
|
|
Initializes the flattened parameter ``self.flat_param`` by flattening
|
|
the parameters in ``params`` into a single :class:`FlatParameter` and
|
|
saves relevant metadata. Shared parameters are only included in the
|
|
flattened parameter once.
|
|
|
|
This checks that all comprising parameters have the same dtype and
|
|
``requires_grad`` and does not support nested construction of
|
|
:class:`FlatParameter` s.
|
|
|
|
Args:
|
|
See the Args in the class docstring.
|
|
"""
|
|
params_set = set(params)
|
|
params_set.discard(None)
|
|
if len(params_set) == 0:
|
|
raise ValueError(
|
|
"Cannot initialize a `FlatParameter` from an empty parameter list"
|
|
)
|
|
param_infos: List[ParamInfo] = []
|
|
numels: List[int] = []
|
|
shapes: List[torch.Size] = []
|
|
fqns: List[str] = []
|
|
shared_param_infos: List[SharedParamInfo] = []
|
|
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str, str]] = {}
|
|
params_to_flatten: List[Union[torch.Tensor, nn.Parameter]] = []
|
|
shared_params: List[Union[torch.Tensor, nn.Parameter]] = []
|
|
param_extensions: List[Any] = []
|
|
dtype: Optional[torch.dtype] = None
|
|
requires_grad: Optional[bool] = None
|
|
for submodule_name, submodule in module.named_modules():
|
|
for param_name, param in submodule.named_parameters(recurse=False):
|
|
if param not in params_set:
|
|
continue
|
|
if param in shared_param_memo: # shared reference
|
|
prim_module, prim_module_name, prim_param_name = shared_param_memo[
|
|
param
|
|
]
|
|
shared_params.append(param)
|
|
shared_param_infos.append(
|
|
SharedParamInfo(
|
|
param_name,
|
|
submodule,
|
|
submodule_name,
|
|
prim_param_name,
|
|
prim_module,
|
|
prim_module_name,
|
|
)
|
|
)
|
|
else:
|
|
if type(param) is FlatParameter:
|
|
raise ValueError("`FlatParameter` does not support nesting")
|
|
if dtype is not None and param.dtype != dtype:
|
|
raise ValueError(
|
|
"`FlatParameter` requires uniform dtype but got "
|
|
f"{dtype} and {param.dtype}"
|
|
)
|
|
if dtype is None and not param.is_floating_point():
|
|
raise ValueError("Integer parameters are unsupported")
|
|
if (
|
|
requires_grad is not None
|
|
and param.requires_grad != requires_grad
|
|
):
|
|
raise ValueError(
|
|
"`FlatParameter` requires uniform `requires_grad`"
|
|
)
|
|
param, extension = _ext_pre_flatten_transform(param)
|
|
param_extensions.append(extension)
|
|
dtype = param.dtype
|
|
requires_grad = param.requires_grad
|
|
shared_param_memo[param] = (submodule, submodule_name, param_name)
|
|
params_to_flatten.append(param)
|
|
param_infos.append(ParamInfo(param_name, submodule, submodule_name))
|
|
numels.append(param.numel())
|
|
shapes.append(param.shape)
|
|
fqn = (
|
|
submodule_name + "." + param_name
|
|
if submodule_name
|
|
else param_name
|
|
)
|
|
fqns.append(fqn)
|
|
assert requires_grad is not None, (
|
|
"Passed-in `params` were not found in the module tree\n"
|
|
f"params: {params}\nmodule: {module}"
|
|
)
|
|
self.flat_param = FlatParamHandle.flatten_params(
|
|
params_to_flatten, requires_grad
|
|
)
|
|
# For `use_orig_params=True`, ensure that the logical parameters are
|
|
# `nn.Parameter`s (and not plain `torch.Tensor`)
|
|
|
|
def convert_to_params(
|
|
tensors: List[Union[torch.Tensor, nn.Parameter]]
|
|
) -> List[nn.Parameter]:
|
|
return [
|
|
t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors
|
|
]
|
|
|
|
self.flat_param._init_metadata(
|
|
param_infos,
|
|
numels,
|
|
shapes,
|
|
fqns,
|
|
shared_param_infos,
|
|
param_extensions,
|
|
convert_to_params(params_to_flatten) if use_orig_params else None,
|
|
convert_to_params(shared_params) if use_orig_params else None,
|
|
)
|
|
|
|
@staticmethod
|
|
def flatten_params(
|
|
params: Sequence[torch.Tensor],
|
|
requires_grad: bool,
|
|
) -> FlatParameter:
|
|
"""
|
|
Flattens the parameters in ``params`` into a single
|
|
:class:`FlatParameter`. This should be the only way used to construct
|
|
:class:`FlatParameter` s.
|
|
|
|
We expose this factory method for checkpointing (e.g. sharded state
|
|
dict). The flattened parameter's metadata should only be initialized
|
|
once (see :meth:`_init_metadata`), but its tensor data may be reloaded.
|
|
"""
|
|
with torch.no_grad():
|
|
flat_params = [
|
|
p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1)
|
|
for p in params
|
|
]
|
|
flat_param_data = torch.cat(flat_params, dim=0)
|
|
flat_param = FlatParameter(flat_param_data, requires_grad=requires_grad)
|
|
return flat_param
|
|
|
|
def _init_param_reduce_dtypes(
|
|
self,
|
|
mp_param_dtype: Optional[torch.dtype],
|
|
mp_reduce_dtype: Optional[torch.dtype],
|
|
) -> None:
|
|
"""
|
|
Precondition: ``self.flat_param`` is set via :meth:`_init_flat_param`.
|
|
This ensures that this handle's parameters have a single dtype.
|
|
|
|
Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
|
|
``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
|
|
is ``None``, then we assume the original parameter dtype. One special
|
|
case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
|
|
is ``None``, in which case we assume the gradient reduction dtype
|
|
matches the forward/backward parameter dtype.
|
|
"""
|
|
# Save whether these dtypes were specified so that we permit the
|
|
# parameter dtype to change up until the lazy initialization
|
|
self._low_prec_param_dtype_specified = mp_param_dtype is not None
|
|
self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
|
|
if (
|
|
self._low_prec_param_dtype_specified
|
|
and not self._low_prec_reduce_dtype_specified
|
|
):
|
|
# Special case: infer gradient reduction mixed precision
|
|
self._fwd_bwd_param_dtype = mp_param_dtype
|
|
self._reduce_dtype = self._fwd_bwd_param_dtype
|
|
else:
|
|
self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
|
|
self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
|
|
assert self._fwd_bwd_param_dtype is not None
|
|
assert self._reduce_dtype is not None
|
|
|
|
###################################
|
|
# SHARD INITIALIZATION & METADATA #
|
|
###################################
|
|
@torch.no_grad()
|
|
def shard(self):
|
|
"""
|
|
Shards the handle's ``FlatParameter``. In terms of memory, this
|
|
allocates new memory for the sharded flattened parameter and frees the
|
|
unsharded flattened parameter's storage.
|
|
|
|
Postcondition: ``self.flat_param`` is the sharded flattened parameter.
|
|
Shard metadata attributes are set for all sharding strategies.
|
|
``process_group``, ``rank``, and ``world_size`` attributes are set if
|
|
using a sharded strategy.
|
|
"""
|
|
flat_param = self.flat_param
|
|
if not self.uses_sharded_strategy:
|
|
self._init_shard_metadata(0, 0, flat_param.numel() - 1)
|
|
else:
|
|
_p_assert(
|
|
flat_param.storage_offset() == 0,
|
|
"The `FlatParameter` is not the sole occupant of its storage",
|
|
)
|
|
orig_storage = flat_param._typed_storage()
|
|
sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
|
|
flat_param, self.rank, self.world_size
|
|
)
|
|
flat_param.set_(sharded_flat_param) # type: ignore[call-overload]
|
|
start = sharded_flat_param.numel() * self.rank
|
|
end = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
|
|
self._init_shard_metadata(numel_padded, start, end)
|
|
if orig_storage._size() > 0:
|
|
orig_storage._resize_(0)
|
|
if self._use_orig_params:
|
|
self._use_sharded_views()
|
|
|
|
def _init_shard_metadata(
|
|
self,
|
|
numel_padded: int,
|
|
start: int,
|
|
end: int,
|
|
) -> None:
|
|
"""
|
|
Initializes shard-related metadata for this rank's shard of the
|
|
flattened parameter: ``_sharded_size``, ``_shard_param_offsets``,
|
|
``_shard_indices``, and ``_shard_numel_padded``.
|
|
|
|
Args:
|
|
numel_padded (int): Numel padded for this rank's sharded flattened
|
|
parameter.
|
|
start (int): Start index in the sharded flattened parameter
|
|
assigned to this rank.
|
|
end (int): End index (inclusive) in the sharded flattened parameter
|
|
assigned to this rank. If this exceeds the sharded flattened
|
|
parameter's numel, then it is truncated.
|
|
|
|
Precondition: ``self.flat_param`` 's data is the sharded flattened
|
|
parameter.
|
|
"""
|
|
self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined]
|
|
sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded`
|
|
_p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
|
|
_p_assert(
|
|
numel_padded <= sharded_flat_param_numel,
|
|
f"numel_padded: {numel_padded} "
|
|
f"sharded_flat_param_numel: {sharded_flat_param_numel}",
|
|
)
|
|
(
|
|
self.flat_param._shard_param_offsets, # type: ignore[attr-defined]
|
|
self.flat_param._shard_indices, # type: ignore[attr-defined]
|
|
) = self._get_shard_metadata(start, end)
|
|
self.flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
|
|
|
|
def _get_shard_metadata(
|
|
self,
|
|
start: int,
|
|
end: int,
|
|
) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]:
|
|
"""
|
|
Computes the shard metadata based on ``start`` and ``end``, which give
|
|
the closed interval of the unsharded flattened parameter specifying the
|
|
shard.
|
|
|
|
Args:
|
|
start (int): Start index (in units of numel) of this rank's shard
|
|
of the flattened parameter.
|
|
end (int): End index (in units of numel and inclusive) of this
|
|
rank's shard of the flattened parameter.
|
|
|
|
Return:
|
|
Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]: See
|
|
``_shard_param_offsets`` and ``_shard_indices`` in
|
|
:class:`FlatParameter` 's docstring.
|
|
"""
|
|
flat_param_offsets = self._get_flat_param_offsets()
|
|
# Indices of the original parameters in this rank's sharded flattened
|
|
# parameter
|
|
shard_param_indices_range = [] # elements will be consecutive
|
|
# [start, end] offsets giving this rank's part of the flattened
|
|
# original module parameter (which will be [0, `p.numel()`-1] for any
|
|
# parameter that is not sharded across ranks)
|
|
shard_param_offsets = []
|
|
for i, (param_start, param_end) in enumerate(flat_param_offsets):
|
|
if start > param_end or end < param_start:
|
|
continue
|
|
if start <= param_start:
|
|
intra_param_start = 0
|
|
else:
|
|
intra_param_start = start - param_start
|
|
intra_param_end = min(param_end, end) - param_start
|
|
shard_param_indices_range.append(i)
|
|
shard_param_offsets.append(
|
|
(intra_param_start, intra_param_end)
|
|
) # both inclusive
|
|
if len(shard_param_indices_range) == 0:
|
|
shard_param_indices = (0, 0)
|
|
assert len(shard_param_offsets) == 0
|
|
else:
|
|
shard_param_indices = (
|
|
shard_param_indices_range[0],
|
|
shard_param_indices_range[-1],
|
|
)
|
|
assert (
|
|
len(shard_param_offsets)
|
|
== shard_param_indices[-1] - shard_param_indices[0] + 1
|
|
)
|
|
return tuple(shard_param_offsets), shard_param_indices
|
|
|
|
@staticmethod
|
|
def _get_unpadded_shard(
|
|
tensor: Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
) -> Tuple[Tensor, int]:
|
|
"""
|
|
Returns the shard of ``tensor`` without any padding for the given
|
|
``rank`` and ``world_size`` and the numel to pad for that shard.
|
|
|
|
If ``tensor`` is already flattened or may be viewed in the flattened
|
|
shape (which is true in the expected usage), then this method does not
|
|
allocate any new tensor memory.
|
|
"""
|
|
chunks = torch.flatten(tensor).chunk(world_size)
|
|
if len(chunks) < (rank + 1):
|
|
# This rank gets an empty chunk fully padded with zeros since there
|
|
# are not enough chunks across ranks
|
|
chunk = chunks[0].new_empty(0)
|
|
else:
|
|
chunk = chunks[rank]
|
|
numel_to_pad = chunks[0].numel() - chunk.numel()
|
|
assert (
|
|
numel_to_pad >= 0
|
|
), "Chunk's size should be at most the first chunk's size"
|
|
return chunk, numel_to_pad
|
|
|
|
@staticmethod
|
|
def _get_shard(
|
|
tensor: Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
) -> Tuple[Tensor, int]:
|
|
"""
|
|
Returns the shard of ``tensor`` with padding for the given ``rank`` and
|
|
``world_size`` and the numel padded for that shard.
|
|
|
|
This method allocates new memory (via :meth:`clone`) since the
|
|
unsharded ``tensor`` may be deallocated after this method returns.
|
|
"""
|
|
chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
|
|
tensor, rank, world_size
|
|
)
|
|
shard = chunk.clone()
|
|
if numel_to_pad > 0:
|
|
shard = F.pad(shard, [0, numel_to_pad])
|
|
return shard, numel_to_pad
|
|
|
|
@staticmethod
|
|
def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
|
|
"""
|
|
Returns the shape of ``tensor`` after sharding including padding. This
|
|
requires ``tensor`` to have 1D shape and ensures that the returned
|
|
shape is 1D.
|
|
"""
|
|
assert len(tensor.shape) == 1, f"{tensor.shape}"
|
|
unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
|
|
tensor, rank, world_size
|
|
)
|
|
unpadded_sharded_size = unpadded_sharded_tensor.size()
|
|
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
|
|
return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
|
|
|
|
def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
|
|
"""Returns [start, end] offsets of each original parameter's flattened
|
|
data in the unsharded flattened parameter (without padding)."""
|
|
cumulative_sum = list(accumulate(self.flat_param._numels))
|
|
starts = [0] + cumulative_sum[:-1]
|
|
ends = [end - 1 for end in cumulative_sum] # inclusive
|
|
param_offsets = list(zip(starts, ends))
|
|
return param_offsets
|
|
|
|
def shard_metadata(
|
|
self,
|
|
) -> FlatParamShardMetadata:
|
|
"""Returns shard-related metadata specific to this rank's shard of the
|
|
flattened parameter."""
|
|
assert hasattr(self.flat_param, "_shard_indices") and hasattr(
|
|
self.flat_param, "_shard_param_offsets"
|
|
), "Shard metadata has not been initialized"
|
|
shard_param_start_index = self.flat_param._shard_indices[0] # type: ignore[attr-defined]
|
|
shard_param_end_index = self.flat_param._shard_indices[1] # type: ignore[attr-defined]
|
|
sl = (
|
|
slice(shard_param_start_index, shard_param_end_index + 1)
|
|
if shard_param_start_index <= shard_param_end_index
|
|
else slice(0, 0)
|
|
)
|
|
return FlatParamShardMetadata(
|
|
self.flat_param._fqns[sl],
|
|
self.flat_param._shapes[sl],
|
|
self.flat_param._numels[sl],
|
|
self.flat_param._shard_param_offsets[:], # type: ignore[attr-defined]
|
|
)
|
|
|
|
@no_type_check
|
|
@torch.no_grad()
|
|
def init_flat_param_attributes(self) -> None:
|
|
"""
|
|
This initializes some attributes on the handle's ``FlatParameter``.
|
|
This should be called during lazy initialization since it requires the
|
|
parameter to be on the compute device if not offloading to CPU and we
|
|
want to give users the chance to move the parameter appropriately after
|
|
the FSDP constructor.
|
|
|
|
For each tensor attribute on the ``FlatParameter``, see the unshard and
|
|
reshard methods in this class for the allocation and free pattern.
|
|
"""
|
|
flat_param = self.flat_param
|
|
if flat_param.dtype != self._orig_param_dtype:
|
|
# Entering this branch means that the user changed the parameter
|
|
# dtype after FSDP initialization, in which case we may need to
|
|
# refresh some saved dtype attributes (dtypes specified as a part
|
|
# of mixed precision take precedence).
|
|
if not self._low_prec_param_dtype_specified:
|
|
self._fwd_bwd_param_dtype = flat_param.dtype
|
|
# For `reduce_dtype`, require `param_dtype` was not specified since
|
|
# then we infer the `reduce_dtype` from the specified `param_dtype`
|
|
if (
|
|
not self._low_prec_reduce_dtype_specified
|
|
and not self._low_prec_param_dtype_specified
|
|
):
|
|
self._reduce_dtype = flat_param.dtype
|
|
self._orig_param_dtype = flat_param.dtype
|
|
cpu_device = torch.device("cpu")
|
|
if self._offload_params:
|
|
_p_assert(
|
|
flat_param.device == cpu_device,
|
|
f"Expects the `FlatParameter` to be on CPU when parameter CPU "
|
|
f"offloading is enabled, not {flat_param.device}",
|
|
)
|
|
else:
|
|
self._check_on_compute_device(self.flat_param)
|
|
flat_param._local_shard = flat_param.data
|
|
if self._offload_params:
|
|
# Pin the memory for faster H2D transfer
|
|
flat_param._local_shard = flat_param._local_shard.pin_memory()
|
|
# Pre-allocate the sharded gradient on CPU to enable non-blocking
|
|
# D2H transfer during the backward pass
|
|
flat_param._cpu_grad = torch.zeros_like(
|
|
flat_param._local_shard, device=cpu_device
|
|
).pin_memory()
|
|
if self._uses_param_mixed_precision:
|
|
# For parameter mixed precision, we maintain a low precision
|
|
# sharded tensor on the compute device to be all-gathered (for
|
|
# sharded strategies) or directly used (for `NO_SHARD`) for
|
|
# computation.
|
|
flat_param._mp_shard = torch.zeros_like(
|
|
flat_param._local_shard,
|
|
device=self.device,
|
|
dtype=self._fwd_bwd_param_dtype,
|
|
)
|
|
_free_storage(flat_param._mp_shard)
|
|
if self.uses_sharded_strategy:
|
|
# We maintain a padded unsharded tensor that serves as the
|
|
# all-gather destination and owns the original parameter storages.
|
|
unsharded_param_dtype = (
|
|
self._fwd_bwd_param_dtype
|
|
if self._uses_param_mixed_precision
|
|
else flat_param.dtype
|
|
) # use low precision if parameter mixed precision is enabled
|
|
padded_unsharded_numel = flat_param.numel() * self.world_size
|
|
flat_param._full_param_padded = torch.zeros(
|
|
padded_unsharded_numel,
|
|
device=self.device,
|
|
dtype=unsharded_param_dtype,
|
|
)
|
|
flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
|
|
_free_storage(flat_param._full_param_padded)
|
|
|
|
if self._uses_param_mixed_precision:
|
|
# For parameter mixed precision, we maintain a full precision
|
|
# padded unsharded tensor for when we force full precision.
|
|
flat_param._full_prec_full_param_padded = torch.zeros(
|
|
padded_unsharded_numel,
|
|
device=self.device,
|
|
dtype=flat_param.dtype, # full precision
|
|
)
|
|
_free_storage(flat_param._full_prec_full_param_padded)
|
|
|
|
###################
|
|
# UNSHARD/RESHARD #
|
|
###################
|
|
def pre_unshard(self) -> bool:
|
|
"""
|
|
Returns: ``False`` if this is a no-op and ``True`` otherwise.
|
|
|
|
Postcondition: ``self.flat_param`` 's data is on the device for
|
|
communication and is what should be all-gathered. This means that it
|
|
matches the dtype of the expected unsharded parameter.
|
|
"""
|
|
ret = False
|
|
if self._use_orig_params:
|
|
ret = self._writeback_orig_params()
|
|
if (
|
|
self.uses_sharded_strategy
|
|
and not self._offload_params
|
|
and not self.needs_unshard()
|
|
):
|
|
pass # no-op
|
|
elif self._uses_param_mixed_precision and not self._force_full_precision:
|
|
self._use_low_precision_shard()
|
|
ret = True
|
|
elif self._offload_params and self.flat_param.device != self.device:
|
|
# NOTE: This creates a new tensor distinct from any attributes.
|
|
self.flat_param_to(self.device, non_blocking=True)
|
|
ret = True
|
|
self._check_on_compute_device(self.flat_param)
|
|
return ret
|
|
|
|
def _use_low_precision_shard(self):
|
|
"""
|
|
Allocates the low precision shard directly on the compute device and
|
|
switches to using the low precision sharded flattened parameter.
|
|
"""
|
|
self._check_low_precision_shard()
|
|
flat_param = self.flat_param
|
|
_alloc_storage(
|
|
flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
|
|
)
|
|
# `copy_()` implicitly casts to the low precision
|
|
flat_param._mp_shard.copy_( # type: ignore[attr-defined]
|
|
flat_param._local_shard.to( # type: ignore[attr-defined]
|
|
self.device, non_blocking=True
|
|
)
|
|
)
|
|
# Invariant: `_mp_shard` is always on the compute device.
|
|
flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
|
|
|
|
def unshard(self):
|
|
"""
|
|
Runs the unshard logic. This includes all-gathering the flattened
|
|
parameter and switching to using the unsharded flattened parameter. If
|
|
the handle does not need unsharding, then this only switches to using
|
|
the unsharded flattened parameter. For ``NO_SHARD``, this is a no-op.
|
|
|
|
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
|
|
mixed precision, then the parameter is forced to full precision.
|
|
"""
|
|
if not self.needs_unshard():
|
|
# Even when not needing an unshard, we should switch to using
|
|
# the unsharded flattened parameter
|
|
unsharded_flat_param = (
|
|
self._get_padded_unsharded_flat_param()
|
|
if self.uses_sharded_strategy
|
|
else self.flat_param
|
|
)
|
|
self._use_unsharded_flat_param(unsharded_flat_param)
|
|
return
|
|
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
|
|
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
|
|
self._use_unsharded_flat_param(padded_unsharded_flat_param)
|
|
|
|
def needs_unshard(self) -> bool:
|
|
"""Returns if the handle's flattened parameter needs to be unsharded."""
|
|
if not self.uses_sharded_strategy:
|
|
return False
|
|
unsharded_flat_param = self._get_padded_unsharded_flat_param()
|
|
already_unsharded = (
|
|
unsharded_flat_param._typed_storage()._size()
|
|
== unsharded_flat_param.numel()
|
|
)
|
|
return not already_unsharded
|
|
|
|
def _alloc_padded_unsharded_flat_param(self):
|
|
"""
|
|
Allocates the *padded* unsharded flattened parameter. The unpadded
|
|
unsharded flattened parameter is always a view into the padded one.
|
|
This padded parameter is saved to a different attribute on the
|
|
``FlatParameter`` depending on if we force full precision.
|
|
"""
|
|
self._check_sharded_strategy()
|
|
flat_param = self.flat_param
|
|
unsharded_flat_param = self._get_padded_unsharded_flat_param()
|
|
self._check_storage_freed(unsharded_flat_param)
|
|
_alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
|
|
return unsharded_flat_param
|
|
|
|
def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
|
|
"""
|
|
Returns a reference to the padded unsharded flattened parameter
|
|
depending on the calling context. This should only be called if using a
|
|
sharded strategy.
|
|
"""
|
|
self._check_sharded_strategy()
|
|
flat_param = self.flat_param
|
|
if self._force_full_precision:
|
|
# When parameter mixed precision is enabled, we use a different
|
|
# tensor as the all-gather destination to preserve the invariant
|
|
# that `_full_param_padded` is in the low precision
|
|
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
|
|
_p_assert(
|
|
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
|
|
f"Expects full precision but got {self._fwd_bwd_param_dtype}",
|
|
)
|
|
else:
|
|
unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
|
|
return unsharded_flat_param
|
|
|
|
def _all_gather_flat_param(
|
|
self,
|
|
padded_unsharded_flat_param: Tensor,
|
|
) -> Tensor:
|
|
"""
|
|
All-gathers the handle's flattened parameter to the destination
|
|
``padded_unsharded_flat_param``, and switches to using the all-gathered
|
|
tensor.
|
|
"""
|
|
_p_assert(
|
|
hasattr(self, "process_group") and hasattr(self, "world_size"),
|
|
"Expects a process group and world size to have been set via `shard()`",
|
|
)
|
|
sharded_flat_param = self.flat_param.data
|
|
expected_numel = sharded_flat_param.numel() * self.world_size
|
|
_p_assert(
|
|
padded_unsharded_flat_param.numel() == expected_numel,
|
|
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
|
|
)
|
|
dist.all_gather_into_tensor(
|
|
padded_unsharded_flat_param,
|
|
sharded_flat_param,
|
|
self.process_group,
|
|
)
|
|
return padded_unsharded_flat_param
|
|
|
|
def _use_unsharded_flat_param(
|
|
self,
|
|
padded_unsharded_flat_param: torch.Tensor,
|
|
) -> None:
|
|
"""
|
|
Switches to using the *unpadded* unsharded flattened parameter, which
|
|
is a view into the *padded* unsharded flattened parameter.
|
|
"""
|
|
unsharded_size = self.flat_param._unpadded_unsharded_size
|
|
self.flat_param.data = padded_unsharded_flat_param[
|
|
: unsharded_size.numel()
|
|
].view(
|
|
unsharded_size
|
|
) # this `.view()` is not autograd visible
|
|
in_forward = self._training_state == HandleTrainingState.FORWARD
|
|
in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
|
|
if self._use_orig_params:
|
|
# We use `Tensor` views in the forward so that they are tracked by
|
|
# autograd. We use them in the pre-backward as well to support
|
|
# reentrant activation checkpointing, which needs the views to be
|
|
# tracked by autograd in the backward pass's recomputed forward.
|
|
self._use_unsharded_views(
|
|
as_params=(not in_forward and not in_pre_backward)
|
|
)
|
|
elif in_forward:
|
|
self._use_unsharded_views(as_params=False)
|
|
|
|
def post_unshard(self):
|
|
"""
|
|
Runs the post-unshard logic. This includes freeing the low precision
|
|
shard if needed.
|
|
"""
|
|
if self._uses_param_mixed_precision and self.uses_sharded_strategy:
|
|
self._free_low_precision_sharded_param()
|
|
self._check_on_compute_device(self.flat_param)
|
|
|
|
def _free_low_precision_sharded_param(self):
|
|
"""Frees the low precision sharded flattened parameter."""
|
|
self._check_low_precision_shard()
|
|
# `_mp_shard` is allocated in the pre-unshard stream, consumed in the
|
|
# unshard stream for sharded strategies, and consumed in both the
|
|
# unshard and default streams for `NO_SHARD`. For sharded strategies,
|
|
# the current stream here is the unshard stream, and for `NO_SHARD`,
|
|
# it is the default stream. For `NO_SHARD`, only recording for the
|
|
# default stream suffices since the default stream waits for the
|
|
# unshard stream.
|
|
_no_dispatch_record_stream(
|
|
self.flat_param._mp_shard, torch.cuda.current_stream() # type: ignore[attr-defined]
|
|
)
|
|
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
|
|
|
|
@torch.no_grad()
|
|
def unshard_grad(self):
|
|
"""
|
|
Unshards the handle's ``FlatParameter`` 's gradient. If all ranks have
|
|
``None`` gradient, then all original parameters will as well. This
|
|
method performs an all-reduce and an all-gather. The additional
|
|
all-reduce is tolerable since this method is not meant to be used on
|
|
the computation critical path.
|
|
|
|
Postcondition: ``_saved_grad_shard`` is defined and contains the value
|
|
to set ``flat_param.grad`` after gradients are resharded.
|
|
"""
|
|
if not self.uses_sharded_strategy:
|
|
self._use_unsharded_grad_views()
|
|
return
|
|
flat_param = self.flat_param
|
|
self._check_unsharded(flat_param)
|
|
|
|
# Check if all ranks have a `None` gradient
|
|
num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
|
|
num_grad_none[0] = flat_param.grad is None
|
|
dist.all_reduce(num_grad_none, group=self.process_group)
|
|
if num_grad_none[0] == self.world_size:
|
|
flat_param._saved_grad_shard = None # type: ignore[attr-defined]
|
|
self._use_unsharded_grad_views()
|
|
return
|
|
|
|
padded_unsharded_grad = torch.empty(
|
|
flat_param._padded_unsharded_size, # type: ignore[attr-defined]
|
|
device=self.device,
|
|
)
|
|
if flat_param.grad is None:
|
|
# In the case that only some ranks have `None` gradient, we use
|
|
# zeros to approximate as a best effort attempt
|
|
if self._debug_level == dist.DebugLevel.DETAIL:
|
|
warnings.warn(
|
|
f"[Rank {self.rank}] Only some but not all ranks have a "
|
|
"`None` `FlatParameter` gradient, so FSDP is using zeros to "
|
|
"approximate those ranks' sharded gradients being `None`"
|
|
)
|
|
flat_param._saved_grad_shard = None # type: ignore[attr-defined]
|
|
sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined]
|
|
else:
|
|
self._check_sharded(flat_param.grad)
|
|
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
|
|
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
dist.all_gather_into_tensor(
|
|
padded_unsharded_grad, sharded_grad, self.process_group
|
|
)
|
|
unsharded_size = self.flat_param._unpadded_unsharded_size
|
|
flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
|
|
unsharded_size
|
|
)
|
|
self._use_unsharded_grad_views()
|
|
|
|
def reshard_grad(self):
|
|
if self._use_orig_params:
|
|
self._use_sharded_grad_views()
|
|
if not self.uses_sharded_strategy:
|
|
return
|
|
self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
delattr(self.flat_param, "_saved_grad_shard")
|
|
|
|
def prepare_gradient_for_backward(self):
|
|
"""
|
|
Prepares the gradient for the backward computation by saving and
|
|
clearing any existing sharded gradient in ``.grad`` to enable computing
|
|
a new unsharded gradient.
|
|
"""
|
|
_p_assert(
|
|
self._training_state
|
|
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
|
|
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
|
|
)
|
|
flat_param = self.flat_param
|
|
if flat_param.grad is not None and (
|
|
flat_param.grad.size() != flat_param._unpadded_unsharded_size
|
|
or flat_param.grad.device != flat_param.device # grad on CPU
|
|
):
|
|
self._check_on_compute_device(self.flat_param)
|
|
grad_offloaded = flat_param.grad.device != self.device
|
|
_p_assert(
|
|
not grad_offloaded or self._offload_params,
|
|
f"Expects the sharded gradient to be on {self.device} "
|
|
f"but got {flat_param.grad.device}",
|
|
)
|
|
prev_iter_synced_gradients = (
|
|
flat_param.grad.size()
|
|
== flat_param._local_shard.size() # type: ignore[attr-defined]
|
|
)
|
|
if prev_iter_synced_gradients:
|
|
# TODO (awgu): Gradient accumulation outside `no_sync()`
|
|
# does not work with CPU offloading. The issue should be
|
|
# that, in the post-backward hook, we cannot do an addition
|
|
# between a CPU tensor (the existing sharded gradient) and
|
|
# a GPU tensor (the new sharded gradient).
|
|
if not grad_offloaded:
|
|
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
|
|
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
else:
|
|
_p_assert(
|
|
hasattr(flat_param, "_cpu_grad"),
|
|
"`_cpu_grad` should be defined if the gradient is on CPU",
|
|
)
|
|
sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined]
|
|
# If user specified to keep the gradient in low precision, then
|
|
# the gradient may still be of the low precision dtype if the
|
|
# user did not set the gradient to `None` after the previous
|
|
# backward, in which case FSDP should cast back to the full
|
|
# precision dtype so that FSDP can accumulate in that dtype in
|
|
# the post-backward hook and assign to `.grad` in that dtype in
|
|
# the post-backward callback.
|
|
local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined]
|
|
if (
|
|
self._keep_low_precision_grads
|
|
and sharded_grad.dtype != local_shard_dtype
|
|
):
|
|
sharded_grad.data = sharded_grad.to(local_shard_dtype)
|
|
else:
|
|
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
|
|
_p_assert(
|
|
flat_param.grad.size() == padded_unsharded_size,
|
|
"Expects `.grad` to be the unsharded gradient in "
|
|
f"`no_sync()` with size {padded_unsharded_size} "
|
|
f"but got size {flat_param.grad.size()}",
|
|
)
|
|
flat_param.grad = None
|
|
|
|
def prepare_gradient_for_optim(self):
|
|
"""
|
|
Prepares the gradient for optimizer computation by moving the sharded
|
|
gradient to the ``.grad`` attribute.
|
|
"""
|
|
|
|
def cast_grad_to_param_dtype_if_needed(flat_param):
|
|
if self._keep_low_precision_grads:
|
|
assert flat_param.grad is not None # mypy
|
|
if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
|
|
flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
|
|
if self._use_orig_params:
|
|
self._use_sharded_grad_views()
|
|
|
|
flat_param = self.flat_param
|
|
# TODO (awgu): We should replace these conditional checks to encode
|
|
# the logical intention more directly.
|
|
if hasattr(flat_param, "_cpu_grad"):
|
|
# NOTE: This branch includes `NO_SHARD`.
|
|
self._check_sharded(flat_param)
|
|
self._check_on_cpu(flat_param)
|
|
flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined]
|
|
cast_grad_to_param_dtype_if_needed(flat_param)
|
|
elif hasattr(flat_param, "_saved_grad_shard"):
|
|
self._check_sharded(flat_param)
|
|
self._check_on_compute_device(flat_param)
|
|
self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined]
|
|
# If no sharded gradient was computed this iteration, then there is
|
|
# no need to forward `_saved_grad_shard` to `grad`
|
|
if flat_param._post_backward_called: # type: ignore[attr-defined]
|
|
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
cast_grad_to_param_dtype_if_needed(flat_param)
|
|
else:
|
|
_p_assert(
|
|
not self.uses_sharded_strategy
|
|
or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
|
"All sharded parameters that received a gradient in the "
|
|
"post-backward should use `_saved_grad_shard`",
|
|
)
|
|
# Delete `_saved_grad_shard` since its existence indicates a previous
|
|
# gradient to accumulate with in the post-backward hook
|
|
if hasattr(flat_param, "_saved_grad_shard"):
|
|
delattr(flat_param, "_saved_grad_shard")
|
|
|
|
@contextlib.contextmanager
|
|
def to_cpu(self):
|
|
"""
|
|
Moves the unpadded unsharded flattened parameter to CPU while in the
|
|
context and moves it back to the previous device upon exit. For now,
|
|
this assumes the ``FlatParameter`` is the unpadded unsharded flattened
|
|
parameter since (1) there is no reason to include the padding in the
|
|
copy and (2) there is no use case for the sharded flattened parameter.
|
|
|
|
Precondition: ``self.flat_param`` 's data is the unpadded unsharded
|
|
flattened parameter on the compute device, and the handle uses a
|
|
sharded strategy.
|
|
Postcondition: Same as the precondition.
|
|
"""
|
|
self._check_sharded_strategy()
|
|
_p_assert(
|
|
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
|
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
|
)
|
|
self._check_on_compute_device(self.flat_param)
|
|
# Check that the unpadded unsharded flattened parameter is a view into
|
|
# the padded unsharded flattened parameter as expected
|
|
# NOTE: This check is not strictly needed for correctness but is a
|
|
# useful sanity check since the tensor should only be used internally.
|
|
unpadded_storage_ptr = self.flat_param._typed_storage()._data_ptr()
|
|
padded_storage_ptr = (
|
|
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
|
|
)
|
|
_p_assert(
|
|
unpadded_storage_ptr == padded_storage_ptr,
|
|
"Expects the unpadded parameter to be a view into the padded parameter",
|
|
)
|
|
self.flat_param_to(torch.device("cpu"))
|
|
self._free_unsharded_flat_param()
|
|
try:
|
|
yield
|
|
finally:
|
|
_p_assert(
|
|
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
|
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
|
)
|
|
padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
|
|
# Copy from CPU to the compute device
|
|
padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
|
|
self.flat_param
|
|
)
|
|
self._use_unsharded_flat_param(padded_unsharded_flat_param)
|
|
|
|
def reshard(self, free_unsharded_flat_param: bool):
|
|
"""
|
|
Runs the reshard logic. This includes freeing the unsharded flattened
|
|
parameter if ``free_unsharded_flat_param`` and switching to using the
|
|
sharded flattened parameter.
|
|
"""
|
|
# Switch to the sharded `FlatParameter` before freeing to prevent
|
|
# "use-after-free"-type bugs with external profiling tools, where for
|
|
# `use_orig_params=True`, the `param` does not point to valid memory
|
|
# when setting `param.data = ...` in `_use_sharded_views()`.
|
|
self._use_sharded_flat_param()
|
|
if free_unsharded_flat_param:
|
|
self._free_unsharded_flat_param()
|
|
|
|
def post_reshard(self):
|
|
"""
|
|
Runs the post-reshard logic. This includes freeing any memory that
|
|
can now be freed given that the ``FlatParameter`` points to the full
|
|
precision sharded flattened parameter.
|
|
|
|
Precondition: ``self.flat_param`` 's data points to the full precision
|
|
sharded flattened parameter.
|
|
"""
|
|
# For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since
|
|
# it is also the low precision *unsharded* flattened parameter. Hence,
|
|
# we delay the free until the reshard.
|
|
if (
|
|
self._uses_param_mixed_precision
|
|
and not self.uses_sharded_strategy
|
|
and not self._force_full_precision # did not use the low precision shard
|
|
):
|
|
self._free_low_precision_sharded_param()
|
|
|
|
def _free_unsharded_flat_param(self):
|
|
"""
|
|
Frees the padded unsharded flattened parameter. The tensor to free
|
|
depends on the calling context since the unshard may have forced full
|
|
precision, in which case a different tensor is used.
|
|
"""
|
|
self._check_sharded_strategy()
|
|
unsharded_flat_param = self._get_padded_unsharded_flat_param()
|
|
self._check_storage_allocated(unsharded_flat_param)
|
|
self._check_on_compute_device(unsharded_flat_param)
|
|
# Do not free the memory until all ops in the current stream finish
|
|
_no_dispatch_record_stream(unsharded_flat_param, torch.cuda.current_stream())
|
|
_free_storage(unsharded_flat_param)
|
|
|
|
def _use_sharded_flat_param(self) -> None:
|
|
"""Switches to using the sharded flattened parameter."""
|
|
flat_param = self.flat_param
|
|
if self._offload_params:
|
|
device = flat_param._local_shard.device # type: ignore[attr-defined]
|
|
_p_assert(
|
|
device == torch.device("cpu"),
|
|
f"Expects the local shard to be on CPU but got {device}",
|
|
)
|
|
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
|
|
if self._use_orig_params:
|
|
self._use_sharded_views()
|
|
# For the post-forward reshard, we may try to use sharded gradient
|
|
# views (or unsharded gradient views if a gradient was accumulated
|
|
# in `no_sync()`), but for the post-backward reshard, we delay the
|
|
# call to after the reduce-scatter.
|
|
if self._training_state == HandleTrainingState.FORWARD:
|
|
# TODO: Change `_unpadded_unsharded_size` if we change the
|
|
# gradient to be computed directly with padding.
|
|
accumulated_grad_in_no_sync = (
|
|
flat_param.grad is not None
|
|
and self.uses_sharded_strategy
|
|
and flat_param.grad.shape == flat_param._unpadded_unsharded_size
|
|
)
|
|
if accumulated_grad_in_no_sync:
|
|
self._use_unsharded_grad_views()
|
|
else:
|
|
self._use_sharded_grad_views()
|
|
|
|
#########
|
|
# VIEWS #
|
|
#########
|
|
@staticmethod
|
|
def _get_unflat_views(
|
|
flat_param: FlatParameter,
|
|
tensor: Optional[torch.Tensor] = None,
|
|
) -> Iterator[Tensor]:
|
|
"""
|
|
Returns unflattened ``Tensor`` views into ``tensor`` if it is not
|
|
``None`` or ``flat_param`` otherwise, where the unflattening is based
|
|
on ``flat_param`` 's metadata.
|
|
|
|
In other words, to get views into the unsharded flattened parameter,
|
|
pass ``tensor`` as ``None``, but to get views into tensor optimizer
|
|
state, pass ``tensor`` as the optimizer state tensor.
|
|
"""
|
|
if tensor is None:
|
|
tensor = flat_param
|
|
_p_assert(
|
|
tensor.numel() == flat_param._unpadded_unsharded_size.numel(),
|
|
f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got "
|
|
f"{tensor.numel()} numel",
|
|
)
|
|
views = (
|
|
_ext_post_unflatten_transform(subtensor.view(shape), param_extension)
|
|
for (subtensor, shape, param_extension) in zip(
|
|
torch.split(tensor, flat_param._numels, dim=0), # type: ignore[arg-type]
|
|
flat_param._shapes,
|
|
flat_param._param_extensions,
|
|
)
|
|
)
|
|
return views
|
|
|
|
def _use_unsharded_views(self, as_params: bool) -> None:
|
|
"""
|
|
Unflattens the unsharded flattened parameter by setting the original
|
|
module parameter variables to be views into it.
|
|
|
|
Args:
|
|
as_params (bool): If ``True``, then registers the original
|
|
parameters as ``nn.Parameter`` s; if ``False``, then registers
|
|
the original parameters only as ``Tensor`` s. ``False`` should
|
|
be used during forward/backward computation and when hiding the
|
|
original parameters from :meth:`nn.Module.named_parameters`.
|
|
"""
|
|
self._check_unsharded(self.flat_param)
|
|
views = self._get_unflat_views(self.flat_param)
|
|
for i, (view, (param_name, module, _)) in enumerate(
|
|
zip(views, self.flat_param._param_infos)
|
|
):
|
|
if self._use_orig_params and as_params:
|
|
if type(view) is DTensor:
|
|
# A `DTensor` `view` is not compatible with assigning
|
|
# `param.data = view`, so we cannot preserve the parameter
|
|
# variable.
|
|
self._setattr_param(module, param_name, nn.Parameter(view))
|
|
continue
|
|
param = self.flat_param._params[i] # type: ignore[index]
|
|
self._setattr_param(module, param_name, param)
|
|
param.data = view
|
|
elif as_params:
|
|
self._setattr_param(module, param_name, nn.Parameter(view))
|
|
else: # `as_params=False`
|
|
param_var: Tensor = view
|
|
if self._use_orig_params:
|
|
if self._training_state == HandleTrainingState.FORWARD:
|
|
assert self.flat_param._tensors is not None
|
|
# Save the `Tensor` for the pre-backward
|
|
self.flat_param._tensors[i] = view # save for pre-backward
|
|
elif self._training_state == HandleTrainingState.BACKWARD_PRE:
|
|
# Use the saved `Tensor` variable from the forward to
|
|
# preserve the autograd graph so that the post-backward
|
|
# hook fires (e.g. for reentrant AC)
|
|
assert self.flat_param._tensors is not None # mypy
|
|
tensor = self.flat_param._tensors[i]
|
|
_p_assert(
|
|
tensor is not None,
|
|
"Expects `Tensor` to have been saved in forward",
|
|
)
|
|
tensor.data = view # type: ignore[union-attr]
|
|
assert tensor is not None # mypy
|
|
param_var = tensor
|
|
self._setattr_tensor(module, param_name, param_var)
|
|
if (
|
|
self._use_orig_params
|
|
and self._training_state == HandleTrainingState.FORWARD
|
|
):
|
|
module._parameters[param_name] = param_var # type: ignore[assignment]
|
|
for i, (
|
|
param_name,
|
|
module,
|
|
_,
|
|
prim_param_name,
|
|
prim_module,
|
|
prim_module_name,
|
|
) in enumerate(self.flat_param._shared_param_infos):
|
|
if hasattr(module, param_name):
|
|
delattr(module, param_name)
|
|
_p_assert(
|
|
hasattr(prim_module, prim_param_name),
|
|
f"Module {prim_module_name} is missing parameter {prim_param_name}",
|
|
)
|
|
prim_param: Union[Tensor, nn.Parameter] = getattr(
|
|
prim_module, prim_param_name
|
|
)
|
|
_p_assert(
|
|
not as_params or isinstance(prim_param, nn.Parameter),
|
|
f"as_params={as_params} type(prim_param)={type(prim_param)}",
|
|
)
|
|
if self._use_orig_params and as_params:
|
|
shared_param = self.flat_param._shared_params[i] # type: ignore[index]
|
|
self._setattr_param(module, param_name, shared_param)
|
|
shared_param.data = prim_param
|
|
elif as_params:
|
|
assert isinstance(prim_param, nn.Parameter)
|
|
self._setattr_param(module, param_name, prim_param)
|
|
else:
|
|
self._setattr_tensor(module, param_name, prim_param)
|
|
if (
|
|
self._use_orig_params
|
|
and self._training_state == HandleTrainingState.FORWARD
|
|
):
|
|
module._parameters[param_name] = prim_param # type: ignore[assignment]
|
|
|
|
def _use_unsharded_grad_views(self) -> None:
|
|
"""
|
|
Unflattens the unsharded flattened parameter's gradient by setting the
|
|
original module parameter variables' gradients to be views into it.
|
|
"""
|
|
# Expects the gradient to be in `flat_param.grad`
|
|
if self.flat_param.grad is None:
|
|
assert self.flat_param._params is not None # mypy
|
|
assert self.flat_param._shared_params is not None # mypy
|
|
for param in chain(
|
|
self.flat_param._params, # type: ignore[attr-defined]
|
|
self.flat_param._shared_params, # type: ignore[attr-defined]
|
|
):
|
|
param.grad = None
|
|
return
|
|
self._check_unsharded(self.flat_param.grad)
|
|
views = self._get_unflat_views(self.flat_param, self.flat_param.grad)
|
|
for i, (view, (param_name, module, _)) in enumerate(
|
|
zip(views, self.flat_param._param_infos)
|
|
):
|
|
_p_assert(
|
|
hasattr(module, param_name),
|
|
f"{self.flat_param._fqns[i]} is missing",
|
|
)
|
|
param = getattr(module, param_name)
|
|
if param.shape != view.shape or param.dtype != view.dtype:
|
|
# NOTE: This is a hack using `.data` to side step the
|
|
# check that parameter/gradient sizes and dtypes match. Here,
|
|
# `param` can have the sharded size, and `grad` can have the
|
|
# unsharded size. Orthgonally, `param` can have the full
|
|
# precision dtype from `reshard()`, and `grad` can have the
|
|
# parameter low precision dtype. Both of these mismatches
|
|
# happen when running in `no_sync()`.
|
|
if param.grad is None:
|
|
param.grad = torch.empty_like(param)
|
|
param.grad.data = view
|
|
else:
|
|
param.grad = view
|
|
for i, (
|
|
param_name,
|
|
module,
|
|
module_name,
|
|
prim_param_name,
|
|
prim_module,
|
|
_,
|
|
) in enumerate(self.flat_param._shared_param_infos):
|
|
_p_assert(
|
|
hasattr(module, param_name),
|
|
f"{module_name + '.' + param_name if module_name else param_name} is missing",
|
|
) # did not save FQN info in `_shared_param_infos`
|
|
param = getattr(module, param_name)
|
|
prim_param = getattr(prim_module, prim_param_name)
|
|
if (
|
|
param.shape != prim_param.grad.shape
|
|
or param.dtype != prim_param.grad.dtype
|
|
):
|
|
# NOTE: This is the same hack to use `.data` to side step the
|
|
# size check.
|
|
if param.grad is None:
|
|
param.grad = torch.empty_like(param)
|
|
param.grad.data = prim_param.grad
|
|
else:
|
|
param.grad = prim_param.grad
|
|
|
|
@contextlib.contextmanager
|
|
def unflatten_as_params(self) -> Generator:
|
|
"""
|
|
Assumes the flattened parameter is unsharded. When in the context,
|
|
unflattens the original parameters as ``nn.Parameter`` views into the
|
|
flattened parameter, and after the context, restores the original
|
|
parameters as ``Tensor`` views into the flattened parameter.
|
|
"""
|
|
self._use_unsharded_views(as_params=True)
|
|
try:
|
|
yield
|
|
finally:
|
|
self._use_unsharded_views(as_params=False)
|
|
|
|
@torch.no_grad()
|
|
def _use_sharded_views(self) -> None:
|
|
"""
|
|
Sets the original module parameter variables' data to be flattened
|
|
views into the sharded flattened parameter.
|
|
|
|
The views are kept as flattened to simplify the case where a parameter
|
|
is sharded across ranks. Parameters whose data is not present in the
|
|
sharded flattened parameter have their data set to a size-0 empty
|
|
tensor. We do not delete them to ensure to preserve expected behaviors
|
|
like model printability. Parameters whose data is present must preserve
|
|
their variables to be passable to an optimizer.
|
|
"""
|
|
if not self.uses_sharded_strategy:
|
|
# For `NO_SHARD`, use the *unflattened* unsharded views since we
|
|
# have the unsharded parameter
|
|
self._use_unsharded_views(as_params=True)
|
|
return
|
|
self._check_sharded(self.flat_param)
|
|
start, end = self.flat_param._shard_indices # type: ignore[attr-defined]
|
|
offset = 0
|
|
assert self.flat_param._params is not None
|
|
for i, (param, (param_name, module, _)) in enumerate(
|
|
zip(self.flat_param._params, self.flat_param._param_infos)
|
|
):
|
|
self._setattr_param(module, param_name, param)
|
|
in_sharded_flat_param = (
|
|
i >= start
|
|
and i <= end
|
|
and self.flat_param._shard_param_offsets # type: ignore[attr-defined]
|
|
)
|
|
if in_sharded_flat_param:
|
|
param_start, param_end = self.flat_param._shard_param_offsets[i - start] # type: ignore[attr-defined]
|
|
numel_in_shard = param_end - param_start + 1
|
|
param.data = self.flat_param[offset : offset + numel_in_shard]
|
|
offset += numel_in_shard
|
|
else:
|
|
# Allow the original data to be freed via garbage collection
|
|
param.data = torch.empty(
|
|
0,
|
|
dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
|
|
device=self.flat_param.device,
|
|
requires_grad=False,
|
|
)
|
|
assert self.flat_param._shared_params is not None
|
|
for i, (
|
|
param,
|
|
(param_name, module, _, prim_param_name, prim_module, _),
|
|
) in enumerate(
|
|
zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
|
|
):
|
|
self._setattr_param(module, param_name, param)
|
|
prim_param = getattr(prim_module, prim_param_name)
|
|
param.data = prim_param # could be both empty and non-empty
|
|
if self._training_state == HandleTrainingState.BACKWARD_POST:
|
|
assert self.flat_param._tensors is not None # mypy
|
|
# Clear the saved `Tensor`s since they are unneeded now
|
|
for i in range(len(self.flat_param._tensors)):
|
|
self.flat_param._tensors[i] = None # type: ignore[index]
|
|
|
|
@torch.no_grad()
|
|
def _use_sharded_grad_views(self) -> None:
|
|
"""
|
|
Sets the original module parameter variables' gradients to be flattened
|
|
views into the sharded flattened parameter's gradient. This is a no-op
|
|
if there is no gradient.
|
|
|
|
Parameters whose data is not present in the sharded flattened parameter
|
|
and parameters with ``requires_grad=False`` have their gradients set to
|
|
``None``. Since the gradient variables do not need to be preserved,
|
|
this method does not manipulate existing ``Tensor`` data directly and
|
|
creates new ``Tensor`` variables instead.
|
|
"""
|
|
flat_param = self.flat_param
|
|
self._check_sharded(flat_param)
|
|
grad = self.sharded_grad
|
|
if grad is None:
|
|
assert flat_param._params is not None # mypy
|
|
assert flat_param._shared_params is not None # mypy
|
|
for param in chain(flat_param._params, flat_param._shared_params): # type: ignore[attr-defined]
|
|
param.grad = None
|
|
return
|
|
self._check_sharded(grad)
|
|
start, end = flat_param._shard_indices # type: ignore[attr-defined]
|
|
offset = 0
|
|
assert flat_param._params is not None
|
|
for i, param in enumerate(flat_param._params):
|
|
in_sharded_flat_param = (
|
|
i >= start
|
|
and i <= end
|
|
and flat_param._shard_param_offsets # type: ignore[attr-defined]
|
|
)
|
|
if in_sharded_flat_param:
|
|
param_start, param_end = flat_param._shard_param_offsets[i - start] # type: ignore[attr-defined]
|
|
numel_in_shard = param_end - param_start + 1
|
|
assert flat_param._is_grad_none is not None # mypy
|
|
if param.requires_grad and not flat_param._is_grad_none[i]:
|
|
if self._keep_low_precision_grads or param.dtype != grad.dtype:
|
|
# NOTE: This is a hack using `.data` to side step the
|
|
# check that parameter/gradient dtypes match. Here,
|
|
# `param` has full precision; `grad` has low precision.
|
|
if param.grad is None:
|
|
# `.grad` must have the same shape as `param`
|
|
param.grad = torch.empty_like(param)
|
|
param.grad.data = grad[
|
|
offset : offset + numel_in_shard
|
|
].reshape(param.shape)
|
|
else:
|
|
param.grad = grad[offset : offset + numel_in_shard].reshape(
|
|
param.shape
|
|
)
|
|
else:
|
|
param.grad = None
|
|
offset += numel_in_shard
|
|
else:
|
|
param.grad = None
|
|
assert flat_param._shared_params is not None
|
|
for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
|
|
zip(flat_param._shared_params, flat_param._shared_param_infos)
|
|
):
|
|
in_sharded_flat_param = hasattr(prim_module, prim_param_name)
|
|
if in_sharded_flat_param and param.requires_grad:
|
|
prim_param = getattr(prim_module, prim_param_name)
|
|
param.grad = prim_param.grad # share the same reference
|
|
else:
|
|
param.grad = None
|
|
|
|
@torch.no_grad()
|
|
def _writeback_orig_params(self) -> bool:
|
|
"""
|
|
Iterates over the original parameters and writes back any parameters
|
|
that changed storages (due to a non-inplace operator) to the handle's
|
|
``FlatParameter``. This method preserves the ``FlatParameter` 's
|
|
device even if an original parameter's device changes.
|
|
|
|
Raises:
|
|
RuntimeError: If an original parameter or gradient changes storages
|
|
but no longer has the expected flattened shape.
|
|
Returns: ``True`` if some writeback happened, and ``False`` otherwise.
|
|
"""
|
|
if self.uses_sharded_strategy and not self.is_sharded(self.flat_param):
|
|
# For `NO_SHARD`, we may still need to writeback
|
|
return False
|
|
flat_param = self.flat_param
|
|
start, end = flat_param._shard_indices # type: ignore[attr-defined]
|
|
offset = 0
|
|
assert flat_param._params is not None
|
|
wroteback = False
|
|
for i, (param, (param_name, module, _)) in enumerate(
|
|
zip(flat_param._params, flat_param._param_infos)
|
|
):
|
|
if not hasattr(module, param_name):
|
|
# Do not writeback if original parameters are deregistered
|
|
# (e.g. during model checkpointing)
|
|
continue
|
|
in_sharded_flat_param = (
|
|
i >= start
|
|
and i <= end
|
|
and self.flat_param._shard_param_offsets # type: ignore[attr-defined]
|
|
)
|
|
if not in_sharded_flat_param:
|
|
continue
|
|
param_start, param_end = flat_param._shard_param_offsets[i - start] # type: ignore[attr-defined]
|
|
numel_in_shard = param_end - param_start + 1
|
|
|
|
# Check for parameter writeback
|
|
param_changed = getattr(module, param_name) is not param
|
|
needs_param_writeback = (
|
|
param_changed # changed parameter variable itself
|
|
or not _same_storage(param, flat_param) # changed `.data`
|
|
)
|
|
if param_changed:
|
|
# NOTE: The gradient is not preserved after a parameter change.
|
|
param = getattr(module, param_name)
|
|
flat_param._params[i] = param
|
|
if needs_param_writeback:
|
|
expected_shape = torch.Size([numel_in_shard])
|
|
self._writeback_tensor(
|
|
param, flat_param, i, expected_shape, offset, True
|
|
)
|
|
wroteback = True
|
|
|
|
# Check for gradient writeback
|
|
# NOTE: Since this method is called in the pre-unshard, which is
|
|
# only called during computation in the pre-forward or
|
|
# pre-backward, the sharded gradient should be guaranteed to be in
|
|
# `.grad`, not in `._saved_grad_shard`.
|
|
if param.grad is None and flat_param.grad is not None:
|
|
expected_shape = torch.Size([numel_in_shard])
|
|
self._writeback_tensor(
|
|
None, flat_param.grad, i, expected_shape, offset, False
|
|
)
|
|
elif param.grad is not None:
|
|
# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
|
|
# memory and owns the gradient storage, so it will never
|
|
# require gradient writeback.
|
|
flat_param_grad = (
|
|
flat_param.grad
|
|
if self.uses_sharded_strategy or not self._offload_params
|
|
else flat_param._cpu_grad # type: ignore[attr-defined]
|
|
)
|
|
needs_grad_writeback = flat_param_grad is None or not _same_storage(
|
|
param.grad, flat_param_grad
|
|
)
|
|
if needs_grad_writeback:
|
|
if flat_param_grad is None:
|
|
flat_param_grad = torch.zeros_like(flat_param)
|
|
expected_shape = torch.Size([numel_in_shard])
|
|
self._writeback_tensor(
|
|
param.grad, flat_param_grad, i, expected_shape, offset, False
|
|
)
|
|
flat_param.grad = flat_param_grad
|
|
offset += numel_in_shard
|
|
# TODO (awgu): Handle shared parameters. We need to re-generate the
|
|
# shared parameter data structures in case sharedness changed.
|
|
for i, (
|
|
param_name,
|
|
module,
|
|
_,
|
|
prim_param_name,
|
|
prim_module,
|
|
_,
|
|
) in enumerate(flat_param._shared_param_infos):
|
|
if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
|
|
raise NotImplementedError(
|
|
"Changing shared parameters is not supported yet"
|
|
)
|
|
return wroteback
|
|
|
|
def _writeback_tensor(
|
|
self,
|
|
src_tensor: Optional[Tensor],
|
|
dst_tensor: Tensor,
|
|
tensor_index: int,
|
|
expected_shape: torch.Size,
|
|
offset: int,
|
|
is_param: bool, # else gradient
|
|
) -> None:
|
|
"""
|
|
Writes back ``src_tensor`` to ``dst_tensor`` at offset ``offset``,
|
|
where ``src_tensor`` should have shape ``expected_shape``. ``is_param``
|
|
indicates if the tensor is the parameter (if ``True``) or gradient (if
|
|
``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
|
|
instead of copying. ``tensor_index`` gives the index of ``src_tensor``
|
|
in the metadata structures.
|
|
|
|
Raises:
|
|
RuntimeError: If the ``src_tensor`` does not have the expected
|
|
shape.
|
|
"""
|
|
_p_assert(
|
|
len(expected_shape) == 1,
|
|
f"Expects a 1D expected shape but got {expected_shape}",
|
|
)
|
|
if self._debug_level == dist.DebugLevel.DETAIL:
|
|
rank = self.rank if hasattr(self, "rank") else dist.get_rank()
|
|
src_shape = src_tensor.shape if src_tensor is not None else None
|
|
src_device = src_tensor.device if src_tensor is not None else None
|
|
warnings.warn(
|
|
f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
|
|
f"writeback in {self._training_state}\n"
|
|
f"expected shape={expected_shape} shape={src_shape} "
|
|
f"expected device={dst_tensor.device} device={src_device}"
|
|
)
|
|
if src_tensor is not None and src_tensor.shape != expected_shape:
|
|
# NOTE: Gradient shape mismatch is not possible in practice since
|
|
# the gradient shape is enforced to match that of the parameter and
|
|
# we already check for parameter shape mismatch.
|
|
raise RuntimeError(
|
|
f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
|
|
f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
|
|
)
|
|
if src_tensor is not None:
|
|
dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
|
|
else:
|
|
dst_tensor[offset : offset + expected_shape.numel()].zero_()
|
|
assert self.flat_param._is_grad_none is not None
|
|
self.flat_param._is_grad_none[tensor_index] = True
|
|
|
|
def _clear_grads_if_needed(self):
|
|
"""
|
|
When ``use_orig_params=True``, sets the underlying ``flat_param.grad``
|
|
to ``None`` if *all* of the original parameters' ``.grad`` are
|
|
``None``. This is targeting ``optim.zero_grad(set_to_none=True)``, in
|
|
which case we want to free the gradients as soon after the
|
|
``zero_grad()`` call as possible.
|
|
"""
|
|
if not self._use_orig_params:
|
|
return
|
|
flat_param = self.flat_param
|
|
assert flat_param._params is not None
|
|
if all(param.grad is None for param in flat_param._params):
|
|
flat_param.grad = None
|
|
|
|
def _deregister_orig_params(self):
|
|
for param_name, module, _ in self.flat_param._param_infos:
|
|
if hasattr(module, param_name):
|
|
delattr(module, param_name)
|
|
for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
|
|
if hasattr(module, param_name):
|
|
delattr(module, param_name)
|
|
|
|
###########
|
|
# HELPERS #
|
|
###########
|
|
def flat_param_to(self, *args, **kwargs):
|
|
"""Wraps an in-place call to ``.to()`` for ``self.flat_param``."""
|
|
self.flat_param.data = self.flat_param.to(*args, **kwargs)
|
|
if self._use_orig_params:
|
|
# Refresh the views because their storage may have changed
|
|
if self.is_sharded(self.flat_param):
|
|
self._use_sharded_views()
|
|
else:
|
|
self._use_unsharded_views(as_params=True)
|
|
|
|
def _get_modules(self) -> Set[nn.Module]:
|
|
"""Returns a :class:`set` of the modules whose parameters are included
|
|
in this handle's flattened parameter."""
|
|
return {pi.module for pi in self.flat_param._param_infos}.union(
|
|
{spi.module for spi in self.flat_param._shared_param_infos}
|
|
)
|
|
|
|
def is_sharded(self, tensor: Tensor) -> bool:
|
|
"""
|
|
Returns if ``tensor`` is *currently* sharded. For ``NO_SHARD``, we
|
|
choose to have this always return ``False`` for clarity.
|
|
"""
|
|
if (
|
|
not hasattr(self.flat_param, "_sharded_size")
|
|
or not self.uses_sharded_strategy
|
|
):
|
|
# `_sharded_size` is defined iff `handle.shard()` has been called
|
|
return False
|
|
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
|
|
return tensor.size() == sharded_size
|
|
|
|
def parameter_module_names(self) -> Iterator[Tuple[str, str]]:
|
|
shared_param_infos = [
|
|
ParamInfo(param_name, module, module_name)
|
|
for (
|
|
param_name,
|
|
module,
|
|
module_name,
|
|
_,
|
|
_,
|
|
_,
|
|
) in self.flat_param._shared_param_infos
|
|
]
|
|
for param_name, _, module_name in chain(
|
|
self.flat_param._param_infos, shared_param_infos
|
|
):
|
|
yield (param_name, module_name)
|
|
|
|
def shared_parameter_module_names(self) -> Iterator[Tuple[str, str]]:
|
|
for param_name, _, module_name in [
|
|
ParamInfo(param_name, module, module_name)
|
|
for (
|
|
param_name,
|
|
module,
|
|
module_name,
|
|
_,
|
|
_,
|
|
_,
|
|
) in self.flat_param._shared_param_infos
|
|
]:
|
|
yield (param_name, module_name)
|
|
|
|
@property
|
|
def _fqns_in_shard(self) -> List[str]:
|
|
"""Returns the FQNs of the parameters present in this rank's shard."""
|
|
fqns_in_shard: List[str] = []
|
|
start, end = self.flat_param._shard_indices # type: ignore[attr-defined]
|
|
for i in range(len(self.flat_param._fqns)):
|
|
if i >= start and i <= end and self.flat_param._shard_param_offsets: # type: ignore[attr-defined]
|
|
fqns_in_shard.append(self.flat_param._fqns[i])
|
|
return fqns_in_shard
|
|
|
|
@property
|
|
def sharded_grad(self) -> Optional[Tensor]:
|
|
"""Returns the handle's sharded gradient."""
|
|
flat_param = self.flat_param
|
|
# Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
|
|
# - CPU offloading: `_cpu_grad`
|
|
# - No CPU offloading + sharded strategies: `_saved_grad_shard`
|
|
# - No CPU offloading + `NO_SHARD`: `grad`
|
|
if hasattr(flat_param, "_cpu_grad"):
|
|
grad = flat_param._cpu_grad # type: ignore[attr-defined]
|
|
elif hasattr(flat_param, "_saved_grad_shard"):
|
|
# In the post-backward hook, the sharded gradient is still in
|
|
# `_saved_grad_shard`.
|
|
grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
else:
|
|
# If in IDLE or in FORWARD states, then there may be an
|
|
# (accumulated) gradient. If accessed in IDLE, then this should
|
|
# be due to re-registering the original parameters (e.g. in state
|
|
# dict load).
|
|
_p_assert(
|
|
flat_param.grad is None
|
|
or not self.uses_sharded_strategy
|
|
or self._training_state
|
|
in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
|
|
"Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
|
|
"unless in IDLE or FORWARD",
|
|
)
|
|
grad = flat_param.grad
|
|
return grad
|
|
|
|
def _reset_is_grad_none(self) -> None:
|
|
"""
|
|
Resets the ``_is_grad_none`` mask as needed. This method should only be
|
|
called in the post-backward after gradient computation, in which case
|
|
if a parameter requires gradient, then it will surely receive a
|
|
gradient and we may reset its mask entry to ``False``.
|
|
"""
|
|
if not self._use_orig_params:
|
|
return
|
|
_p_assert(
|
|
self._training_state == HandleTrainingState.BACKWARD_POST,
|
|
"Expects to only be called in the post-backward after gradient computation",
|
|
)
|
|
flat_param = self.flat_param
|
|
assert flat_param._params is not None # mypy
|
|
for i, param in enumerate(flat_param._params):
|
|
# As long as the parameter requires gradient, it should receive a
|
|
# meaningful gradient (even if the gradient happens to be zeros)
|
|
if param.requires_grad:
|
|
assert flat_param._is_grad_none is not None # mypy
|
|
flat_param._is_grad_none[i] = False
|
|
|
|
#######################
|
|
# CHECKS & INVARIANTS #
|
|
#######################
|
|
def _check_sharded_strategy(self):
|
|
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
|
|
|
|
def _check_on_compute_device(self, tensor: Tensor):
|
|
_p_assert(
|
|
tensor.device == self.device,
|
|
f"Expects tensor to be on the compute device {self.device}",
|
|
)
|
|
|
|
def _check_on_cpu(self, tensor: Tensor):
|
|
_p_assert(
|
|
tensor.device == torch.device("cpu"),
|
|
f"Expects tensor to be on CPU but got {tensor.device}",
|
|
)
|
|
|
|
@staticmethod
|
|
def _check_storage_freed(tensor: Tensor):
|
|
storage_size: int = tensor._typed_storage()._size()
|
|
_p_assert(
|
|
storage_size == 0,
|
|
f"Expects storage to be freed but got storage with size {storage_size}",
|
|
)
|
|
|
|
@staticmethod
|
|
def _check_storage_allocated(tensor: Tensor):
|
|
storage_size: int = tensor._typed_storage()._size()
|
|
_p_assert(storage_size > 0, "Expects storage to be allocated")
|
|
|
|
def _check_low_precision_shard(self):
|
|
_p_assert(
|
|
self._uses_param_mixed_precision,
|
|
"Not using low precision for parameters",
|
|
)
|
|
_p_assert(
|
|
getattr(self.flat_param, "_mp_shard", None) is not None,
|
|
"Expects `_mp_shard` to exist",
|
|
)
|
|
device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
|
|
_p_assert(
|
|
device == self.device,
|
|
f"Expects the low precision shard to be on {self.device} but got {device}",
|
|
)
|
|
|
|
def _check_unsharded(self, tensor: Tensor):
|
|
msg_prefix = "Expects tensor to be unsharded "
|
|
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
|
unsharded_size = self.flat_param._unpadded_unsharded_size
|
|
_p_assert(
|
|
tensor.size() == unsharded_size,
|
|
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
|
|
)
|
|
|
|
def _check_sharded(self, tensor: Tensor):
|
|
msg_prefix = "Expects tensor to be sharded "
|
|
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
|
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
|
|
_p_assert(
|
|
tensor.size() == sharded_size,
|
|
msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
|
|
)
|
|
|
|
##############
|
|
# PROPERTIES #
|
|
##############
|
|
@property
|
|
def uses_sharded_strategy(self) -> bool:
|
|
return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
|
|
|
|
@property
|
|
def _uses_param_mixed_precision(self) -> bool:
|
|
return self._fwd_bwd_param_dtype != self._orig_param_dtype
|
|
|
|
@property
|
|
def _uses_reduce_mixed_precision(self) -> bool:
|
|
return self._reduce_dtype != self._orig_param_dtype
|
|
|
|
@property
|
|
def _force_full_precision(self) -> bool:
|
|
return (
|
|
self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
|
|
and self._uses_param_mixed_precision
|
|
)
|
|
|
|
|
|
# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
|
|
def _unsafe_setattr_param(
|
|
module: nn.Module, param_name: str, param: nn.Parameter
|
|
) -> None:
|
|
module._parameters[param_name] = param
|
|
# This bypasses any overrides in case `module` is an instance of an
|
|
# `nn.Module` subclass
|
|
super(nn.Module, module).__setattr__(param_name, param)
|
|
|
|
|
|
def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
|
|
module._parameters.pop(param_name, None)
|
|
# This bypasses any overrides in case `module` is an instance of an
|
|
# `nn.Module` subclass
|
|
super(nn.Module, module).__setattr__(param_name, tensor)
|
|
|
|
|
|
def _safe_setattr_tensor_or_param(
|
|
module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
|
|
):
|
|
# Call `delattr()` and `setattr()` to go through `nn.Module` checks
|
|
if hasattr(module, param_name):
|
|
delattr(module, param_name)
|
|
setattr(module, param_name, tensor_or_param)
|
|
|
|
|
|
# A handles key represents the group of `FlatParamHandle`s involved in a given
|
|
# module's forward. These will be all-gathered together in the pre-forward and
|
|
# pre-backward.
|
|
_HandlesKey = Tuple[FlatParamHandle, ...]
|