mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
reland of https://github.com/pytorch/pytorch/pull/133113 I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :( ---- Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203 Approved by: https://github.com/tianyu-l
2736 lines
118 KiB
Python
2736 lines
118 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
import functools
|
|
import logging
|
|
import os
|
|
import warnings
|
|
from enum import auto, Enum
|
|
from itertools import accumulate, chain
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
cast,
|
|
Dict,
|
|
Generator,
|
|
Iterator,
|
|
List,
|
|
NamedTuple,
|
|
no_type_check,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.distributed.fsdp._common_utils import (
|
|
_FSDPDeviceHandle,
|
|
_named_parameters_with_duplicates,
|
|
_no_dispatch_record_stream,
|
|
_set_fsdp_flattened,
|
|
HandleTrainingState,
|
|
)
|
|
from torch.distributed.utils import (
|
|
_alloc_storage,
|
|
_data_ptr_allocated,
|
|
_free_storage,
|
|
_p_assert,
|
|
)
|
|
from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined]
|
|
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
|
|
|
|
from ._fsdp_extensions import (
|
|
_ext_post_unflatten_transform,
|
|
_ext_pre_flatten_transform,
|
|
FSDPExtensions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"FlatParameter",
|
|
"FlatParamHandle",
|
|
"FlatParamShardMetadata",
|
|
"ParamInfo",
|
|
"SharedParamInfo",
|
|
"HandleShardingStrategy",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
"""
|
|
[Note: Fully Sharded Module]
|
|
We define the "fully sharded module" to be the original ``nn.Module`` that owns
|
|
a ``FlatParamHandle``. It is the *single* module logically responsible for the
|
|
*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
|
|
forward or backward pass. The fully sharded module should be passed to the
|
|
``FlatParamHandle`` constructor.
|
|
|
|
For the wrapper code path:
|
|
- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
|
|
runs the unshard/reshard on behalf of the fully sharded module by overriding
|
|
``nn.Module.forward``.
|
|
- The fully sharded module is exactly the module passed to the
|
|
``FullyShardedDataParallel`` constructor's ``module`` argument.
|
|
|
|
For the non-wrapper code path:
|
|
- Hooks registered on the fully sharded module run the unshard/reshard.
|
|
- The fully sharded module may either be the direct argument to ``fully_shard``
|
|
or a submodule chosen by the provided wrapping policy.
|
|
"""
|
|
|
|
# Environment variable toggling whether to use unsafe `setattr()` for view
|
|
# setting in `_use_sharded_views()` and `_use_unsharded_views()`
|
|
# We should use 'safe' by default since it respects method overrides, but for
|
|
# special cases such as for high CPU overhead or for intentionally bypassing
|
|
# checks in the overrides, we may use 'unsafe'.
|
|
_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
|
|
|
|
# Environment variable toggling whether to check for parameter/gradient
|
|
# writeback in case their storages change after FSDP initialization
|
|
# We should check by default since it prevents silent correctness errors, but
|
|
# since such changes are atypical, we may want to skip the check to save CPU
|
|
# overhead, especially since the check happens in the pre-forward and
|
|
# pre-backward each iteration.
|
|
_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK"
|
|
|
|
# Env var toggling whether when model is in .eval() mode, should we run in fp32
|
|
# or the reduced precision.
|
|
_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL"
|
|
|
|
# Some value to set padding in tensors to for debuggability
|
|
_FLAT_PARAM_PADDING_VALUE = 42
|
|
|
|
# Environment variables for disabling the all-gather and reduce-scatter
|
|
# communication ops for ablation studies. Note that without these communication
|
|
# ops the training won't converge, and you probably need to disable correctness
|
|
# checks in your model.
|
|
_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER"
|
|
_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE"
|
|
|
|
|
|
# TODO: Define this for now to avoid circular imports. See if we can remove.
|
|
class HandleShardingStrategy(Enum):
|
|
FULL_SHARD = auto()
|
|
SHARD_GRAD_OP = auto()
|
|
NO_SHARD = auto()
|
|
HYBRID_SHARD = auto()
|
|
_HYBRID_SHARD_ZERO2 = auto()
|
|
|
|
|
|
RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
|
|
HandleShardingStrategy.FULL_SHARD,
|
|
HandleShardingStrategy.HYBRID_SHARD,
|
|
)
|
|
NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
|
|
HandleShardingStrategy.SHARD_GRAD_OP,
|
|
HandleShardingStrategy._HYBRID_SHARD_ZERO2,
|
|
)
|
|
|
|
|
|
class ParamInfo(NamedTuple):
|
|
"""Information for an original parameter."""
|
|
|
|
param_name: str # unprefixed
|
|
module: nn.Module
|
|
module_name: str
|
|
|
|
|
|
class SharedParamInfo(NamedTuple):
|
|
"""
|
|
Additional information for a shared parameter.
|
|
|
|
For each shared parameter, we designate one module and its parameter
|
|
variable to be the primary owner, determined as the first one encountered
|
|
in the parameter walk. These are prefixed with "prim". The primary module
|
|
and parameter do not have their own :class:`SharedParamInfo` instance.
|
|
"""
|
|
|
|
param_name: str # unprefixed
|
|
module: nn.Module
|
|
module_name: str
|
|
prim_param_name: str # unprefixed
|
|
prim_module: nn.Module
|
|
prim_module_name: str
|
|
|
|
|
|
class _ShardParamInfo(NamedTuple):
|
|
"""Shard-related information for an original parameter."""
|
|
|
|
in_shard: bool
|
|
# Use to index into the sharded flat parameter, e.g.
|
|
# `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
|
|
offset_in_shard: Optional[int]
|
|
numel_in_shard: Optional[int]
|
|
# Use to get part of the parameter in the local shard from a flattened
|
|
# version of the unsharded parameter, e.g.
|
|
# `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]`
|
|
intra_param_start_idx: Optional[int]
|
|
intra_param_end_idx: Optional[int] # inclusive
|
|
|
|
|
|
class FlatParamShardMetadata(NamedTuple):
|
|
"""
|
|
This holds metadata specific to this rank's shard of the flat parameter.
|
|
|
|
Attributes:
|
|
param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
|
|
shard of the parameters; see :class:`FlatParameter`.
|
|
param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
|
|
shard of the parameters; see :class:`FlatParameter`.
|
|
param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
|
|
of the parameters; see :class:`FlatParameter`.
|
|
param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
|
|
units of numels) giving this rank's part of each flattened
|
|
original parameter.
|
|
"""
|
|
|
|
param_names: Tuple[str, ...]
|
|
param_shapes: Tuple[torch.Size, ...]
|
|
param_numels: Tuple[int, ...]
|
|
param_offsets: Tuple[Tuple[int, int], ...]
|
|
|
|
|
|
class _FlatParameterMeta(_ParameterMeta):
|
|
# Make `isinstance(t, FlatParameter)` return True for custom tensor
|
|
# instances that have the _is_flat_param flag for BC
|
|
def __instancecheck__(self, instance):
|
|
# NB: do NOT test the super implementation
|
|
return isinstance(instance, torch.Tensor) and getattr(
|
|
instance, "_is_flat_param", False
|
|
)
|
|
|
|
|
|
class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
|
|
"""
|
|
This is the flat parameter used by :class:`FullyShardedDataParallel`.
|
|
|
|
It is comprised of one or more original parameters, which are flattened and
|
|
concatenated to construct the flat parameter.
|
|
|
|
Under the current design, this parameter logically represents both the
|
|
unsharded and sharded flat parameter, and its data changes storages
|
|
dynamically.
|
|
- In the :class:`FullyShardedDataParallel` constructor, the parameter
|
|
is initialized as unsharded and then sharded in-place.
|
|
- At runtime, the parameter is lazily (re)-initialized. The sharded
|
|
parameter data is saved in ``self._local_shard``, and a new ``Tensor``
|
|
``self._full_param_padded`` is created, which is the all-gather
|
|
destination and owns the unsharded parameter storage thereafter. (See
|
|
:meth:`FlatParamHandle.init_flat_param_attributes`.)
|
|
- Throughout runtime, the parameter data changes storages as needed,
|
|
e.g. to the sharded flat parameter, low precision sharded flat
|
|
parameter, or the unsharded flat parameter.
|
|
|
|
NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter``
|
|
padding, we have two versions of the per-parameter numels, one that
|
|
includes the padding (``_numels_with_padding``) and one that does not
|
|
(``_numels``). The former may have length longer than the other data
|
|
structures, while the latter has the same length as the number of actual
|
|
original parameters like the other per-parameter data structures.
|
|
|
|
NOTE: This is not a real class; instead, you will always get a Parameter
|
|
back out if you try to create one of these. This is similar to the trick
|
|
we implemented for Parameter to get it to work with subclasses; this
|
|
is primarily so that FlatParameter supports combination with FakeTensor.
|
|
|
|
Attributes:
|
|
_unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size
|
|
without right-hand-side padding for divisibility by the world size.
|
|
For ``use_orig_params=True``, this includes alignment padding.
|
|
_padded_unsharded_size (torch.Size): Unsharded flat parameter's size
|
|
with right-hand-side padding for divisibility by the world size.
|
|
For ``use_orig_params=True``, this includes alignment padding. This
|
|
is only set for sharded strategies since they require padding for
|
|
the all-gather.
|
|
_sharded_size (torch.Size): Sharded flat parameter's size with padding.
|
|
This is also set for ``NO_SHARD``, in which case it is the same as
|
|
the unsharded sizes. (We omit "padded" because there is no
|
|
analogous unpadded one.)
|
|
|
|
_num_params (int): Number of original parameters flattened into this
|
|
flat parameter. This is the length of the per-parameter data
|
|
structures.
|
|
_param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
|
|
entry; see :class:`ParamInfo` for details.
|
|
_shapes (Tuple[torch.Size, ...]): Each parameter's original shape.
|
|
_fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN)
|
|
prefixed from the ``_fully_sharded_module``. The names are
|
|
guaranteed to be unique in the subtree rooted at that module.
|
|
_param_extensions (Tuple[Optional[Any], ...]): Each parameter's
|
|
extension (i.e. some per-parameter state) used to customize
|
|
pre-flatten and post-unflatten behavior or ``None``. This is
|
|
experimental, and users should not depend on its existence in the
|
|
future.
|
|
_numels_with_padding (Tuple[int, ...]): Each parameter's numel
|
|
including entries for the padding. This is used to construct views
|
|
into the flat parameter via ``torch.split()``. This may have length
|
|
longer than ``_num_params``.
|
|
_numels (Tuple[int, ...]): Each parameter's numel excluding entries for
|
|
padding. This has length equal to ``_num_params``.
|
|
_shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's
|
|
shard parameter info; see :class:`_ShardParamInfo` for details.
|
|
_shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
|
|
info entries; see :class:`SharedParamInfo` for details.
|
|
_modules (Set[nn.Module]): Modules that contain some original parameter
|
|
that is flattened into the flat parameter.
|
|
|
|
_shard_numel_padded (int): Numel padded for this rank's sharded flat
|
|
parameter.
|
|
_local_shard (Tensor): Sharded flat parameter with padding if using a
|
|
sharded strategy. If using ``NO_SHARD``, then this is the unpadded
|
|
unsharded flat parameter, and there is no notion of a sharded flat
|
|
parameter or padded unsharded flat parameter.
|
|
_full_param_padded (Tensor): Unsharded flat parameter with padding.
|
|
This is not defined for ``NO_SHARD``. When using mixed precision
|
|
for parameters, this has the low precision.
|
|
_full_prec_full_param_padded (Tensor): Full precision unsharded flat
|
|
parameter with padding. This is used for unsharding outside of
|
|
computation when using mixed precision for parameters. This is
|
|
never defined for ``NO_SHARD``.
|
|
_post_backward_hook_handle (RemovableHandle):
|
|
Flat parameter's post-backward hook handle. (Compile only)
|
|
_post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
|
|
Flat parameter's :class:`AccumulateGrad` object and post-backward
|
|
hook handle. (Eager only)
|
|
_mp_shard (Tensor): Low precision sharded flat parameter with padding.
|
|
This is only defined when parameter mixed precision is enabled. For
|
|
``NO_SHARD``, this is used for computation.
|
|
_cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
|
|
This is only defined when offloading parameters is enabled.
|
|
_saved_grad_shard (Tensor): Sharded gradient with padding from previous
|
|
iterations for gradient accumulation without :meth:`no_sync`.
|
|
|
|
_params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``,
|
|
then each original parameter variable; otherwise, ``None``. This
|
|
does not include any padding tensors.
|
|
_shared_params (Optional[List[nn.Parameter]]): The original shared
|
|
parameter variables if ``use_orig_params=True`` and ``None``
|
|
otherwise.
|
|
_tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
|
|
views created in the forward and tracked by autograd when
|
|
``use_orig_params=True`` and is ``None`` otherwise. This is to
|
|
preserve those ``Tensor`` variables for the backward to ensure that
|
|
the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
|
|
in which case the post-backward hook does not run. This is relevant
|
|
for cases like reentrant activation checkpointing.
|
|
_is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``,
|
|
a mask over the original parameters' gradients indicating if it is
|
|
logically ``None`` or not; otherwise, ``None``. This does not
|
|
include entries for padding. This mask is needed because only some
|
|
of the parameters may have ``None`` gradient, in which case the
|
|
flat gradient must be non-``None`` and must use zeros to
|
|
approximate those original ``None`` gradients. This mask informs
|
|
FSDP to set the original parameter gradients to ``None`` (instead
|
|
of zeros) as needed.
|
|
"""
|
|
|
|
_unpadded_unsharded_size: torch.Size
|
|
_padded_unsharded_size: torch.Size
|
|
_sharded_size: torch.Size
|
|
_num_params: int
|
|
_param_infos: Tuple[ParamInfo, ...]
|
|
_shapes: Tuple[torch.Size, ...]
|
|
_fqns: Tuple[str, ...]
|
|
_param_extensions: Tuple[Optional[Any], ...]
|
|
_numels_with_padding: Tuple[int, ...]
|
|
_numels: Tuple[int, ...]
|
|
_shard_param_infos: Tuple[_ShardParamInfo, ...]
|
|
_shared_param_infos: Tuple[SharedParamInfo, ...]
|
|
_modules: Set[nn.Module]
|
|
_shard_numel_padded: int
|
|
_local_shard: Tensor
|
|
_full_param_padded: Tensor
|
|
_full_prec_full_param_padded: Tensor
|
|
# Eager only
|
|
_post_backward_hook_state: Tuple[Any, Any]
|
|
# Compile only
|
|
_post_backward_hook_handle: Any
|
|
_mp_shard: Tensor
|
|
_cpu_grad: Tensor
|
|
_saved_grad_shard: Tensor
|
|
_params: Optional[List[nn.Parameter]]
|
|
_shared_params: Optional[List[nn.Parameter]]
|
|
_tensors: Optional[List[Optional[Tensor]]]
|
|
_is_grad_none_mask: Optional[List[bool]]
|
|
|
|
_is_padding_mask: List[bool]
|
|
|
|
def __new__(cls, data=None, requires_grad=True):
|
|
assert cls is FlatParameter, "subclasses FlatParameter not supported"
|
|
r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg]
|
|
r._is_flat_param = True # type: ignore[attr-defined]
|
|
return r
|
|
|
|
# NB: This is not a regular method, because FlatParameters are not actually
|
|
# instances of this class (see __new__ above). So you must indirectly
|
|
# call this directly through the classmethod.
|
|
@classmethod
|
|
def _init_metadata(
|
|
cls,
|
|
self,
|
|
param_infos: List[ParamInfo],
|
|
numels: List[int],
|
|
shapes: List[torch.Size],
|
|
fqns: List[str],
|
|
shared_param_infos: List[SharedParamInfo],
|
|
param_extensions: List[Optional[Any]],
|
|
params: Optional[List[nn.Parameter]],
|
|
shared_params: Optional[List[nn.Parameter]],
|
|
is_padding_mask: List[bool],
|
|
) -> None:
|
|
"""
|
|
Initialize attributes holding metadata about the original parameters comprising the flat parameter.
|
|
|
|
We expose this method separate from the constructor to keep the
|
|
constructor only responsible for the flat parameter's tensor data. This
|
|
method should only be called once per model, while the constructor may
|
|
be called multiple times, e.g. when reloading from a checkpoint, in
|
|
which case only the tensor data needs to be passed to the constructor.
|
|
Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
|
|
metadata is correctly assumed to be unchanged.
|
|
|
|
Args:
|
|
See the Attributes in the class docstring.
|
|
"""
|
|
assert len(param_infos) == len(shapes)
|
|
assert len(param_infos) == len(fqns)
|
|
assert len(param_infos) == len(param_extensions)
|
|
self._num_params = len(param_infos)
|
|
self._param_infos = param_infos
|
|
self._shapes = shapes
|
|
self._fqns = fqns
|
|
self._param_extensions = param_extensions
|
|
self._is_padding_mask = is_padding_mask
|
|
|
|
numels_without_padding: List[int] = []
|
|
for numel, is_padding in zip(numels, is_padding_mask):
|
|
if not is_padding:
|
|
numels_without_padding.append(numel)
|
|
self._numels = tuple(numels_without_padding)
|
|
self._numels_with_padding = tuple(numels)
|
|
assert len(self._numels) == self._num_params
|
|
|
|
self._shared_param_infos = tuple(shared_param_infos)
|
|
self._modules = {pi.module for pi in self._param_infos}.union(
|
|
{spi.module for spi in self._shared_param_infos}
|
|
)
|
|
assert (params is None) == (shared_params is None)
|
|
if params is not None:
|
|
assert shared_params is not None and len(shared_params) == len(
|
|
shared_param_infos
|
|
)
|
|
self._params = []
|
|
for param, is_padding in zip(params, is_padding_mask):
|
|
if not is_padding:
|
|
self._params.append(param)
|
|
self._shared_params = shared_params
|
|
# Mark the original parameters to avoid flattening them into
|
|
# another `FlatParameter` during recursive construction
|
|
for param in chain(self._params, self._shared_params):
|
|
_set_fsdp_flattened(param)
|
|
self._is_grad_none_mask = [False for _ in range(self._num_params)]
|
|
self._tensors = [None for _ in range(self._num_params)]
|
|
else:
|
|
self._params = None
|
|
self._shared_params = None
|
|
self._is_grad_none_mask = None
|
|
self._tensors = None
|
|
self._unpadded_unsharded_size = self.size()
|
|
_set_fsdp_flattened(self)
|
|
# Tracks whether the `FlatParameter`'s post-backward hook has been
|
|
# called to modify the behavior of the post-backward callback
|
|
self._post_backward_called = False
|
|
|
|
|
|
class FlatParamHandle:
|
|
"""
|
|
A handle that manages a flat parameter (:class:`FlatParameter`).
|
|
|
|
This includes sharding and view management.
|
|
|
|
Args:
|
|
params (Sequence[nn.Parameter]): The parameters to flatten into the
|
|
flat parameter.
|
|
fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
|
|
device (torch.device): The compute and communication device, which
|
|
should be a non-CPU device. We refer to it as the compute device.
|
|
sharding_strategy (ShardingStrategy): Sharding strategy to apply to
|
|
this handle's ``FlatParameter``.
|
|
offload_params (bool): Whether to offload the handle's
|
|
``FlatParameter`` to CPU.
|
|
mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
|
|
setting passed to the FSDP constructor.
|
|
mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
|
|
precision setting passed to the FSDP constructor.
|
|
keep_low_precision_grads (bool): Whether to keep gradients in low
|
|
precision.
|
|
use_orig_params (bool): If ``True``, then FSDP preserves the original
|
|
parameter variables and returns them from ``named_parameters()``
|
|
(e.g. to support different optimizer hyperparameters within one
|
|
:class:`FlatParameter`). If ``False``, then FSDP reconstructs the
|
|
parameters every iteration and returns the :class:`FlatParameter` s
|
|
from ``named_parameters()``.
|
|
"""
|
|
|
|
##################
|
|
# INITIALIZATION #
|
|
##################
|
|
def __init__(
|
|
self,
|
|
params: Sequence[Union[nn.Parameter, Tensor]],
|
|
fully_sharded_module: nn.Module,
|
|
device: torch.device,
|
|
sharding_strategy: HandleShardingStrategy,
|
|
offload_params: bool,
|
|
mp_param_dtype: Optional[torch.dtype],
|
|
mp_reduce_dtype: Optional[torch.dtype],
|
|
keep_low_precision_grads: bool,
|
|
process_group: dist.ProcessGroup,
|
|
use_orig_params: bool,
|
|
*,
|
|
fsdp_extension: Optional[FSDPExtensions] = None,
|
|
):
|
|
super().__init__()
|
|
params = list(params)
|
|
if len(params) == 0:
|
|
raise ValueError(
|
|
f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
|
|
)
|
|
self._init_setattr_fns()
|
|
self._skip_writeback_check = (
|
|
os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1"
|
|
)
|
|
self._use_full_prec_in_eval = (
|
|
os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
|
|
)
|
|
self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1"
|
|
self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1"
|
|
if self._skip_writeback_check:
|
|
_warn_skip_writeback_check(
|
|
logger,
|
|
f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check "
|
|
"for parameter or gradient writeback. Changing parameter or "
|
|
"gradient storages may lead to silent correctness errors.",
|
|
)
|
|
if self._use_fake_all_gather:
|
|
_warn_use_fake_all_gather(
|
|
logger,
|
|
f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute "
|
|
"all-gather ops. Your training will be incorrect, but "
|
|
"can reveal how much time spent on all-gather ops.",
|
|
)
|
|
if self._use_fake_reduce:
|
|
_warn_use_fake_reduce(
|
|
logger,
|
|
f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute "
|
|
"reduce-scatter ops. Your training will be incorrect, but "
|
|
"can reveal how much time spent on reduce-scatter ops.",
|
|
)
|
|
# Only align addresses for `use_orig_params=True` (for now)
|
|
align_addresses = use_orig_params
|
|
self._init_get_unflat_views_fn(align_addresses)
|
|
self.device = device
|
|
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
|
|
self.process_group = process_group
|
|
if self._use_fake_all_gather or self._use_fake_reduce:
|
|
self._fake_process_group = FakeProcessGroup(
|
|
rank=process_group.rank(), world_size=process_group.size()
|
|
)
|
|
self.rank = process_group.rank()
|
|
self.world_size = process_group.size()
|
|
self._sharding_strategy = sharding_strategy
|
|
self._offload_params = offload_params
|
|
self._use_orig_params = use_orig_params
|
|
self._keep_low_precision_grads = keep_low_precision_grads
|
|
self._training_state = HandleTrainingState.IDLE
|
|
self._debug_level = dist.get_debug_level()
|
|
self._fully_sharded_module = fully_sharded_module
|
|
# For strategies that do not free after forward, we skip using sharded
|
|
# views after forward since the unsharded data exists. We still switch
|
|
# `self.flat_param` to point to the sharded flat parameter since what
|
|
# it points to parameterizes behavior. We use the following attribute
|
|
# to track which tensor data the parameters are unsharded views into.
|
|
self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None
|
|
# The index in the state's `all_handles`, which must be the
|
|
# same across ranks for the execution order validation to work
|
|
self._handle_index: Optional[int] = None
|
|
# Index in handles_to_pre_forward_order
|
|
self._pre_forward_order_index: Optional[int] = None
|
|
# Index in `handles_post_forward_order`
|
|
self._post_forward_index: Optional[int] = None
|
|
# Used for guarding against mistargeted forward prefetches
|
|
self._needs_pre_forward_unshard = False
|
|
# Used for guarding against mistargeted backward prefetches
|
|
self._needs_pre_backward_unshard = False
|
|
# Was the handle prefetched? Set on successful _prefetch_handle and unshard
|
|
self._prefetched = False
|
|
# Optimistically assume a valid input `params` and set dtype attributes
|
|
# before `_init_flat_param()`, which performs the actual validation
|
|
self._orig_param_dtype = params[0].dtype
|
|
self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
|
|
assert self._fwd_bwd_param_dtype is not None # mypy
|
|
self._aligned_numel = (
|
|
_get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
|
|
if align_addresses
|
|
else 0
|
|
)
|
|
self._fsdp_extension = fsdp_extension
|
|
self._init_flat_param_and_metadata(
|
|
params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type]
|
|
)
|
|
self._use_unsharded_views(as_params=False)
|
|
|
|
def _init_setattr_fns(self):
|
|
use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
|
|
self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
|
|
self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
|
|
if use_unsafe_setattr:
|
|
self._setattr_tensor = _unsafe_setattr_tensor
|
|
self._setattr_param = _unsafe_setattr_param
|
|
else:
|
|
self._setattr_tensor = _safe_setattr_tensor_or_param
|
|
self._setattr_param = _safe_setattr_tensor_or_param
|
|
|
|
def _init_get_unflat_views_fn(self, align_addresses: bool):
|
|
self._get_unflat_views = (
|
|
self._get_unflat_views_aligned
|
|
if align_addresses
|
|
else self._get_unflat_views_unaligned
|
|
)
|
|
|
|
def _init_flat_param_and_metadata(
|
|
self,
|
|
params: List[Union[Tensor, nn.Parameter]],
|
|
module: nn.Module,
|
|
aligned_numel: int,
|
|
use_orig_params: bool,
|
|
) -> None:
|
|
"""
|
|
Initialize the ``FlatParameter`` and its metadata.
|
|
|
|
NOTE: This should only be called once at construction time, after which
|
|
the ``FlatParameter`` metadata is assumed to be static.
|
|
|
|
NOTE: The elements of ``params`` should only be ``Tensor`` s when
|
|
composing with ``DTensor`` -based tensor parallelism, in which case the
|
|
elements may be ``DTensor`` local shards.
|
|
"""
|
|
if len(params) == 0:
|
|
raise ValueError("Expects non-empty `params`")
|
|
if aligned_numel < 0:
|
|
raise ValueError(
|
|
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
|
|
)
|
|
(
|
|
dtype,
|
|
flat_param_requires_grad,
|
|
device,
|
|
) = self._validate_tensors_to_flatten(params)
|
|
params_set = set(params)
|
|
# For alignment padding, only `numels` gets strictly non-`None`
|
|
# elements, and all other lists get `None` elements for padding.
|
|
param_infos: List[ParamInfo] = []
|
|
numels: List[int] = []
|
|
shapes: List[torch.Size] = []
|
|
fqns: List[str] = []
|
|
shared_param_infos: List[SharedParamInfo] = []
|
|
shared_param_memo: Dict[
|
|
Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
|
|
] = {}
|
|
params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
|
|
shared_params: List[Union[Tensor, nn.Parameter]] = []
|
|
param_extensions: List[Any] = []
|
|
is_padding_mask: List[bool] = []
|
|
total_numel = total_numel_without_padding = 0
|
|
for submodule_name, submodule in module.named_modules(remove_duplicate=False):
|
|
for param_name, param in _named_parameters_with_duplicates(
|
|
submodule, recurse=False
|
|
):
|
|
if param not in params_set:
|
|
continue
|
|
if param in shared_param_memo: # shared reference
|
|
prim_module, prim_module_name, prim_param_name = shared_param_memo[
|
|
param
|
|
]
|
|
shared_params.append(param)
|
|
shared_param_infos.append(
|
|
SharedParamInfo(
|
|
param_name,
|
|
submodule,
|
|
submodule_name,
|
|
prim_param_name,
|
|
prim_module,
|
|
prim_module_name,
|
|
)
|
|
)
|
|
else:
|
|
if aligned_numel > 0:
|
|
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
|
|
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
|
|
padding_tensor = _construct_padding_tensor(
|
|
numel_to_pad, dtype, False, device
|
|
)
|
|
params_to_flatten.append(padding_tensor)
|
|
is_padding_mask.append(True)
|
|
numels.append(numel_to_pad)
|
|
total_numel += numel_to_pad
|
|
transform_t, extension = _ext_pre_flatten_transform(
|
|
param,
|
|
self._fsdp_extension,
|
|
)
|
|
param = cast(nn.Parameter, transform_t)
|
|
param_extensions.append(extension)
|
|
shared_param_memo[param] = (submodule, submodule_name, param_name)
|
|
params_to_flatten.append(param)
|
|
is_padding_mask.append(False)
|
|
param_infos.append(ParamInfo(param_name, submodule, submodule_name))
|
|
numels.append(param.numel())
|
|
shapes.append(param.shape)
|
|
fqn = (
|
|
submodule_name + "." + param_name
|
|
if submodule_name
|
|
else param_name
|
|
)
|
|
fqns.append(fqn)
|
|
total_numel += param.numel()
|
|
total_numel_without_padding += param.numel()
|
|
if len(params_to_flatten) == 0:
|
|
raise ValueError(
|
|
f"`params` were not found in `module`'s tree"
|
|
f"params: {params}\nmodule: {module}"
|
|
)
|
|
if (
|
|
self.rank == 0
|
|
and aligned_numel > 0
|
|
and total_numel != total_numel_without_padding
|
|
):
|
|
logger.debug(
|
|
"FSDP FlatParameter address alignment created "
|
|
"%s numel of padding (%s vs. %s)",
|
|
total_numel - total_numel_without_padding,
|
|
total_numel,
|
|
total_numel_without_padding,
|
|
)
|
|
if aligned_numel > 0:
|
|
# Pad to be divisible by world size to avoid a copy for the
|
|
# post-backward reduce-scatter
|
|
numel_to_pad = self.world_size - (total_numel % self.world_size)
|
|
if numel_to_pad > 0 and numel_to_pad < self.world_size:
|
|
if self.rank == 0:
|
|
logger.info(
|
|
"FSDP FlatParameter world size divisibility created "
|
|
"%s numel of padding",
|
|
numel_to_pad,
|
|
)
|
|
padding_tensor = _construct_padding_tensor(
|
|
numel_to_pad, dtype, False, device
|
|
)
|
|
params_to_flatten.append(padding_tensor)
|
|
is_padding_mask.append(True)
|
|
numels.append(numel_to_pad)
|
|
total_numel += numel_to_pad
|
|
# Pass `aligned_numel=0` since we already included padding tensors
|
|
self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
|
|
params_to_flatten,
|
|
aligned_numel=0,
|
|
requires_grad=flat_param_requires_grad,
|
|
)
|
|
FlatParameter._init_metadata(
|
|
self.flat_param,
|
|
param_infos,
|
|
numels,
|
|
shapes,
|
|
fqns,
|
|
shared_param_infos,
|
|
param_extensions,
|
|
_convert_to_params(params_to_flatten) if use_orig_params else None,
|
|
_convert_to_params(shared_params) if use_orig_params else None,
|
|
is_padding_mask,
|
|
)
|
|
|
|
def _validate_tensors_to_flatten(
|
|
self, tensors: List[Union[Tensor, nn.Parameter]]
|
|
) -> Tuple:
|
|
"""Validate the tensors to flatten and returns any necessary metadata."""
|
|
dtype: Optional[torch.dtype] = None
|
|
# Return as the logical OR over each tensor's value
|
|
flat_param_requires_grad: Optional[bool] = None
|
|
device: Optional[torch.device] = None
|
|
# For `use_orig_params=True`, permit non-uniform `requires_grad`
|
|
for tensor in tensors:
|
|
if isinstance(tensor, FlatParameter):
|
|
raise ValueError("Cannot flatten a `FlatParameter`")
|
|
if dtype is None and not tensor.is_floating_point():
|
|
raise ValueError("Cannot flatten integer dtype tensors")
|
|
if dtype is not None and tensor.dtype != dtype:
|
|
raise ValueError(
|
|
f"Must flatten tensors with uniform dtype but got {dtype} "
|
|
f"and {tensor.dtype}"
|
|
)
|
|
if (
|
|
not self._use_orig_params
|
|
and flat_param_requires_grad is not None
|
|
and tensor.requires_grad != flat_param_requires_grad
|
|
):
|
|
raise ValueError(
|
|
"Must flatten tensors with uniform `requires_grad` when "
|
|
"`use_orig_params=False`"
|
|
)
|
|
if device is not None and tensor.device != device:
|
|
raise ValueError(
|
|
"Must flatten tensors on the same device but got both "
|
|
f"{device} and {tensor.device}"
|
|
)
|
|
dtype = tensor.dtype
|
|
flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
|
|
device = tensor.device
|
|
assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
|
|
return dtype, flat_param_requires_grad, device
|
|
|
|
def flatten_tensors(
|
|
self,
|
|
tensors: List[Tensor],
|
|
aligned_numel: int,
|
|
) -> Tensor:
|
|
"""
|
|
Flatten ``tensors`` into a single flat tensor.
|
|
|
|
The flattening optionally includes
|
|
padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
|
|
gives the numel required to have address alignment.
|
|
|
|
NOTE: The padding alignment algorithm must be kept in sync with
|
|
:meth:`_init_flat_param_metadata`. We separate the two methods because
|
|
the initialization happens once, whereas this method may be called
|
|
multiple times throughout training (e.g. for checkpointing).
|
|
"""
|
|
if len(tensors) == 0:
|
|
raise ValueError("Expects non-empty `tensors`")
|
|
if aligned_numel < 0:
|
|
raise ValueError(
|
|
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
|
|
)
|
|
dtype, _, device = self._validate_tensors_to_flatten(tensors)
|
|
flat_tensors: List[Tensor] = []
|
|
if aligned_numel > 0:
|
|
total_numel = 0
|
|
for tensor in tensors:
|
|
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
|
|
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
|
|
padding_tensor = _construct_padding_tensor(
|
|
numel_to_pad, dtype, False, device
|
|
)
|
|
flat_tensors.append(padding_tensor)
|
|
total_numel += numel_to_pad
|
|
flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
|
|
total_numel += tensor.numel()
|
|
numel_to_pad = self.world_size - (total_numel % self.world_size)
|
|
if numel_to_pad > 0 and numel_to_pad < self.world_size:
|
|
padding_tensor = _construct_padding_tensor(
|
|
numel_to_pad, dtype, False, device
|
|
)
|
|
flat_tensors.append(padding_tensor)
|
|
total_numel += numel_to_pad
|
|
else:
|
|
flat_tensors = [
|
|
torch.flatten(_detach_if_needed(tensor)) for tensor in tensors
|
|
]
|
|
return torch.cat(flat_tensors, dim=0)
|
|
|
|
def flatten_tensors_into_flat_param(
|
|
self,
|
|
tensors: List[Tensor],
|
|
aligned_numel: int,
|
|
requires_grad: bool,
|
|
) -> FlatParameter:
|
|
flat_param_data = self.flatten_tensors(tensors, aligned_numel)
|
|
return FlatParameter(flat_param_data, requires_grad=requires_grad)
|
|
|
|
def _init_param_reduce_dtypes(
|
|
self,
|
|
mp_param_dtype: Optional[torch.dtype],
|
|
mp_reduce_dtype: Optional[torch.dtype],
|
|
) -> None:
|
|
"""
|
|
Initialize param and reduce dtypes.
|
|
|
|
Precondition: ``self.flat_param`` is set. This ensures that this
|
|
handle's parameters have a single dtype.
|
|
|
|
Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
|
|
``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
|
|
is ``None``, then we assume the original parameter dtype. One special
|
|
case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
|
|
is ``None``, in which case we assume the gradient reduction dtype
|
|
matches the forward/backward parameter dtype.
|
|
"""
|
|
# Save whether these dtypes were specified so that we permit the
|
|
# parameter dtype to change up until the lazy initialization
|
|
self._low_prec_param_dtype_specified = mp_param_dtype is not None
|
|
self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
|
|
if (
|
|
self._low_prec_param_dtype_specified
|
|
and not self._low_prec_reduce_dtype_specified
|
|
):
|
|
# Special case: infer gradient reduction mixed precision
|
|
self._fwd_bwd_param_dtype = mp_param_dtype
|
|
self._reduce_dtype = self._fwd_bwd_param_dtype
|
|
else:
|
|
self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
|
|
self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
|
|
assert self._fwd_bwd_param_dtype is not None
|
|
assert self._reduce_dtype is not None
|
|
|
|
###################################
|
|
# SHARD INITIALIZATION & METADATA #
|
|
###################################
|
|
@torch.no_grad()
|
|
def shard(self):
|
|
"""
|
|
Shard the handle's ``FlatParameter``.
|
|
|
|
This allocates new memory for
|
|
the sharded flat parameter and frees the unsharded flat parameter's
|
|
storage.
|
|
|
|
Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
|
|
metadata attributes are set for all sharding strategies.
|
|
"""
|
|
flat_param = self.flat_param
|
|
if not self.uses_sharded_strategy:
|
|
self._init_shard_metadata(0, 0, flat_param.numel() - 1)
|
|
else:
|
|
_p_assert(
|
|
flat_param.storage_offset() == 0,
|
|
"The `FlatParameter` is not the sole occupant of its storage",
|
|
)
|
|
sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
|
|
flat_param, self.rank, self.world_size
|
|
)
|
|
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
|
|
allocated = flat_param._typed_storage()._size() > 0
|
|
if allocated:
|
|
flat_param._typed_storage()._resize_(0)
|
|
flat_param.set_(sharded_flat_param) # type: ignore[call-overload]
|
|
start_idx = sharded_flat_param.numel() * self.rank
|
|
end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
|
|
self._init_shard_metadata(numel_padded, start_idx, end_idx)
|
|
if self._use_orig_params:
|
|
self._use_sharded_views()
|
|
|
|
def _init_shard_metadata(
|
|
self,
|
|
numel_padded: int,
|
|
unsharded_start_idx: int,
|
|
unsharded_end_idx: int,
|
|
) -> None:
|
|
"""
|
|
Initialize shard-related metadata for this rank's shard of the flat parameter.
|
|
|
|
This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``.
|
|
|
|
Args:
|
|
numel_padded (int): Numel padded for this rank's sharded flat
|
|
parameter.
|
|
unsharded_start_idx (int): Start index in the unsharded flat
|
|
parameter assigned to this rank.
|
|
unsharded_end_idx (int): End index (inclusive) in the unsharded
|
|
flat parameter assigned to this rank.
|
|
|
|
Precondition: ``self.flat_param`` 's data is the sharded flat
|
|
parameter.
|
|
"""
|
|
flat_param = self.flat_param
|
|
flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined]
|
|
sharded_flat_param_numel = flat_param.numel() # includes `numel_padded`
|
|
_p_assert(
|
|
unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
|
|
f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
|
|
)
|
|
_p_assert(
|
|
numel_padded <= sharded_flat_param_numel,
|
|
f"numel_padded: {numel_padded} "
|
|
f"sharded_flat_param_numel: {sharded_flat_param_numel}",
|
|
)
|
|
shard_param_infos = self._get_shard_metadata(
|
|
unsharded_start_idx, unsharded_end_idx
|
|
)
|
|
assert (
|
|
len(shard_param_infos) == flat_param._num_params
|
|
), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
|
|
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
|
|
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
|
|
|
|
def _get_shard_metadata(
|
|
self,
|
|
unsharded_start_idx: int,
|
|
unsharded_end_idx: int,
|
|
) -> Tuple[_ShardParamInfo, ...]:
|
|
"""
|
|
Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive).
|
|
|
|
``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the
|
|
unsharded flat parameter specifying the shard.
|
|
"""
|
|
flat_param_offsets = self._get_flat_param_offsets()
|
|
assert len(flat_param_offsets) == len(
|
|
self.flat_param._numels_with_padding
|
|
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
|
|
shard_param_infos: List[_ShardParamInfo] = []
|
|
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
|
|
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
|
|
# into the unsharded flat parameter (inclusive) of the given parameter
|
|
for i, (
|
|
(unsharded_param_start_idx, unsharded_param_end_idx),
|
|
is_padding,
|
|
) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
|
|
if is_padding:
|
|
continue
|
|
in_sharded_flat_param = (
|
|
unsharded_start_idx <= unsharded_param_end_idx
|
|
and unsharded_end_idx >= unsharded_param_start_idx
|
|
)
|
|
if not in_sharded_flat_param:
|
|
shard_param_info = _ShardParamInfo(False, None, None, None, None)
|
|
else:
|
|
if unsharded_start_idx <= unsharded_param_start_idx:
|
|
# This branch can only happen once since the rank's
|
|
# unsharded start index can only intersect one parameter
|
|
intra_param_start_idx = 0
|
|
offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
|
|
else:
|
|
intra_param_start_idx = (
|
|
unsharded_start_idx - unsharded_param_start_idx
|
|
)
|
|
offset_in_shard = 0
|
|
assert (
|
|
offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
|
|
), (
|
|
f"Invalid `offset_in_shard` of {offset_in_shard} for "
|
|
f"sharded flat parameter with {sharded_flat_param_numel} numel"
|
|
)
|
|
intra_param_end_idx = (
|
|
min(unsharded_param_end_idx, unsharded_end_idx)
|
|
- unsharded_param_start_idx
|
|
)
|
|
numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
|
|
shard_param_info = _ShardParamInfo(
|
|
True,
|
|
offset_in_shard,
|
|
numel_in_shard,
|
|
intra_param_start_idx,
|
|
intra_param_end_idx,
|
|
)
|
|
shard_param_infos.append(shard_param_info)
|
|
return tuple(shard_param_infos)
|
|
|
|
@staticmethod
|
|
def _get_unpadded_shard(
|
|
tensor: Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
) -> Tuple[Tensor, int]:
|
|
"""
|
|
Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``.
|
|
|
|
The returned value is a tuple of the shard of ``tensor`` without any
|
|
padding and the numel to pad for that shard.
|
|
|
|
If ``tensor`` is already flattened or may be viewed in the flattened
|
|
shape (which is true in the expected usage), then this method does not
|
|
allocate any new tensor memory.
|
|
"""
|
|
chunks = torch.flatten(tensor).chunk(world_size)
|
|
if len(chunks) < (rank + 1):
|
|
# This rank gets an empty chunk fully padded with zeros since there
|
|
# are not enough chunks across ranks
|
|
chunk = chunks[0].new_empty(0)
|
|
else:
|
|
chunk = chunks[rank]
|
|
numel_to_pad = chunks[0].numel() - chunk.numel()
|
|
assert (
|
|
numel_to_pad >= 0
|
|
), "Chunk's size should be at most the first chunk's size"
|
|
return chunk, numel_to_pad
|
|
|
|
@staticmethod
|
|
def _get_shard(
|
|
tensor: Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
) -> Tuple[Tensor, int]:
|
|
"""
|
|
Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard.
|
|
|
|
This method allocates new memory (via :meth:`clone`) since the
|
|
unsharded ``tensor`` may be deallocated after this method returns.
|
|
"""
|
|
chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
|
|
tensor, rank, world_size
|
|
)
|
|
shard = chunk.clone()
|
|
if numel_to_pad > 0:
|
|
shard = F.pad(shard, [0, numel_to_pad])
|
|
return shard, numel_to_pad
|
|
|
|
@staticmethod
|
|
def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
|
|
"""
|
|
Return the shape of ``tensor`` after sharding including padding.
|
|
|
|
This requires ``tensor`` to have 1D shape and ensures that the returned
|
|
shape is 1D.
|
|
"""
|
|
assert len(tensor.shape) == 1, f"{tensor.shape}"
|
|
unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
|
|
tensor, rank, world_size
|
|
)
|
|
unpadded_sharded_size = unpadded_sharded_tensor.size()
|
|
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
|
|
return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
|
|
|
|
def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
|
|
"""
|
|
Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
|
|
|
|
NOTE: The returned list includes elements for alignment padding.
|
|
"""
|
|
cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
|
|
starts = [0] + cumulative_sum[:-1]
|
|
ends = [end - 1 for end in cumulative_sum] # inclusive
|
|
param_offsets = list(zip(starts, ends))
|
|
return param_offsets
|
|
|
|
@no_type_check
|
|
def shard_metadata(
|
|
self,
|
|
) -> FlatParamShardMetadata:
|
|
"""
|
|
Return the shard-related metadata specific to this rank's shard of the flat parameter.
|
|
|
|
NOTE: The returned tuple does not include elements for alignment
|
|
padding but does account for the padding.
|
|
"""
|
|
fqns_list = []
|
|
shapes_list = []
|
|
numels_list = []
|
|
shard_param_offsets = []
|
|
for fqn, shape, numel, shard_param_info in zip(
|
|
self.flat_param._fqns,
|
|
self.flat_param._shapes,
|
|
self.flat_param._numels,
|
|
self.flat_param._shard_param_infos,
|
|
):
|
|
if not shard_param_info.in_shard:
|
|
continue
|
|
fqns_list.append(fqn)
|
|
shapes_list.append(shape)
|
|
numels_list.append(numel)
|
|
shard_param_offsets.append(
|
|
(
|
|
shard_param_info.intra_param_start_idx,
|
|
shard_param_info.intra_param_end_idx,
|
|
)
|
|
)
|
|
return FlatParamShardMetadata(
|
|
tuple(fqns_list),
|
|
tuple(shapes_list),
|
|
tuple(numels_list),
|
|
tuple(shard_param_offsets),
|
|
)
|
|
|
|
@no_type_check
|
|
@torch.no_grad()
|
|
def init_flat_param_attributes(self) -> None:
|
|
"""
|
|
This initializes some attributes on the handle's ``FlatParameter``.
|
|
This should be called during lazy initialization since it requires the
|
|
parameter to be on the compute device if not offloading to CPU and we
|
|
want to give users the chance to move the parameter appropriately after
|
|
the FSDP constructor.
|
|
|
|
For each tensor attribute on the ``FlatParameter``, see the unshard and
|
|
reshard methods in this class for the allocation and free pattern.
|
|
"""
|
|
flat_param = self.flat_param
|
|
if flat_param.dtype != self._orig_param_dtype:
|
|
# Entering this branch means that the user changed the parameter
|
|
# dtype after FSDP initialization, in which case we may need to
|
|
# refresh some saved dtype attributes (dtypes specified as a part
|
|
# of mixed precision take precedence).
|
|
if not self._low_prec_param_dtype_specified:
|
|
self._fwd_bwd_param_dtype = flat_param.dtype
|
|
# For `reduce_dtype`, require `param_dtype` was not specified since
|
|
# then we infer the `reduce_dtype` from the specified `param_dtype`
|
|
if (
|
|
not self._low_prec_reduce_dtype_specified
|
|
and not self._low_prec_param_dtype_specified
|
|
):
|
|
self._reduce_dtype = flat_param.dtype
|
|
self._orig_param_dtype = flat_param.dtype
|
|
cpu_device = torch.device("cpu")
|
|
if self._offload_params:
|
|
_p_assert(
|
|
flat_param.device == cpu_device,
|
|
f"Expects the `FlatParameter` to be on CPU when parameter CPU "
|
|
f"offloading is enabled, not {flat_param.device}",
|
|
)
|
|
else:
|
|
self._check_on_compute_device(self.flat_param)
|
|
flat_param._local_shard = flat_param.data
|
|
if self._offload_params:
|
|
# Pin the memory for faster H2D transfer
|
|
flat_param._local_shard = flat_param._local_shard.pin_memory(
|
|
device=self.device
|
|
)
|
|
# Pre-allocate the sharded gradient on CPU to enable non-blocking
|
|
# D2H transfer during the backward pass
|
|
flat_param._cpu_grad = torch.zeros_like(
|
|
flat_param._local_shard, device=cpu_device
|
|
).pin_memory(device=self.device)
|
|
if self._uses_param_mixed_precision:
|
|
# For parameter mixed precision, we maintain a low precision
|
|
# sharded tensor on the compute device to be all-gathered (for
|
|
# sharded strategies) or directly used (for `NO_SHARD`) for
|
|
# computation.
|
|
flat_param._mp_shard = torch.empty_like(
|
|
flat_param._local_shard,
|
|
device=self.device,
|
|
dtype=self._fwd_bwd_param_dtype,
|
|
)
|
|
_free_storage(flat_param._mp_shard)
|
|
if self.uses_sharded_strategy:
|
|
# We maintain a padded unsharded tensor that serves as the
|
|
# all-gather destination and owns the original parameter storages.
|
|
unsharded_param_dtype = (
|
|
self._fwd_bwd_param_dtype
|
|
if self._uses_param_mixed_precision
|
|
else flat_param.dtype
|
|
) # use low precision if parameter mixed precision is enabled
|
|
padded_unsharded_numel = flat_param.numel() * self.world_size
|
|
flat_param._full_param_padded = torch.empty(
|
|
padded_unsharded_numel,
|
|
device=self.device,
|
|
dtype=unsharded_param_dtype,
|
|
)
|
|
flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
|
|
_free_storage(flat_param._full_param_padded)
|
|
|
|
if self._uses_param_mixed_precision:
|
|
# For parameter mixed precision, we maintain a full precision
|
|
# padded unsharded tensor for when we force full precision.
|
|
flat_param._full_prec_full_param_padded = torch.empty(
|
|
padded_unsharded_numel,
|
|
device=self.device,
|
|
dtype=flat_param.dtype, # full precision
|
|
)
|
|
_free_storage(flat_param._full_prec_full_param_padded)
|
|
|
|
###################
|
|
# UNSHARD/RESHARD #
|
|
###################
|
|
def pre_unshard(self) -> bool:
|
|
"""
|
|
Return ``False`` if this is a no-op and ``True`` otherwise.
|
|
|
|
Postcondition: ``self.flat_param`` 's data is on the device for
|
|
communication and is what should be all-gathered. This means that it
|
|
matches the dtype of the expected unsharded parameter.
|
|
"""
|
|
if (
|
|
self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
|
|
and self._skipped_use_sharded_views
|
|
):
|
|
# Since this path imposes special semantics for the unsharded flat
|
|
# parameter (e.g. forcing full precision), use sharded views to
|
|
# reuse the existing logic for that special handling
|
|
self._use_sharded_views()
|
|
ret = False
|
|
if self._use_orig_params and not self._skip_writeback_check:
|
|
ret = self._writeback_orig_params()
|
|
if (
|
|
self.uses_sharded_strategy
|
|
and not self._offload_params
|
|
and not self.needs_unshard()
|
|
):
|
|
pass # no-op
|
|
elif self._uses_param_mixed_precision and not self._force_full_precision:
|
|
self._use_low_precision_shard()
|
|
ret = True
|
|
elif self._offload_params and self.flat_param.device != self.device:
|
|
# NOTE: This creates a new tensor distinct from any attributes.
|
|
self.flat_param_to(self.device, non_blocking=True)
|
|
ret = True
|
|
self._check_on_compute_device(self.flat_param)
|
|
return ret
|
|
|
|
def _use_low_precision_shard(self):
|
|
"""Allocate on the compute device and switch to using the low precision sharded flat parameter."""
|
|
self._check_low_precision_shard()
|
|
flat_param = self.flat_param
|
|
_alloc_storage(
|
|
flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
|
|
)
|
|
# `copy_()` implicitly casts to the low precision
|
|
flat_param._mp_shard.copy_( # type: ignore[attr-defined]
|
|
flat_param._local_shard.to( # type: ignore[attr-defined]
|
|
self.device, non_blocking=True
|
|
)
|
|
)
|
|
# Invariant: `_mp_shard` is always on the compute device.
|
|
flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
|
|
|
|
def unshard(self):
|
|
"""
|
|
Run the unshard logic.
|
|
|
|
This includes all-gathering the flat parameter
|
|
and switching to using the unsharded flat parameter. If the handle does
|
|
not need unsharding, then this only switches to using the unsharded
|
|
flat parameter. For ``NO_SHARD``, this is a no-op.
|
|
|
|
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
|
|
mixed precision, then the parameter is forced to full precision.
|
|
"""
|
|
if not self.needs_unshard():
|
|
# Even when not needing an unshard, we should switch to using
|
|
# the unsharded flat parameter
|
|
unsharded_flat_param = (
|
|
self._get_padded_unsharded_flat_param()
|
|
if self.uses_sharded_strategy
|
|
else self.flat_param
|
|
)
|
|
self._use_unsharded_flat_param(unsharded_flat_param)
|
|
return
|
|
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
|
|
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
|
|
self._use_unsharded_flat_param(padded_unsharded_flat_param)
|
|
|
|
def needs_unshard(self) -> bool:
|
|
"""Return if the handle's flat parameter needs to be unsharded."""
|
|
if not self.uses_sharded_strategy:
|
|
return False
|
|
unsharded_flat_param = self._get_padded_unsharded_flat_param()
|
|
already_unsharded = _same_storage_size(
|
|
unsharded_flat_param, unsharded_flat_param.numel()
|
|
)
|
|
return not already_unsharded
|
|
|
|
def _alloc_padded_unsharded_flat_param(self):
|
|
"""
|
|
Allocate the *padded* unsharded flat parameter.
|
|
|
|
The unpadded unsharded
|
|
flat parameter is always a view into the padded one. This padded
|
|
parameter is saved to a different attribute on the ``FlatParameter``
|
|
depending on if we force full precision.
|
|
"""
|
|
self._check_sharded_strategy()
|
|
flat_param = self.flat_param
|
|
unsharded_flat_param = self._get_padded_unsharded_flat_param()
|
|
self._check_storage_freed(unsharded_flat_param)
|
|
_alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
|
|
return unsharded_flat_param
|
|
|
|
def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
|
|
"""
|
|
Return a reference to the padded unsharded flat parameter depending on the calling context.
|
|
|
|
This should only be called if using a sharded strategy.
|
|
"""
|
|
self._check_sharded_strategy()
|
|
flat_param = self.flat_param
|
|
if self._force_full_precision and self._uses_param_mixed_precision:
|
|
# When parameter mixed precision is enabled, we use a different
|
|
# tensor as the all-gather destination to preserve the invariant
|
|
# that `_full_param_padded` is in the low precision
|
|
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
|
|
_p_assert(
|
|
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
|
|
f"Expects full precision but got {self._fwd_bwd_param_dtype}",
|
|
)
|
|
# For no-reshard-after-forward strategies, `_full_param_padded` may
|
|
# still be allocated from a previous forward. As we are forcing
|
|
# full precision here, the full-precision unsharded copy may be
|
|
# modified, invalidating the existing low-precision unsharded copy,
|
|
# so we should free it here to ensure a new all-gather for the next
|
|
# forward/backward computation to persist the modifications.
|
|
if flat_param._full_param_padded.untyped_storage().size() > 0:
|
|
_free_storage(flat_param._full_param_padded)
|
|
else:
|
|
unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
|
|
return unsharded_flat_param
|
|
|
|
def _all_gather_flat_param(
|
|
self,
|
|
padded_unsharded_flat_param: Tensor,
|
|
) -> Tensor:
|
|
"""
|
|
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
|
|
|
|
Then switch to use the all-gathered tensor.
|
|
"""
|
|
_p_assert(
|
|
hasattr(self, "process_group") and hasattr(self, "world_size"),
|
|
"Expects a process group and world size to have been set via `shard()`",
|
|
)
|
|
sharded_flat_param = self.flat_param.data
|
|
expected_numel = sharded_flat_param.numel() * self.world_size
|
|
_p_assert(
|
|
padded_unsharded_flat_param.numel() == expected_numel,
|
|
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
|
|
)
|
|
|
|
pg = (
|
|
self._fake_process_group
|
|
if self._use_fake_all_gather
|
|
else self.process_group
|
|
)
|
|
|
|
# HACK this should be handled by C10D
|
|
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
|
|
tensor_list = list(
|
|
torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))
|
|
)
|
|
dist.all_gather(tensor_list, sharded_flat_param, group=pg)
|
|
else:
|
|
dist.all_gather_into_tensor(
|
|
padded_unsharded_flat_param,
|
|
sharded_flat_param,
|
|
pg,
|
|
)
|
|
|
|
if self._offload_params:
|
|
# In case of offloading, `flat_param.data` (i.e. sharded param) is
|
|
# created on the pre-unshard stream. We need to hand it over to the
|
|
# unshard stream for all-gather
|
|
_no_dispatch_record_stream(
|
|
sharded_flat_param,
|
|
self._device_handle.current_stream(), # unshard_stream
|
|
)
|
|
return padded_unsharded_flat_param
|
|
|
|
def _use_unsharded_flat_param(
|
|
self,
|
|
padded_unsharded_flat_param: torch.Tensor,
|
|
) -> None:
|
|
"""
|
|
Switch to use the *unpadded* unsharded flat parameter.
|
|
|
|
This is a view into the *padded* unsharded flat parameter.
|
|
"""
|
|
unsharded_size = self.flat_param._unpadded_unsharded_size
|
|
flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
|
|
# slicing [:] is not visible to autograd because of .data
|
|
self.flat_param.data = flat_param_part
|
|
in_forward = self._training_state == HandleTrainingState.FORWARD
|
|
in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
|
|
if self._use_orig_params:
|
|
if self._skipped_use_sharded_views and in_pre_backward:
|
|
# This call corresponds to the complementary pre-backward
|
|
# `_use_unsharded_views()` to the skipped pre-forward
|
|
# `_use_sharded_views()`, so we should skip this one too.
|
|
return
|
|
# We use `Tensor` views in the forward so that they are tracked by
|
|
# autograd. We use them in the pre-backward as well to support
|
|
# reentrant activation checkpointing, which needs the views to be
|
|
# tracked by autograd in the backward pass's recomputed forward.
|
|
self._use_unsharded_views(
|
|
as_params=(not in_forward and not in_pre_backward)
|
|
)
|
|
elif in_forward:
|
|
self._use_unsharded_views(as_params=False)
|
|
|
|
def post_unshard(self):
|
|
"""
|
|
Run the post-unshard logic.
|
|
|
|
This includes freeing the low precision shard if needed.
|
|
"""
|
|
if self._uses_param_mixed_precision and self.uses_sharded_strategy:
|
|
self._free_low_precision_sharded_param()
|
|
self._check_on_compute_device(self.flat_param)
|
|
|
|
def _free_low_precision_sharded_param(self):
|
|
"""Frees the low precision sharded flat parameter."""
|
|
self._check_low_precision_shard()
|
|
# `_mp_shard` is allocated in the pre-unshard stream, consumed in the
|
|
# unshard stream for sharded strategies, and consumed in both the
|
|
# unshard and default streams for `NO_SHARD`. For sharded strategies,
|
|
# the current stream here is the unshard stream, and for `NO_SHARD`,
|
|
# it is the default stream. For `NO_SHARD`, only recording for the
|
|
# default stream suffices since the default stream waits for the
|
|
# unshard stream.
|
|
_no_dispatch_record_stream(
|
|
self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined]
|
|
)
|
|
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
|
|
|
|
@torch.no_grad()
|
|
def unshard_grad(self):
|
|
"""
|
|
Unshard the handle's ``FlatParameter``'s gradient.
|
|
|
|
If all ranks have
|
|
``None`` gradient, then all original parameters will as well. This
|
|
method performs an all-reduce and an all-gather. The additional
|
|
all-reduce is tolerable since this method is not meant to be used on
|
|
the computation critical path.
|
|
|
|
Postcondition: ``_saved_grad_shard`` is defined and contains the value
|
|
to set ``flat_param.grad`` after gradients are resharded.
|
|
"""
|
|
if not self.uses_sharded_strategy:
|
|
self._use_unsharded_grad_views()
|
|
return
|
|
flat_param = self.flat_param
|
|
self._check_unsharded(flat_param)
|
|
|
|
# Check if all ranks have a `None` gradient
|
|
num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
|
|
num_grad_none[0] = flat_param.grad is None
|
|
dist.all_reduce(num_grad_none, group=self.process_group)
|
|
if num_grad_none[0] == self.world_size:
|
|
flat_param._saved_grad_shard = None # type: ignore[assignment]
|
|
self._use_unsharded_grad_views()
|
|
return
|
|
|
|
if flat_param.grad is None:
|
|
# In the case that only some ranks have `None` gradient, we use
|
|
# zeros to approximate as a best effort attempt
|
|
if self._debug_level == dist.DebugLevel.INFO:
|
|
warnings.warn(
|
|
f"[Rank {self.rank}] Only some but not all ranks have a "
|
|
"`None` `FlatParameter` gradient, so FSDP is using zeros to "
|
|
"approximate those ranks' sharded gradients being `None`"
|
|
)
|
|
flat_param._saved_grad_shard = None # type: ignore[assignment]
|
|
sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined]
|
|
else:
|
|
self._check_sharded(flat_param.grad)
|
|
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
|
|
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
padded_unsharded_grad = torch.empty(
|
|
flat_param._padded_unsharded_size, # type: ignore[attr-defined]
|
|
device=self.device,
|
|
dtype=sharded_grad.dtype,
|
|
)
|
|
dist.all_gather_into_tensor(
|
|
padded_unsharded_grad, sharded_grad, self.process_group
|
|
)
|
|
unsharded_size = self.flat_param._unpadded_unsharded_size
|
|
flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
|
|
unsharded_size
|
|
)
|
|
self._use_unsharded_grad_views()
|
|
|
|
def reshard_grad(self):
|
|
if self._use_orig_params:
|
|
self._use_sharded_grad_views()
|
|
if not self.uses_sharded_strategy:
|
|
return
|
|
self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
delattr(self.flat_param, "_saved_grad_shard")
|
|
|
|
def prepare_gradient_for_backward(self):
|
|
"""
|
|
Prepare the gradient for the backward computation.
|
|
|
|
This is done by saving and clearing any existing sharded gradient
|
|
in ``.grad`` to enable computing a new unsharded gradient.
|
|
"""
|
|
_p_assert(
|
|
self._training_state
|
|
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
|
|
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
|
|
)
|
|
flat_param = self.flat_param
|
|
if flat_param.grad is not None and (
|
|
flat_param.grad.size() != flat_param._unpadded_unsharded_size
|
|
or flat_param.grad.device != flat_param.device # grad on CPU
|
|
):
|
|
self._check_on_compute_device(self.flat_param)
|
|
grad_offloaded = flat_param.grad.device != self.device
|
|
_p_assert(
|
|
not grad_offloaded or self._offload_params,
|
|
f"Expects the sharded gradient to be on {self.device} "
|
|
f"but got {flat_param.grad.device}",
|
|
)
|
|
prev_iter_synced_gradients = (
|
|
flat_param.grad.size()
|
|
== flat_param._local_shard.size() # type: ignore[attr-defined]
|
|
)
|
|
if prev_iter_synced_gradients:
|
|
# TODO (awgu): Gradient accumulation outside `no_sync()`
|
|
# does not work with CPU offloading. The issue should be
|
|
# that, in the post-backward hook, we cannot do an addition
|
|
# between a CPU tensor (the existing sharded gradient) and
|
|
# a GPU tensor (the new sharded gradient).
|
|
if not grad_offloaded:
|
|
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
|
|
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
else:
|
|
_p_assert(
|
|
hasattr(flat_param, "_cpu_grad"),
|
|
"`_cpu_grad` should be defined if the gradient is on CPU",
|
|
)
|
|
sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined]
|
|
# If user specified to keep the gradient in low precision, then
|
|
# the gradient may still be of the low precision dtype if the
|
|
# user did not set the gradient to `None` after the previous
|
|
# backward, in which case FSDP should cast back to the full
|
|
# precision dtype so that FSDP can accumulate in that dtype in
|
|
# the post-backward hook and assign to `.grad` in that dtype in
|
|
# the post-backward callback.
|
|
local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined]
|
|
if (
|
|
self._keep_low_precision_grads
|
|
and sharded_grad.dtype != local_shard_dtype
|
|
):
|
|
sharded_grad.data = sharded_grad.to(local_shard_dtype)
|
|
else:
|
|
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
|
|
_p_assert(
|
|
flat_param.grad.size() == padded_unsharded_size,
|
|
"Expects `.grad` to be the unsharded gradient in "
|
|
f"`no_sync()` with size {padded_unsharded_size} "
|
|
f"but got size {flat_param.grad.size()}",
|
|
)
|
|
flat_param.grad = None
|
|
|
|
def prepare_gradient_for_optim(self):
|
|
"""Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
|
|
|
|
def cast_grad_to_param_dtype_if_needed(flat_param):
|
|
# TODO (rohan-varma): test for full precision with keep_low_precision_grads
|
|
if not self._force_full_precision and self._keep_low_precision_grads:
|
|
_p_assert(flat_param.grad is not None, "Unexpected None grad!")
|
|
if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
|
|
flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
|
|
if self._use_orig_params:
|
|
self._use_sharded_grad_views()
|
|
|
|
flat_param = self.flat_param
|
|
# TODO (awgu): We should replace these conditional checks to encode
|
|
# the logical intention more directly.
|
|
if hasattr(flat_param, "_cpu_grad"):
|
|
# NOTE: This branch includes `NO_SHARD`.
|
|
self._check_sharded(flat_param)
|
|
self._check_on_cpu(flat_param)
|
|
flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined]
|
|
cast_grad_to_param_dtype_if_needed(flat_param)
|
|
elif hasattr(flat_param, "_saved_grad_shard"):
|
|
self._check_sharded(flat_param)
|
|
self._check_on_compute_device(flat_param)
|
|
if flat_param._saved_grad_shard is not None:
|
|
self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined]
|
|
# If no sharded gradient was computed this iteration, then there is
|
|
# no need to forward `_saved_grad_shard` to `grad`
|
|
if flat_param._post_backward_called: # type: ignore[attr-defined]
|
|
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
if flat_param.grad is not None:
|
|
cast_grad_to_param_dtype_if_needed(flat_param)
|
|
else:
|
|
_p_assert(
|
|
not self.uses_sharded_strategy
|
|
or not flat_param._post_backward_called, # type: ignore[attr-defined]
|
|
"All sharded parameters that received a gradient in the "
|
|
"post-backward should use `_saved_grad_shard`",
|
|
)
|
|
# Delete `_saved_grad_shard` since its existence indicates a previous
|
|
# gradient to accumulate with in the post-backward hook
|
|
if hasattr(flat_param, "_saved_grad_shard"):
|
|
delattr(flat_param, "_saved_grad_shard")
|
|
|
|
@contextlib.contextmanager
|
|
def to_cpu(self):
|
|
"""
|
|
Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit.
|
|
|
|
For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter
|
|
since (1) there is no reason to include the padding in the copy and (2)
|
|
there is no use case for the sharded flat parameter.
|
|
|
|
Precondition: ``self.flat_param`` 's data is the unpadded unsharded
|
|
flat parameter on the compute device, and the handle uses a sharded
|
|
strategy.
|
|
Postcondition: Same as the precondition.
|
|
"""
|
|
self._check_sharded_strategy()
|
|
_p_assert(
|
|
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
|
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
|
)
|
|
self._check_on_compute_device(self.flat_param)
|
|
# Check that the unpadded unsharded flat parameter is a view into the
|
|
# padded unsharded flat parameter as expected
|
|
# NOTE: This check is not strictly needed for correctness but is a
|
|
# useful sanity check since the tensor should only be used internally.
|
|
_p_assert(
|
|
_same_storage(self.flat_param, self._get_padded_unsharded_flat_param()),
|
|
"Expects the unpadded parameter to be a view into the padded parameter",
|
|
)
|
|
self.flat_param_to(torch.device("cpu"))
|
|
self._free_unsharded_flat_param()
|
|
try:
|
|
yield
|
|
finally:
|
|
_p_assert(
|
|
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
|
|
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
|
|
)
|
|
padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
|
|
# Copy from CPU to the compute device
|
|
padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
|
|
self.flat_param
|
|
)
|
|
self._use_unsharded_flat_param(padded_unsharded_flat_param)
|
|
|
|
def reshard(self, free_unsharded_flat_param: bool):
|
|
"""
|
|
Run the reshard logic.
|
|
|
|
This includes freeing the unsharded flat
|
|
parameter if ``free_unsharded_flat_param`` and switching to using the
|
|
sharded flat parameter. Note that this also implicitly offloads
|
|
the sharded flat parameter (if CPU offload is enabled) by pointing
|
|
it to the ``_local_shard`` attribute which resides on CPU.
|
|
"""
|
|
# Switch to the sharded `FlatParameter` before freeing to prevent
|
|
# "use-after-free"-type bugs with external profiling tools, where for
|
|
# `use_orig_params=True`, the `param` does not point to valid memory
|
|
# when setting `param.data = ...` in `_use_sharded_views()`.
|
|
self._use_sharded_flat_param()
|
|
if free_unsharded_flat_param:
|
|
self._free_unsharded_flat_param()
|
|
|
|
def post_reshard(self):
|
|
"""
|
|
Run the post-reshard logic.
|
|
|
|
This includes freeing any memory that
|
|
can now be freed given that the ``FlatParameter`` points to the full
|
|
precision sharded flat parameter.
|
|
|
|
Precondition: ``self.flat_param`` 's data points to the full precision
|
|
sharded flat parameter.
|
|
"""
|
|
# For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it
|
|
# is also the low precision *unsharded* flat parameter. Hence, we delay
|
|
# the free until the reshard.
|
|
if (
|
|
self._uses_param_mixed_precision
|
|
and not self.uses_sharded_strategy
|
|
and not self._force_full_precision # did not use the low precision shard
|
|
):
|
|
self._free_low_precision_sharded_param()
|
|
|
|
def _free_unsharded_flat_param(self):
|
|
"""
|
|
Free the padded unsharded flat parameter. We allow this
|
|
function to be called even when storage is not allocated
|
|
|
|
The tensor to free depends
|
|
on the calling context since the unshard may have forced full
|
|
precision, in which case a different tensor is used.
|
|
"""
|
|
self._check_sharded_strategy()
|
|
unsharded_flat_param = self._get_padded_unsharded_flat_param()
|
|
self._check_on_compute_device(unsharded_flat_param)
|
|
# Do not free the memory until all ops in the current stream finish
|
|
_no_dispatch_record_stream(
|
|
unsharded_flat_param, self._device_handle.current_stream()
|
|
)
|
|
_free_storage(unsharded_flat_param)
|
|
|
|
def _use_sharded_flat_param(self) -> None:
|
|
"""Switches to using the sharded flat parameter."""
|
|
flat_param = self.flat_param
|
|
if self._use_orig_params:
|
|
in_forward = self._training_state == HandleTrainingState.FORWARD
|
|
skip_use_sharded_views = (
|
|
torch.is_grad_enabled()
|
|
and in_forward
|
|
and self._sharding_strategy
|
|
in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
|
|
)
|
|
# Only incur the extra `.data` call if needed
|
|
if skip_use_sharded_views:
|
|
unsharded_flat_param = flat_param.data
|
|
if self._offload_params:
|
|
device = flat_param._local_shard.device # type: ignore[attr-defined]
|
|
_p_assert(
|
|
device == torch.device("cpu"),
|
|
f"Expects the local shard to be on CPU but got {device}",
|
|
)
|
|
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
|
|
if self._use_orig_params:
|
|
if skip_use_sharded_views: # type: ignore[possibly-undefined]
|
|
self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined]
|
|
else:
|
|
self._use_sharded_views()
|
|
# For the post-forward reshard, we may try to use sharded gradient
|
|
# views (or unsharded gradient views if a gradient was accumulated
|
|
# in `no_sync()`), but for the post-backward reshard, we delay the
|
|
# call to after the reduce-scatter.
|
|
if (
|
|
in_forward # type: ignore[possibly-undefined]
|
|
# Skip using gradient views if skipped using sharded views
|
|
# since exposing unsharded parameters with sharded gradients
|
|
# may be confusing to the user
|
|
and not self._skipped_use_sharded_views
|
|
):
|
|
# TODO: Change `_unpadded_unsharded_size` if we change the
|
|
# gradient to be computed directly with padding.
|
|
accumulated_grad_in_no_sync = (
|
|
flat_param.grad is not None
|
|
and self.uses_sharded_strategy
|
|
and flat_param.grad.shape == flat_param._unpadded_unsharded_size
|
|
)
|
|
if accumulated_grad_in_no_sync:
|
|
self._use_unsharded_grad_views()
|
|
else:
|
|
self._use_sharded_grad_views()
|
|
|
|
#########
|
|
# VIEWS #
|
|
#########
|
|
@no_type_check
|
|
def _get_unflat_views_unaligned(
|
|
self,
|
|
tensor: Optional[torch.Tensor] = None,
|
|
) -> Iterator[Tensor]:
|
|
"""
|
|
Return unflattened ``Tensor`` views into ``tensor``.
|
|
|
|
If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based
|
|
on ``flat_param`` 's metadata.
|
|
|
|
Examples for ``tensor`` include ``flat_param.grad`` or unsharded
|
|
tensor optimizer state.
|
|
"""
|
|
flat_param = self.flat_param
|
|
if tensor is None:
|
|
tensor = flat_param
|
|
views = (
|
|
_ext_post_unflatten_transform(
|
|
subtensor.view(shape),
|
|
param_extension,
|
|
self._fsdp_extension,
|
|
)
|
|
for (subtensor, shape, param_extension) in zip(
|
|
torch.split(tensor, flat_param._numels, dim=0),
|
|
flat_param._shapes,
|
|
flat_param._param_extensions,
|
|
)
|
|
)
|
|
return views
|
|
|
|
@no_type_check
|
|
def _get_unflat_views_aligned(
|
|
self,
|
|
tensor: Optional[Tensor] = None,
|
|
) -> List[Tensor]:
|
|
"""
|
|
Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
|
|
|
|
This method has the same contract as :meth:`_get_unflat_views_unaligned`
|
|
except it checks for ``None`` placeholders representing padding for
|
|
alignment, which may incur slightly more CPU overhead.
|
|
"""
|
|
flat_param = self.flat_param
|
|
if tensor is None:
|
|
tensor = flat_param
|
|
splits: List[Tensor] = torch.split(
|
|
tensor, flat_param._numels_with_padding, dim=0
|
|
)
|
|
idx = 0
|
|
views: List[Tensor] = []
|
|
for split, is_padding in zip(splits, flat_param._is_padding_mask):
|
|
if is_padding:
|
|
continue
|
|
views.append(
|
|
_ext_post_unflatten_transform(
|
|
split.view(flat_param._shapes[idx]),
|
|
flat_param._param_extensions[idx],
|
|
self._fsdp_extension,
|
|
)
|
|
)
|
|
idx += 1
|
|
return views
|
|
|
|
@no_type_check
|
|
@torch.enable_grad()
|
|
def _use_unsharded_views(self, as_params: bool) -> None:
|
|
"""
|
|
Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
|
|
|
|
Args:
|
|
as_params (bool): If ``True``, then registers the original
|
|
parameters as ``nn.Parameter`` s; if ``False``, then registers
|
|
the original parameters only as ``Tensor`` s. ``False`` should
|
|
be used during forward/backward computation and when hiding the
|
|
original parameters from :meth:`nn.Module.named_parameters`.
|
|
|
|
Note:
|
|
when prefetching for next forward, current forward may be
|
|
annotated with `@torch.no_grad()`
|
|
`@torch.enable_grad()` ensures non-empty `view.grad_fn`
|
|
otherwise `_post_backward_hook` will not get called
|
|
"""
|
|
flat_param = self.flat_param
|
|
self._check_unsharded(flat_param)
|
|
views = self._get_unflat_views()
|
|
from torch.distributed.tensor import DTensor
|
|
|
|
for i, (view, (param_name, module, _)) in enumerate(
|
|
zip(views, flat_param._param_infos)
|
|
):
|
|
if self._use_orig_params and as_params:
|
|
if type(view) is DTensor:
|
|
# A `DTensor` `view` is not compatible with assigning
|
|
# `param.data = view`, so we cannot preserve the parameter
|
|
# variable.
|
|
self._setattr_param(
|
|
module,
|
|
param_name,
|
|
nn.Parameter(view, requires_grad=flat_param.requires_grad),
|
|
)
|
|
continue
|
|
param = self.flat_param._params[i]
|
|
self._setattr_param(module, param_name, param)
|
|
param.data = view
|
|
elif as_params:
|
|
self._setattr_param(
|
|
module,
|
|
param_name,
|
|
nn.Parameter(view, requires_grad=flat_param.requires_grad),
|
|
)
|
|
else: # `as_params=False`
|
|
param_var: Tensor = view
|
|
if self._use_orig_params:
|
|
if self._training_state == HandleTrainingState.FORWARD:
|
|
# Save the `Tensor` for the pre-backward
|
|
self.flat_param._tensors[i] = view # save for pre-backward
|
|
elif self._training_state == HandleTrainingState.BACKWARD_PRE:
|
|
# Use the saved `Tensor` variable from the forward to
|
|
# preserve the autograd graph so that the post-backward
|
|
# hook fires (e.g. for reentrant AC)
|
|
tensor = self.flat_param._tensors[i]
|
|
tensor.data = view
|
|
param_var = tensor
|
|
self._setattr_tensor(module, param_name, param_var)
|
|
if (
|
|
self._use_orig_params
|
|
and self._training_state == HandleTrainingState.FORWARD
|
|
):
|
|
module._parameters[param_name] = param_var
|
|
for i, (
|
|
param_name,
|
|
module,
|
|
_,
|
|
prim_param_name,
|
|
prim_module,
|
|
_,
|
|
) in enumerate(self.flat_param._shared_param_infos):
|
|
prim_param: Union[Tensor, nn.Parameter] = getattr(
|
|
prim_module, prim_param_name
|
|
)
|
|
_p_assert(
|
|
not as_params or isinstance(prim_param, nn.Parameter),
|
|
f"as_params={as_params} type(prim_param)={type(prim_param)}",
|
|
)
|
|
if self._use_orig_params and as_params:
|
|
shared_param = self.flat_param._shared_params[i]
|
|
self._setattr_param(module, param_name, shared_param)
|
|
shared_param.data = prim_param
|
|
elif as_params:
|
|
self._setattr_param(module, param_name, prim_param)
|
|
else:
|
|
self._setattr_tensor(module, param_name, prim_param)
|
|
if (
|
|
self._use_orig_params
|
|
and self._training_state == HandleTrainingState.FORWARD
|
|
):
|
|
module._parameters[param_name] = prim_param
|
|
|
|
@no_type_check
|
|
def _use_unsharded_grad_views(self) -> None:
|
|
"""
|
|
Unflatten the unsharded flat parameter's gradient.
|
|
|
|
The original parameter variables' gradients are set to be views into
|
|
the unsharded flat parameter's gradient.
|
|
"""
|
|
# Expects the gradient to be in `flat_param.grad`
|
|
if self.flat_param.grad is None:
|
|
for param in chain(self.flat_param._params, self.flat_param._shared_params):
|
|
param.grad = None
|
|
return
|
|
self._check_unsharded(self.flat_param.grad)
|
|
views = self._get_unflat_views(self.flat_param.grad)
|
|
for i, (view, (param_name, module, _)) in enumerate(
|
|
zip(views, self.flat_param._param_infos)
|
|
):
|
|
_p_assert(
|
|
hasattr(module, param_name),
|
|
f"{self.flat_param._fqns[i]} is missing",
|
|
)
|
|
param = getattr(module, param_name)
|
|
if (
|
|
param.shape != view.shape
|
|
or param.dtype != view.dtype
|
|
or param.device != view.device
|
|
):
|
|
# NOTE: This is a hack using `.data` to side step the check
|
|
# that parameter/gradient sizes/dtypes/devices match. From
|
|
# calling `reshard()`, `param` has the sharded size, has the
|
|
# full precision dtype, and if CPU offloading is enabled, is on
|
|
# CPU. Thus, one or more of the following cases can hold when
|
|
# in `no_sync()`, where `view` is the original parameter's
|
|
# gradient:
|
|
# 1. `view` can have the unsharded size.
|
|
# 2. `view` can have the parameter low precision dtype.
|
|
# 3. `view` can be on GPU.
|
|
if param.grad is None:
|
|
param.grad = torch.empty_like(param)
|
|
param.grad.data = view
|
|
else:
|
|
param.grad = view
|
|
for i, (
|
|
param_name,
|
|
module,
|
|
module_name,
|
|
prim_param_name,
|
|
prim_module,
|
|
_,
|
|
) in enumerate(self.flat_param._shared_param_infos):
|
|
_p_assert(
|
|
hasattr(module, param_name),
|
|
f"{module_name + '.' + param_name if module_name else param_name} is missing",
|
|
) # did not save FQN info in `_shared_param_infos`
|
|
param = getattr(module, param_name)
|
|
prim_param = getattr(prim_module, prim_param_name)
|
|
if (
|
|
param.shape != prim_param.grad.shape
|
|
or param.dtype != prim_param.grad.dtype
|
|
or param.device != prim_param.grad.device
|
|
):
|
|
# NOTE: This is the same hack to use `.data` to side step the
|
|
# size check.
|
|
if param.grad is None:
|
|
param.grad = torch.empty_like(param)
|
|
param.grad.data = prim_param.grad
|
|
else:
|
|
param.grad = prim_param.grad
|
|
|
|
@contextlib.contextmanager
|
|
def unflatten_as_params(self) -> Generator:
|
|
"""
|
|
Unflatten the original parameters.
|
|
|
|
The function assumes that the flat parameter is unsharded. When in the context,
|
|
unflattens the original parameters as ``nn.Parameter`` views into the
|
|
flat parameter, and after the context, restores the original parameters
|
|
as ``Tensor`` views into the flat parameter.
|
|
"""
|
|
self._use_unsharded_views(as_params=True)
|
|
try:
|
|
yield
|
|
finally:
|
|
self._use_unsharded_views(as_params=False)
|
|
|
|
@no_type_check
|
|
@torch.no_grad()
|
|
def _use_sharded_views(self) -> None:
|
|
"""
|
|
Set the original parameter variables' data to be flattened views into the sharded flat parameter.
|
|
|
|
The views are kept as flattened to simplify the case where a parameter
|
|
is sharded across ranks. Parameters whose data is not present in the
|
|
sharded flat parameter have their data set to a size-0 empty tensor. We
|
|
do not delete them to ensure to preserve expected behaviors like model
|
|
printability. Parameters whose data is present must preserve their
|
|
variables to be passable to an optimizer.
|
|
"""
|
|
self._unsharded_flat_param_for_skipped_views = None
|
|
if not self.uses_sharded_strategy:
|
|
# For `NO_SHARD`, use the *unflattened* unsharded views since we
|
|
# have the unsharded parameter
|
|
self._use_unsharded_views(as_params=True)
|
|
return
|
|
flat_param = self.flat_param
|
|
self._check_sharded(flat_param)
|
|
# Construct once and reuse for all parameters not in the local shard
|
|
size_0_empty_tensor = torch.empty(
|
|
0,
|
|
dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
|
|
device=self.flat_param.device,
|
|
requires_grad=False,
|
|
)
|
|
for param, shard_param_info, (param_name, module, _) in zip(
|
|
flat_param._params, flat_param._shard_param_infos, flat_param._param_infos
|
|
):
|
|
self._setattr_param(module, param_name, param)
|
|
if not shard_param_info.in_shard:
|
|
# Allow the original data to be freed via garbage collection
|
|
param.data = size_0_empty_tensor
|
|
else:
|
|
offset = shard_param_info.offset_in_shard
|
|
numel_in_shard = shard_param_info.numel_in_shard
|
|
param.data = flat_param[offset : offset + numel_in_shard]
|
|
assert self.flat_param._shared_params is not None
|
|
for i, (
|
|
param,
|
|
(param_name, module, _, prim_param_name, prim_module, _),
|
|
) in enumerate(
|
|
zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
|
|
):
|
|
self._setattr_param(module, param_name, param)
|
|
prim_param = getattr(prim_module, prim_param_name)
|
|
param.data = prim_param # could be both empty and non-empty
|
|
if self._training_state == HandleTrainingState.BACKWARD_POST:
|
|
# Clear the saved `Tensor`s since they are unneeded now
|
|
for i in range(len(self.flat_param._tensors)):
|
|
self.flat_param._tensors[i] = None
|
|
|
|
@no_type_check
|
|
@torch.no_grad()
|
|
def _use_sharded_grad_views(self) -> None:
|
|
"""
|
|
Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient.
|
|
|
|
This is a no-op if there is no gradient.
|
|
|
|
Parameters whose data is not present in the sharded flat parameter and
|
|
parameters with ``requires_grad=False`` have their gradients set to
|
|
``None``. Since the gradient variables do not need to be preserved,
|
|
this method does not manipulate existing ``Tensor`` data directly and
|
|
creates new ``Tensor`` variables instead.
|
|
"""
|
|
flat_param = self.flat_param
|
|
self._check_sharded(flat_param)
|
|
grad = self.sharded_grad
|
|
if grad is None:
|
|
for param in chain(flat_param._params, flat_param._shared_params):
|
|
param.grad = None
|
|
return
|
|
self._check_sharded(grad)
|
|
for param, shard_param_info, is_grad_none in zip(
|
|
flat_param._params,
|
|
flat_param._shard_param_infos,
|
|
flat_param._is_grad_none_mask,
|
|
):
|
|
if not shard_param_info.in_shard:
|
|
param.grad = None
|
|
else:
|
|
numel_in_shard = shard_param_info.numel_in_shard
|
|
if param.requires_grad and not is_grad_none:
|
|
offset = shard_param_info.offset_in_shard
|
|
if self._keep_low_precision_grads or param.dtype != grad.dtype:
|
|
# NOTE: This is a hack using `.data` to side step the
|
|
# check that parameter/gradient dtypes match. Here,
|
|
# `param` has full precision; `grad` has low precision.
|
|
if param.grad is None:
|
|
# `.grad` must have the same shape as `param`
|
|
param.grad = torch.empty_like(param)
|
|
param.grad.data = grad[
|
|
offset : offset + numel_in_shard
|
|
].reshape(param.shape)
|
|
else:
|
|
param.grad = grad[offset : offset + numel_in_shard].reshape(
|
|
param.shape
|
|
)
|
|
else:
|
|
param.grad = None
|
|
assert flat_param._shared_params is not None
|
|
for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
|
|
zip(flat_param._shared_params, flat_param._shared_param_infos)
|
|
):
|
|
in_sharded_flat_param = hasattr(prim_module, prim_param_name)
|
|
if in_sharded_flat_param and param.requires_grad:
|
|
prim_param = getattr(prim_module, prim_param_name)
|
|
param.grad = prim_param.grad # share the same reference
|
|
else:
|
|
param.grad = None
|
|
|
|
@no_type_check
|
|
@torch.no_grad()
|
|
def _writeback_orig_params(self) -> bool:
|
|
"""
|
|
Write back any parameters that changed storage to the handle's ``FlatParameter``.
|
|
|
|
Iterates over the original parameters and writes back any parameters
|
|
that changed storages (due to a non-inplace operator) to the handle's
|
|
``FlatParameter``. This method preserves the ``FlatParameter` 's
|
|
device even if an original parameter's device changes.
|
|
|
|
Raises:
|
|
RuntimeError: If an original parameter or gradient changes storages
|
|
but no longer has the expected flattened shape.
|
|
Returns: ``True`` if some writeback happened, and ``False`` otherwise.
|
|
"""
|
|
if (
|
|
self.uses_sharded_strategy
|
|
and not self.is_sharded(self.flat_param)
|
|
and not self._skipped_use_sharded_views
|
|
):
|
|
# For `NO_SHARD`, we may still need to writeback
|
|
return False
|
|
flat_param = self.flat_param
|
|
wroteback = False
|
|
if self._skipped_use_sharded_views and self.uses_sharded_strategy:
|
|
# NOTE: We must use the unsharded flat parameter from which the
|
|
# unsharded views were computed, not the one from the current
|
|
# calling context (`_get_padded_unsharded_flat_param()`) since that
|
|
# may be different (e.g. the model changed from train to eval).
|
|
flat_param_tensor = self._unsharded_flat_param_for_skipped_views
|
|
_p_assert(
|
|
_data_ptr_allocated(flat_param_tensor),
|
|
"If skipped using sharded views, the unsharded flat parameter "
|
|
"should be allocated",
|
|
)
|
|
else:
|
|
flat_param_tensor = flat_param
|
|
# NOTE: Since this method is called in the pre-unshard, which is only
|
|
# called during computation in the pre-forward or pre-backward, the
|
|
# sharded gradient should be guaranteed to be in `.grad`, not in
|
|
# `._saved_grad_shard`.
|
|
flat_param_grad = (
|
|
flat_param.grad
|
|
if self.uses_sharded_strategy or not self._offload_params
|
|
else flat_param._cpu_grad
|
|
)
|
|
for i, (
|
|
param,
|
|
(in_shard, offset_in_shard, numel_in_shard, _, _),
|
|
(param_name, module, _),
|
|
) in enumerate(
|
|
zip(
|
|
flat_param._params,
|
|
flat_param._shard_param_infos,
|
|
flat_param._param_infos,
|
|
)
|
|
):
|
|
if not in_shard:
|
|
continue
|
|
if not hasattr(module, param_name):
|
|
# Do not writeback if original parameters are deregistered
|
|
# (e.g. during model checkpointing)
|
|
continue
|
|
|
|
# Check for parameter writeback
|
|
if self._skipped_use_sharded_views:
|
|
param = flat_param._tensors[i]
|
|
_p_assert(
|
|
param is not None,
|
|
f"Expects to have saved tensor for {flat_param._fqns[i]}",
|
|
)
|
|
param_changed = getattr(module, param_name) is not param
|
|
needs_param_writeback = (
|
|
param_changed # changed parameter variable itself
|
|
or not _same_storage(param, flat_param_tensor)
|
|
)
|
|
if self._skipped_use_sharded_views and (
|
|
param_changed or needs_param_writeback
|
|
):
|
|
raise AssertionError(
|
|
"FSDP does not support changing the parameters between "
|
|
f"forward and backward for {self._sharding_strategy}"
|
|
)
|
|
if param_changed:
|
|
# NOTE: The gradient is not preserved after a parameter change.
|
|
param = getattr(module, param_name)
|
|
flat_param._params[i] = param
|
|
if needs_param_writeback:
|
|
expected_shape = torch.Size([numel_in_shard])
|
|
self._writeback_tensor(
|
|
param, flat_param, i, expected_shape, offset_in_shard, True
|
|
)
|
|
wroteback = True
|
|
|
|
# Check for gradient writeback
|
|
if self._skipped_use_sharded_views:
|
|
# Skip the writeback check because we do not expose gradients
|
|
# when we skipped using sharded views
|
|
continue
|
|
if param.grad is None and flat_param.grad is not None:
|
|
expected_shape = torch.Size([numel_in_shard])
|
|
self._writeback_tensor(
|
|
None, flat_param.grad, i, expected_shape, offset_in_shard, False
|
|
)
|
|
elif param.grad is not None:
|
|
# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
|
|
# memory and owns the gradient storage, so it will never
|
|
# require gradient writeback.
|
|
if not self.uses_sharded_strategy and self._offload_params:
|
|
# Explicitly continue to handle the case of `no_sync()`,
|
|
# where `param.grad` is a view into the GPU gradient
|
|
# referenced by `flat_param.grad`, while `flat_param_grad`
|
|
# is `flat_param._cpu_grad`, which is on CPU
|
|
continue
|
|
|
|
needs_grad_writeback = flat_param_grad is None or not _same_storage(
|
|
param.grad, flat_param_grad
|
|
)
|
|
if needs_grad_writeback:
|
|
if flat_param_grad is None:
|
|
flat_param_grad = torch.zeros_like(flat_param)
|
|
expected_shape = torch.Size([numel_in_shard])
|
|
self._writeback_tensor(
|
|
param.grad,
|
|
flat_param_grad,
|
|
i,
|
|
expected_shape,
|
|
offset_in_shard,
|
|
False,
|
|
)
|
|
flat_param.grad = flat_param_grad
|
|
flat_param_grad = flat_param.grad
|
|
|
|
# TODO: If we want to handle shared parameters, we need to re-generate
|
|
# the shared parameter data structures in case sharedness changed.
|
|
for i, (
|
|
param_name,
|
|
module,
|
|
_,
|
|
prim_param_name,
|
|
prim_module,
|
|
_,
|
|
) in enumerate(flat_param._shared_param_infos):
|
|
if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
|
|
raise NotImplementedError(
|
|
"Changing shared parameters is not supported yet"
|
|
)
|
|
return wroteback
|
|
|
|
def _writeback_tensor(
|
|
self,
|
|
src_tensor: Optional[Tensor],
|
|
dst_tensor: Tensor,
|
|
tensor_index: int,
|
|
expected_shape: torch.Size,
|
|
offset: int,
|
|
is_param: bool, # else gradient
|
|
) -> None:
|
|
"""
|
|
Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``.
|
|
|
|
``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if
|
|
``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
|
|
instead of copying. ``tensor_index`` gives the index of ``src_tensor``
|
|
in the metadata structures.
|
|
|
|
Raises:
|
|
RuntimeError: If the ``src_tensor`` does not have the expected
|
|
shape.
|
|
"""
|
|
_p_assert(
|
|
len(expected_shape) == 1,
|
|
f"Expects a 1D expected shape but got {expected_shape}",
|
|
)
|
|
if self._debug_level == dist.DebugLevel.INFO:
|
|
rank = self.rank if hasattr(self, "rank") else dist.get_rank()
|
|
src_shape = src_tensor.shape if src_tensor is not None else None
|
|
src_device = src_tensor.device if src_tensor is not None else None
|
|
warnings.warn(
|
|
f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
|
|
f"writeback in {self._training_state}\n"
|
|
f"expected shape={expected_shape} shape={src_shape} "
|
|
f"expected device={dst_tensor.device} device={src_device}"
|
|
)
|
|
if src_tensor is not None and src_tensor.shape != expected_shape:
|
|
# NOTE: Gradient shape mismatch is not possible in practice since
|
|
# the gradient shape is enforced to match that of the parameter and
|
|
# we already check for parameter shape mismatch.
|
|
raise RuntimeError(
|
|
f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
|
|
f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
|
|
)
|
|
if src_tensor is not None:
|
|
dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
|
|
else:
|
|
dst_tensor[offset : offset + expected_shape.numel()].zero_()
|
|
assert self.flat_param._is_grad_none_mask is not None
|
|
self.flat_param._is_grad_none_mask[tensor_index] = True
|
|
|
|
def _reset_flat_param_grad_info_if_needed(self):
|
|
"""
|
|
Reset ``flat_param.grad`` if needed.
|
|
|
|
When ``use_orig_params=True``:
|
|
(1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
|
|
original parameters' ``.grad`` are ``None``, and
|
|
(2) sets ``flat_param.requires_grad=False`` if *none* of the original
|
|
parameters require gradient.
|
|
For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
|
|
which case we want to free the gradients as soon after the
|
|
``zero_grad()`` call as possible.
|
|
"""
|
|
if not self._use_orig_params:
|
|
return
|
|
flat_param = self.flat_param
|
|
assert flat_param._params is not None # mypy
|
|
all_grad_none = True
|
|
requires_grad = False
|
|
for param in flat_param._params:
|
|
all_grad_none &= param.grad is None
|
|
requires_grad |= param.requires_grad
|
|
if all_grad_none:
|
|
flat_param.grad = None
|
|
# As long as one parameter requires gradient, then the flat parameter
|
|
# must require gradient
|
|
flat_param.requires_grad = requires_grad
|
|
|
|
def _deregister_orig_params(self):
|
|
for param_info in self.flat_param._param_infos:
|
|
param_name, module, _ = param_info
|
|
if hasattr(module, param_name):
|
|
delattr(module, param_name)
|
|
for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
|
|
if hasattr(module, param_name):
|
|
delattr(module, param_name)
|
|
|
|
###########
|
|
# HELPERS #
|
|
###########
|
|
def flat_param_to(self, *args, **kwargs):
|
|
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
|
|
self.flat_param.data = self.flat_param.to(*args, **kwargs)
|
|
if self._use_orig_params:
|
|
# Refresh the views because their storage may have changed
|
|
if self.is_sharded(self.flat_param):
|
|
self._use_sharded_views()
|
|
else:
|
|
self._use_unsharded_views(as_params=True)
|
|
|
|
def _get_modules(self) -> Set[nn.Module]:
|
|
"""Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
|
|
return {pi.module for pi in self.flat_param._param_infos}.union(
|
|
{spi.module for spi in self.flat_param._shared_param_infos}
|
|
)
|
|
|
|
def is_sharded(self, tensor: Tensor) -> bool:
|
|
"""
|
|
Return whether ``tensor`` is *currently* sharded.
|
|
|
|
For ``NO_SHARD``, we choose to have this always return ``False`` for clarity.
|
|
"""
|
|
if (
|
|
not hasattr(self.flat_param, "_sharded_size")
|
|
or not self.uses_sharded_strategy
|
|
):
|
|
# `_sharded_size` is defined iff `handle.shard()` has been called
|
|
return False
|
|
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
|
|
return tensor.size() == sharded_size
|
|
|
|
def param_module_names(self) -> Iterator[Tuple[str, str]]:
|
|
shared_param_infos = [
|
|
ParamInfo(param_name, module, module_name)
|
|
for (
|
|
param_name,
|
|
module,
|
|
module_name,
|
|
_,
|
|
_,
|
|
_,
|
|
) in self.flat_param._shared_param_infos
|
|
]
|
|
for param_info in chain(self.flat_param._param_infos, shared_param_infos):
|
|
param_name, _, module_name = param_info # type: ignore[misc]
|
|
yield (param_name, module_name)
|
|
|
|
def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
|
|
for param_name, _, module_name in [
|
|
ParamInfo(param_name, module, module_name)
|
|
for (
|
|
param_name,
|
|
module,
|
|
module_name,
|
|
_,
|
|
_,
|
|
_,
|
|
) in self.flat_param._shared_param_infos
|
|
]:
|
|
yield (param_name, module_name)
|
|
|
|
@property
|
|
def _fqns_in_shard(self) -> List[str]:
|
|
"""Return the FQNs of the parameters present in this rank's shard."""
|
|
fqns_in_shard: List[str] = []
|
|
for fqn, shard_param_info in zip(
|
|
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
|
|
):
|
|
if shard_param_info.in_shard:
|
|
fqns_in_shard.append(fqn)
|
|
return fqns_in_shard
|
|
|
|
@property
|
|
def sharded_grad(self) -> Optional[Tensor]:
|
|
"""Return the handle's sharded gradient."""
|
|
flat_param = self.flat_param
|
|
# Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
|
|
# - CPU offloading: `_cpu_grad`
|
|
# - No CPU offloading + sharded strategies: `_saved_grad_shard`
|
|
# - No CPU offloading + `NO_SHARD`: `grad`
|
|
grad: Optional[Tensor]
|
|
if hasattr(flat_param, "_cpu_grad"):
|
|
grad = flat_param._cpu_grad # type: ignore[attr-defined]
|
|
elif hasattr(flat_param, "_saved_grad_shard"):
|
|
# In the post-backward hook, the sharded gradient is still in
|
|
# `_saved_grad_shard`.
|
|
grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
|
|
else:
|
|
# If in IDLE or in FORWARD states, then there may be an
|
|
# (accumulated) gradient. If accessed in IDLE, then this should
|
|
# be due to re-registering the original parameters (e.g. in state
|
|
# dict load).
|
|
_p_assert(
|
|
flat_param.grad is None
|
|
or not self.uses_sharded_strategy
|
|
or self._training_state
|
|
in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
|
|
"Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
|
|
"unless in IDLE or FORWARD",
|
|
)
|
|
grad = flat_param.grad
|
|
return grad
|
|
|
|
def _reset_is_grad_none(self) -> None:
|
|
"""
|
|
Reset ``_is_grad_none_mask`` as needed.
|
|
|
|
This method should only be
|
|
called in the post-backward after gradient computation, in which case
|
|
if a parameter requires gradient, then it will surely receive a
|
|
gradient and we may reset its mask entry to ``False``.
|
|
"""
|
|
if not self._use_orig_params:
|
|
return
|
|
_p_assert(
|
|
self._training_state == HandleTrainingState.BACKWARD_POST,
|
|
"Expects to only be called in the post-backward after gradient computation",
|
|
)
|
|
flat_param = self.flat_param
|
|
assert flat_param._params is not None # mypy
|
|
for i, param in enumerate(flat_param._params): # type: ignore[arg-type]
|
|
# As long as the parameter requires gradient, it should receive a
|
|
# meaningful gradient (even if the gradient happens to be zeros)
|
|
if param.requires_grad:
|
|
assert flat_param._is_grad_none_mask is not None # mypy
|
|
flat_param._is_grad_none_mask[i] = False
|
|
|
|
#######################
|
|
# CHECKS & INVARIANTS #
|
|
#######################
|
|
def _check_sharded_strategy(self):
|
|
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
|
|
|
|
def _check_on_compute_device(self, tensor: Tensor):
|
|
_p_assert(
|
|
tensor.device == self.device,
|
|
f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
|
|
)
|
|
|
|
def _check_on_cpu(self, tensor: Tensor):
|
|
_p_assert(
|
|
tensor.device == torch.device("cpu"),
|
|
f"Expects tensor to be on CPU but got {tensor.device}",
|
|
)
|
|
|
|
@staticmethod
|
|
def _check_storage_freed(tensor: Tensor):
|
|
# Compile does not resize during trace
|
|
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
|
|
_p_assert(
|
|
_same_storage_size(tensor, 0),
|
|
"Expects storage to be freed but got storage with size > 0",
|
|
)
|
|
|
|
@staticmethod
|
|
def _check_storage_allocated(tensor: Tensor):
|
|
_p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
|
|
|
|
def _check_low_precision_shard(self):
|
|
_p_assert(
|
|
self._uses_param_mixed_precision,
|
|
"Not using low precision for parameters",
|
|
)
|
|
_p_assert(
|
|
getattr(self.flat_param, "_mp_shard", None) is not None,
|
|
"Expects `_mp_shard` to exist",
|
|
)
|
|
device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
|
|
_p_assert(
|
|
device == self.device,
|
|
f"Expects the low precision shard to be on {self.device} but got {device}",
|
|
)
|
|
|
|
def _check_unsharded(self, tensor: Tensor):
|
|
msg_prefix = "Expects tensor to be unsharded "
|
|
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
|
unsharded_size = self.flat_param._unpadded_unsharded_size
|
|
_p_assert(
|
|
tensor.size() == unsharded_size,
|
|
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
|
|
)
|
|
|
|
def _check_sharded(self, tensor: Tensor):
|
|
msg_prefix = "Expects tensor to be sharded "
|
|
_p_assert(tensor is not None, msg_prefix + "but got `None`")
|
|
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
|
|
_p_assert(
|
|
tensor.size() == sharded_size,
|
|
msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
|
|
)
|
|
|
|
##############
|
|
# PROPERTIES #
|
|
##############
|
|
@property
|
|
def uses_sharded_strategy(self) -> bool:
|
|
return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
|
|
|
|
@property
|
|
def _uses_param_mixed_precision(self) -> bool:
|
|
return self._fwd_bwd_param_dtype != self._orig_param_dtype
|
|
|
|
@property
|
|
def _uses_reduce_mixed_precision(self) -> bool:
|
|
return self._reduce_dtype != self._orig_param_dtype
|
|
|
|
@property
|
|
def _force_full_precision(self) -> bool:
|
|
return (
|
|
self._uses_param_mixed_precision or self._uses_reduce_mixed_precision
|
|
) and (
|
|
self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
|
|
or
|
|
# Also disable mixed precision in model eval mode, if configured
|
|
(not self._fully_sharded_module.training and self._use_full_prec_in_eval)
|
|
)
|
|
|
|
@property
|
|
def _skipped_use_sharded_views(self) -> bool:
|
|
"""
|
|
This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``.
|
|
|
|
This returns if this handle is
|
|
currently in a state where it has skipped using sharded views, in which
|
|
case it can restore view invariants via ``_use_sharded_views()``.
|
|
"""
|
|
return self._unsharded_flat_param_for_skipped_views is not None
|
|
|
|
|
|
# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
|
|
def _unsafe_setattr_param(
|
|
module: nn.Module, param_name: str, param: nn.Parameter
|
|
) -> None:
|
|
module._parameters[param_name] = param
|
|
# This bypasses any overrides in case `module` is an instance of an
|
|
# `nn.Module` subclass
|
|
super(nn.Module, module).__setattr__(param_name, param)
|
|
|
|
|
|
def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
|
|
module._parameters.pop(param_name, None)
|
|
# This bypasses any overrides in case `module` is an instance of an
|
|
# `nn.Module` subclass
|
|
super(nn.Module, module).__setattr__(param_name, tensor)
|
|
|
|
|
|
def _safe_setattr_tensor_or_param(
|
|
module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
|
|
):
|
|
# Call `delattr()` and `setattr()` to go through `nn.Module` checks
|
|
if hasattr(module, param_name):
|
|
delattr(module, param_name)
|
|
setattr(module, param_name, tensor_or_param)
|
|
|
|
|
|
def _convert_to_params(
|
|
tensors: List[Union[torch.Tensor, nn.Parameter]]
|
|
) -> List[nn.Parameter]:
|
|
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
|
|
|
|
|
|
def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
|
|
return (
|
|
param_or_tensor.detach()
|
|
if isinstance(param_or_tensor, nn.Parameter)
|
|
else param_or_tensor
|
|
)
|
|
|
|
|
|
def _get_aligned_numel(unsharded_dtype: torch.dtype):
|
|
# NOTE: This alignment constraint comes from TorchInductor.
|
|
ALIGNMENT = 16 # bytes
|
|
unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
|
|
aligned_numel = ALIGNMENT // unsharded_dtype_size
|
|
return aligned_numel
|
|
|
|
|
|
@functools.lru_cache(8)
|
|
def _get_dtype_size(dtype):
|
|
return torch.empty((), dtype=dtype).element_size()
|
|
|
|
|
|
def _construct_padding_tensor(
|
|
padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
|
|
):
|
|
# NOTE: Set the padding value as a magic number for debuggability. The
|
|
# value itself should never be used in any user-facing computation.
|
|
return (
|
|
torch.ones(
|
|
(padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
|
|
)
|
|
* _FLAT_PARAM_PADDING_VALUE
|
|
)
|
|
|
|
|
|
# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
|
|
# messasge is passed in)
|
|
@functools.lru_cache(1)
|
|
def _warn_skip_writeback_check(log: logging.Logger, warning: str):
|
|
logger.warning(warning)
|
|
|
|
|
|
# Use `lru_cache(1)` to only log the warning once
|
|
@functools.lru_cache(1)
|
|
def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
|
|
logger.warning(warning)
|
|
|
|
|
|
# Use `lru_cache(1)` to only log the warning once
|
|
@functools.lru_cache(1)
|
|
def _warn_use_fake_reduce(log: logging.Logger, warning: str):
|
|
logger.warning(warning)
|
|
|
|
|
|
def _same_storage(a, b):
|
|
# Params are DTensors in backward
|
|
# with SHARD_GRAD_OP + TP
|
|
from torch.distributed.tensor import DTensor
|
|
|
|
if isinstance(a, DTensor):
|
|
a = a._local_tensor
|
|
if isinstance(b, DTensor):
|
|
b = b._local_tensor
|
|
return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
|
|
|
|
|
|
def _same_storage_size(a: torch.Tensor, b: int):
|
|
return a.untyped_storage().size() // a.element_size() == b
|
|
|
|
|
|
def _storage_size_allocated(tensor: Tensor):
|
|
storage_size: int = tensor.untyped_storage().size()
|
|
return storage_size > 0
|