mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87925 Approved by: https://github.com/mrshenli
759 lines
27 KiB
Python
759 lines
27 KiB
Python
import collections
|
|
import warnings
|
|
from typing import (
|
|
Callable,
|
|
Dict,
|
|
Iterable,
|
|
Iterator,
|
|
List,
|
|
no_type_check,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
|
|
import torch.nn as nn
|
|
from torch.distributed.algorithms._comm_hooks import default_hooks
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
from torch.distributed.fsdp._common_utils import (
|
|
_apply_to_modules,
|
|
_get_param_to_unflat_param_names,
|
|
_is_fsdp_flattened,
|
|
_State,
|
|
clean_tensor_name,
|
|
TrainingState,
|
|
)
|
|
from torch.distributed.fsdp._exec_order_utils import _ExecOrderData
|
|
from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
|
|
from torch.distributed.fsdp._wrap_utils import _get_submodule_to_states
|
|
from torch.distributed.fsdp.api import (
|
|
BackwardPrefetch,
|
|
CPUOffload,
|
|
MixedPrecision,
|
|
ShardingStrategy,
|
|
)
|
|
from torch.distributed.fsdp.flat_param import (
|
|
_HandlesKey,
|
|
FlatParameter,
|
|
FlatParamHandle,
|
|
HandleConfig,
|
|
HandleShardingStrategy,
|
|
)
|
|
from torch.distributed.utils import _sync_params_and_buffers
|
|
from torch.utils.hooks import RemovableHandle
|
|
|
|
_TORCHDISTX_AVAIL = True
|
|
try:
|
|
from torchdistx import deferred_init, fake # type: ignore[import]
|
|
except ImportError:
|
|
_TORCHDISTX_AVAIL = False
|
|
|
|
PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
|
|
FSDP_SYNCED = "_fsdp_synced"
|
|
|
|
# TODO (awgu): Refactor this later
|
|
SHARDING_STRATEGY_MAP = {
|
|
ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
|
|
ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
|
|
ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
|
|
}
|
|
|
|
|
|
# NOTE: Since non-self attributes cannot be type annotated, several attributes
|
|
# on `state` are defined first as local variables before being assigned.
|
|
|
|
|
|
@no_type_check
|
|
def _init_process_group_state(
|
|
state: _State,
|
|
process_group: Optional[dist.ProcessGroup],
|
|
) -> _State:
|
|
state.process_group = process_group or _get_default_group()
|
|
state.rank = state.process_group.rank()
|
|
state.world_size = state.process_group.size()
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_ignored_module_states(
|
|
state: _State,
|
|
module: nn.Module,
|
|
ignored_modules: Optional[Iterable[torch.nn.Module]],
|
|
) -> _State:
|
|
state._ignored_modules = _get_ignored_modules(module, ignored_modules)
|
|
state._ignored_params, state._ignored_param_names = _get_ignored_params(
|
|
module,
|
|
state._ignored_modules,
|
|
)
|
|
# TODO: FSDP's contract for buffers is not well-defined. They are
|
|
# implicitly ignored for most functionality since they are not sharded;
|
|
# however, FSDP still imposes some semantics on buffers (e.g. buffer mixed
|
|
# precision). We should formalize this contract and decide if we need to
|
|
# compute and store `_ignored_buffers`.
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_buffer_state(
|
|
state: _State,
|
|
module: nn.Module,
|
|
) -> _State:
|
|
state._buffer_names = _get_buffer_names(module)
|
|
# Save a mapping from fully prefixed buffer name to its original dtype
|
|
# since when buffer mixed precision is enabled, buffers are restored to
|
|
# their original dtype for model checkpointing
|
|
_buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
|
|
state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_core_state(
|
|
state: _State,
|
|
sharding_strategy: Optional[ShardingStrategy],
|
|
mixed_precision: Optional[MixedPrecision],
|
|
cpu_offload: Optional[CPUOffload],
|
|
limit_all_gathers: bool,
|
|
use_orig_params: bool,
|
|
backward_prefetch_limit: int,
|
|
forward_prefetch_limit: int,
|
|
) -> _State:
|
|
# We clamp the strategy to `NO_SHARD` for world size of 1 since they are
|
|
# currently functionally equivalent. This may change if/when we integrate
|
|
# FSDP with MoE.
|
|
if state.world_size == 1:
|
|
sharding_strategy = ShardingStrategy.NO_SHARD
|
|
state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
|
|
state.mixed_precision = mixed_precision or MixedPrecision()
|
|
state.cpu_offload = cpu_offload or CPUOffload()
|
|
state.limit_all_gathers = limit_all_gathers
|
|
state._use_orig_params = use_orig_params
|
|
state.training_state = TrainingState.IDLE
|
|
state._is_root = None
|
|
_streams: Dict[str, torch.cuda.Stream] = {}
|
|
state._streams = _streams
|
|
state._free_event_queue = _FreeEventQueue()
|
|
state._debug_level = dist.get_debug_level()
|
|
state._exec_order_data = _ExecOrderData(
|
|
state._debug_level,
|
|
backward_prefetch_limit,
|
|
forward_prefetch_limit,
|
|
)
|
|
# Invariant: `state.params` contains exactly the `FlatParameter`s of the
|
|
# handles in `state._handles`
|
|
_handles: List[FlatParamHandle] = []
|
|
state._handles = _handles
|
|
params: List[FlatParameter] = []
|
|
state.params = params
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_runtime_state(
|
|
state: _State,
|
|
) -> _State:
|
|
_pre_forward_handles: List[RemovableHandle] = []
|
|
state._pre_forward_handles = _pre_forward_handles
|
|
_post_forward_handles: List[RemovableHandle] = []
|
|
state._post_forward_handles = _post_forward_handles
|
|
_module_to_handles: Dict[
|
|
nn.Module, List[FlatParamHandle]
|
|
] = collections.defaultdict(list)
|
|
state._module_to_handles = _module_to_handles
|
|
state._sync_gradients = True
|
|
state._communication_hook = _get_default_comm_hook(state.sharding_strategy)
|
|
state._communication_hook_state = _get_default_comm_hook_state(state.process_group)
|
|
state._hook_registered = False
|
|
# Used to prevent running the pre-backward hook multiple times
|
|
_ran_pre_backward_hook: Dict[_HandlesKey, bool] = {}
|
|
state._ran_pre_backward_hook = _ran_pre_backward_hook
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_prefetching_state(
|
|
state: _State,
|
|
backward_prefetch: BackwardPrefetch,
|
|
forward_prefetch: bool,
|
|
) -> _State:
|
|
state.backward_prefetch = backward_prefetch
|
|
state.forward_prefetch = forward_prefetch
|
|
_handles_prefetched: Dict[_HandlesKey, bool] = {}
|
|
state._handles_prefetched = _handles_prefetched
|
|
# Used for guarding against mistargeted backward prefetches
|
|
_needs_pre_backward_unshard: Dict[_HandlesKey, bool] = {}
|
|
state._needs_pre_backward_unshard = _needs_pre_backward_unshard
|
|
# Used for guarding against mistargeted forward prefetches
|
|
_needs_pre_forward_unshard: Dict[_HandlesKey, bool] = {}
|
|
state._needs_pre_forward_unshard = _needs_pre_forward_unshard
|
|
# The data structures use tuples of handles to generalize over the case
|
|
# where a module's forward involves multiple handles.
|
|
return state
|
|
|
|
|
|
def _init_state_dict_state(state: _State) -> _State:
|
|
# TODO: after rebase
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_param_handle_from_module(
|
|
state: _State,
|
|
root_module: nn.Module,
|
|
device_id: Optional[Union[int, torch.device]],
|
|
param_init_fn: Optional[Callable[[nn.Module], None]],
|
|
sync_module_states: bool,
|
|
module_wrapper_cls: Type,
|
|
) -> _State:
|
|
"""
|
|
Initializes a ``FlatParamHandle`` from a module ``root_module``. This is
|
|
the module wrapper code path.
|
|
"""
|
|
_check_single_device_module(root_module, state._ignored_params)
|
|
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
|
|
_materialize_module(
|
|
root_module,
|
|
param_init_fn,
|
|
state._ignored_params,
|
|
device_from_device_id,
|
|
lambda k: not isinstance(k, module_wrapper_cls),
|
|
)
|
|
# TODO: Investigate refactoring `_move_module_to_device()` to
|
|
# `_move_states_to_device()` to avoid the `device_id` + CPU offload hack
|
|
_move_module_to_device(root_module, state._ignored_params, device_from_device_id)
|
|
state.compute_device = _get_compute_device(
|
|
root_module,
|
|
state._ignored_params,
|
|
device_from_device_id,
|
|
state.rank,
|
|
)
|
|
managed_params = list(_get_orig_params(root_module, state._ignored_params))
|
|
if sync_module_states:
|
|
_sync_module_params_and_buffers(
|
|
root_module, managed_params, state.process_group
|
|
)
|
|
_init_param_handle_from_params(state, managed_params, root_module)
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_param_handles_from_module(
|
|
state: _State,
|
|
root_module: nn.Module,
|
|
auto_wrap_policy: Callable,
|
|
device_id: Optional[Union[int, torch.device]],
|
|
param_init_fn: Optional[Callable[[nn.Module], None]],
|
|
sync_module_states: bool,
|
|
) -> _State:
|
|
"""
|
|
Initializes all ``FlatParamHandle`` s from a module ``root_module``. This
|
|
is the non-module-wrapper code path.
|
|
"""
|
|
submodule_to_states = _get_submodule_to_states(
|
|
root_module,
|
|
auto_wrap_policy,
|
|
state._ignored_modules,
|
|
state._ignored_params,
|
|
)
|
|
_check_single_device_module(root_module, state._ignored_params)
|
|
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
|
|
# Initialize and shard `FlatParamHandle`s one by one following bottom-up
|
|
# order (hence the `reversed`) to avoid increasing peak GPU memory usage
|
|
materialized_module = False
|
|
for submodule, (params, buffers, param_names, buffer_names) in reversed(
|
|
submodule_to_states.items()
|
|
):
|
|
materialized_module |= _materialize_module(
|
|
submodule,
|
|
param_init_fn,
|
|
state._ignored_params,
|
|
device_from_device_id,
|
|
lambda _: True,
|
|
)
|
|
if materialized_module:
|
|
# Materializing from meta device can change the parameter/buffer
|
|
# variables, so reacquire references
|
|
params = [submodule.get_parameter(param_name) for param_name in param_names]
|
|
buffers = [
|
|
submodule.get_buffer(buffer_name) for buffer_name in buffer_names
|
|
]
|
|
_move_states_to_device(params, buffers, device_from_device_id)
|
|
if not hasattr(state, "compute_device"): # only need to set once
|
|
state.compute_device = _get_compute_device(
|
|
submodule,
|
|
state._ignored_params,
|
|
device_from_device_id,
|
|
state.rank,
|
|
)
|
|
if sync_module_states:
|
|
_sync_module_states(params, buffers, state.process_group)
|
|
# Pass `root_module` to have internal FQN metadata prefix starting from
|
|
# it instead of `submodule`
|
|
_init_param_handle_from_params(state, params, root_module)
|
|
# Reverse to preserve top-down order like `_fsdp_handles()`
|
|
state._handles.reverse()
|
|
return state
|
|
|
|
|
|
@no_type_check
|
|
def _init_param_handle_from_params(
|
|
state: _State,
|
|
params: List[nn.Parameter],
|
|
root_module: nn.Module,
|
|
):
|
|
if len(params) == 0:
|
|
return
|
|
handle_config = HandleConfig(
|
|
SHARDING_STRATEGY_MAP[state.sharding_strategy],
|
|
state.cpu_offload.offload_params,
|
|
state.mixed_precision.param_dtype,
|
|
state.mixed_precision.reduce_dtype,
|
|
state.mixed_precision.keep_low_precision_grads,
|
|
)
|
|
handle = FlatParamHandle(
|
|
params,
|
|
root_module,
|
|
state.compute_device,
|
|
handle_config,
|
|
state.process_group,
|
|
state._use_orig_params,
|
|
)
|
|
# TODO: Can simplify call `shard()` in the `FlatParamHandle` ctor
|
|
handle.shard()
|
|
assert handle not in state._handles
|
|
state.params.append(handle.flat_param)
|
|
state._handles.append(handle)
|
|
cpu_device = torch.device("cpu")
|
|
if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
|
|
handle.flat_param_to(cpu_device)
|
|
|
|
|
|
def _get_ignored_modules(
|
|
root_module: nn.Module,
|
|
_ignored_modules: Optional[Iterable[torch.nn.Module]],
|
|
) -> Set[nn.Module]:
|
|
"""
|
|
Checks that ``_ignored_modules`` is an iterable of ``nn.Module`` s without
|
|
any FSDP instances, and returns the modules contained in their module
|
|
subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
|
|
already-computed ignored modules are included.
|
|
"""
|
|
if _ignored_modules is None:
|
|
return set()
|
|
msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
|
|
try:
|
|
ignored_root_modules = set(_ignored_modules)
|
|
except TypeError:
|
|
raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}")
|
|
for module in ignored_root_modules:
|
|
if not isinstance(module, torch.nn.Module):
|
|
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
|
|
if isinstance(module, fsdp_file.FullyShardedDataParallel):
|
|
raise ValueError("`ignored_modules` should not include FSDP modules")
|
|
# Include child modules and exclude nested FSDP modules themselves
|
|
ignored_modules = set(
|
|
child
|
|
for module in ignored_root_modules
|
|
for child in module.modules()
|
|
if not isinstance(child, fsdp_file.FullyShardedDataParallel)
|
|
)
|
|
if root_module in ignored_modules:
|
|
warnings.warn(
|
|
"Trying to ignore the top-level module passed into the FSDP "
|
|
"constructor itself will result in all parameters being "
|
|
f"ignored and is not well-supported: {module}"
|
|
)
|
|
# Include nested FSDP modules' ignored modules
|
|
for submodule in root_module.modules():
|
|
if isinstance(submodule, fsdp_file.FullyShardedDataParallel):
|
|
assert hasattr(submodule, "_ignored_modules")
|
|
ignored_modules.update(submodule._ignored_modules)
|
|
return ignored_modules
|
|
|
|
|
|
def _get_ignored_params(
|
|
root_module: torch.nn.Module,
|
|
ignored_modules: Set[torch.nn.Module],
|
|
) -> Tuple[Set[torch.nn.Parameter], Set[str]]:
|
|
"""
|
|
Returns the parameters of the modules in ``ignored_modules``,
|
|
excluding any :class:`FlatParameter` s, and their fully prefixed names,
|
|
both as :class:`set` s.
|
|
"""
|
|
ignored_params = set(
|
|
p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
|
|
)
|
|
# Conservatively include all shared parameters' names
|
|
param_to_unflat_param_names = _get_param_to_unflat_param_names(
|
|
root_module,
|
|
dedup_shared_params=False,
|
|
)
|
|
ignored_param_names = set()
|
|
for param in ignored_params:
|
|
unflat_param_names = param_to_unflat_param_names[param]
|
|
clean_names = []
|
|
for k in unflat_param_names:
|
|
# Clean any module wrapper prefixes in case of nested wrapping
|
|
clean_names.append(clean_tensor_name(k))
|
|
ignored_param_names.update(clean_names)
|
|
return ignored_params, ignored_param_names
|
|
|
|
|
|
def _get_buffer_names(root_module: nn.Module) -> Set[str]:
|
|
"""
|
|
Returns the fully prefixed names of all buffers in the module hierarchy
|
|
rooted at ``root_module`` as a class:`set`.
|
|
"""
|
|
|
|
def module_fn(module: nn.Module, prefix: str, buffer_names: Set[str]):
|
|
for buffer_name, _ in module.named_buffers(recurse=False):
|
|
# Clean module wrapper prefixes in case of nested wrapping
|
|
prefixed_buffer_name = clean_tensor_name(prefix + buffer_name)
|
|
buffer_names.add(prefixed_buffer_name)
|
|
|
|
def return_fn(buffer_names: Set[str], *args):
|
|
return buffer_names
|
|
|
|
buffer_names: Set[str] = set()
|
|
return _apply_to_modules(
|
|
root_module,
|
|
module_fn,
|
|
return_fn,
|
|
buffer_names,
|
|
)
|
|
|
|
|
|
def _check_single_device_module(
|
|
module: nn.Module,
|
|
ignored_params: Set[nn.Parameter],
|
|
) -> None:
|
|
"""
|
|
Raises an error if ``module`` has original parameters on multiple devices,
|
|
ignoring the parameters in ``ignored_params``. Thus, after this method, the
|
|
module must be either fully on the CPU or fully on a non-CPU device.
|
|
"""
|
|
devices = set(param.device for param in _get_orig_params(module, ignored_params))
|
|
if len(devices) > 1:
|
|
raise RuntimeError(
|
|
f"FSDP only supports single device modules but got params on {devices}"
|
|
)
|
|
|
|
|
|
def _get_device_from_device_id(
|
|
device_id: Optional[Union[int, torch.device]],
|
|
rank: int,
|
|
) -> Optional[torch.device]:
|
|
"""
|
|
Processes ``device_id`` and returns either the corresponding device or
|
|
``None`` if ``device_id`` is ``None``.
|
|
"""
|
|
if device_id is None:
|
|
return None
|
|
device = (
|
|
device_id if isinstance(device_id, torch.device) else torch.device(device_id)
|
|
)
|
|
if device == torch.device("cuda"):
|
|
warnings.warn(
|
|
f"FSDP got the argument `device_id` {device_id} on rank "
|
|
f"{rank}, which does not have an explicit index. "
|
|
f"FSDP will use the current device {torch.cuda.current_device()}. "
|
|
"If this is incorrect, please explicitly call `torch.cuda.set_device()` "
|
|
"before FSDP initialization or pass in the explicit device "
|
|
"index as the `device_id` argument."
|
|
)
|
|
device = torch.device("cuda", torch.cuda.current_device())
|
|
return device
|
|
|
|
|
|
def _materialize_module(
|
|
module: nn.Module,
|
|
param_init_fn: Optional[Callable[[nn.Module], None]],
|
|
ignored_params: Set[nn.Parameter],
|
|
device_from_device_id: Optional[torch.device],
|
|
deferred_init_check_fn: Callable,
|
|
) -> bool:
|
|
"""
|
|
Materializes the wrapped module ``module`` in place if needed: either
|
|
if the module has parameters that use meta device or are torchdistX
|
|
fake tensors.
|
|
|
|
This method uses ``param_init_fn`` to materialize the module if the
|
|
function is not ``None`` and falls back to default behavior otherwise.
|
|
For meta device, this moves the module to ``device_from_device_id`` if
|
|
it is not ``None`` or the current device otherwise and calls
|
|
``reset_parameters()``, and for torchdistX fake tensors, this calls
|
|
``deferred_init.materialize_module()``.
|
|
|
|
Returns:
|
|
bool: ``True`` if ``module`` was materialized and ``False`` if this was
|
|
a no-op.
|
|
"""
|
|
managed_params = _get_orig_params(module, ignored_params)
|
|
is_meta_module = any(param.is_meta for param in managed_params)
|
|
is_torchdistX_deferred_init = (
|
|
not is_meta_module
|
|
and _TORCHDISTX_AVAIL
|
|
and any(fake.is_fake(param) for param in managed_params)
|
|
)
|
|
if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
|
|
if not callable(param_init_fn):
|
|
raise ValueError(
|
|
f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}"
|
|
)
|
|
param_init_fn(module)
|
|
return True
|
|
elif is_meta_module:
|
|
# Run default meta device initialization
|
|
materialization_device = device_from_device_id or torch.device(
|
|
torch.cuda.current_device()
|
|
)
|
|
module.to_empty(device=materialization_device)
|
|
try:
|
|
with torch.no_grad():
|
|
module.reset_parameters() # type: ignore[operator]
|
|
except BaseException as e:
|
|
warnings.warn(
|
|
"Unable to call `reset_parameters()` for module on meta "
|
|
f"device with error {str(e)}. Please ensure your "
|
|
"module implements a `reset_parameters()` method."
|
|
)
|
|
raise e
|
|
return True
|
|
elif is_torchdistX_deferred_init:
|
|
# Run default torchdistX initialization
|
|
deferred_init.materialize_module(module, check_fn=deferred_init_check_fn)
|
|
return True
|
|
return False
|
|
|
|
|
|
def _move_module_to_device(
|
|
module: nn.Module,
|
|
ignored_params: Set[nn.Parameter],
|
|
device_from_device_id: Optional[torch.device],
|
|
) -> None:
|
|
"""
|
|
Moves ``module`` depending on ``device_from_device_id`` and its current
|
|
device. This includes moving ignored modules' parameters.
|
|
|
|
- If ``device_from_device_id`` is not ``None``, then this moves
|
|
``module`` to the device.
|
|
- If ``device_from_device_id`` is ``None``, then this does not move
|
|
``module`` but warns the user if it is on CPU.
|
|
|
|
Precondition: ``_check_single_device_module()``.
|
|
"""
|
|
param = next(_get_orig_params(module, ignored_params), None)
|
|
if param is None:
|
|
return # no original parameters to manage
|
|
cpu_device = torch.device("cpu")
|
|
if device_from_device_id is not None:
|
|
if param.device == cpu_device:
|
|
# NOTE: This includes moving ignored modules' parameters.
|
|
module = module.to(device_from_device_id)
|
|
# TODO: This is a temporary fix to move already-constructed
|
|
# `FlatParameter`s back to CPU if needed. This is needed to
|
|
# make CPU offload work with `device_id`.
|
|
for submodule in module.modules():
|
|
if (
|
|
isinstance(submodule, fsdp_file.FullyShardedDataParallel)
|
|
and submodule.cpu_offload.offload_params
|
|
):
|
|
with torch.no_grad():
|
|
for handle in submodule._handles:
|
|
handle.flat_param_to(torch.device("cpu"))
|
|
elif param.device == cpu_device:
|
|
_warn_cpu_init()
|
|
|
|
|
|
def _move_states_to_device(
|
|
params: List[nn.Parameter],
|
|
buffers: List[torch.Tensor],
|
|
device_from_device_id: Optional[torch.device],
|
|
) -> None:
|
|
"""
|
|
Precondition: ``_check_single_device_module()``.
|
|
"""
|
|
if len(params) == 0 and len(buffers) == 0:
|
|
return
|
|
if len(params) > 0:
|
|
current_device = params[0].device
|
|
elif len(buffers) > 0:
|
|
current_device = buffers[0].device
|
|
cpu_device = torch.device("cpu")
|
|
if device_from_device_id is not None:
|
|
# Move the parameters and buffers like the `.data` code path in
|
|
# `nn.Module._apply()`, which underlies `nn.Module.to()`
|
|
for param in params:
|
|
with torch.no_grad():
|
|
param.data = param.to(device_from_device_id)
|
|
if param.grad is not None:
|
|
param.grad.data = param.grad.to(device_from_device_id)
|
|
for buffer in buffers:
|
|
buffer.data = buffer.to(device_from_device_id)
|
|
elif current_device == cpu_device:
|
|
_warn_cpu_init()
|
|
|
|
|
|
def _warn_cpu_init():
|
|
warnings.warn(
|
|
"The passed-in `module` is on CPU and will thus have FSDP's sharding "
|
|
"initialization run on CPU, which may be slower than on GPU. We "
|
|
"recommend passing in the `device_id` argument for FSDP to move "
|
|
"`module` to GPU for the sharding initialization. `module` must also "
|
|
"be on GPU device to work with the `sync_module_states=True` flag "
|
|
"since that requires GPU communication."
|
|
)
|
|
|
|
|
|
def _get_compute_device(
|
|
module: nn.Module,
|
|
ignored_params: Set[nn.Parameter],
|
|
device_from_device_id: Optional[torch.device],
|
|
rank: int,
|
|
) -> torch.device:
|
|
"""
|
|
Determines and returns this FSDP instance's compute device. If the module
|
|
is already on a non-CPU device, then the compute device is that non-CPU
|
|
device. If the module is on CPU, then the compute device is the current
|
|
device.
|
|
|
|
Since this method should be called after materializing the module, any
|
|
non-CPU device should not be meta device. For now, the compute device is
|
|
always a CUDA GPU device with its explicit index.
|
|
|
|
Precondition: ``_check_single_device_module()`` and
|
|
``_move_module_to_device()``.
|
|
"""
|
|
# If the module is on GPU already, then that GPU device has priority
|
|
# over the current device
|
|
param = next(_get_orig_params(module, ignored_params), None)
|
|
if param is not None and param.device.type == "cuda":
|
|
compute_device = param.device
|
|
else:
|
|
compute_device = torch.device("cuda", torch.cuda.current_device())
|
|
if device_from_device_id is not None and compute_device != device_from_device_id:
|
|
raise ValueError(
|
|
f"Inconsistent compute device and `device_id` on rank {rank}: "
|
|
f"{compute_device} vs {device_from_device_id}"
|
|
)
|
|
return compute_device
|
|
|
|
|
|
# TODO: See how to deprecate!
|
|
def _sync_module_params_and_buffers(
|
|
module: nn.Module,
|
|
params: List[nn.Parameter],
|
|
process_group: dist.ProcessGroup,
|
|
) -> None:
|
|
"""
|
|
Synchronizes module states (i.e. parameters ``params`` and all
|
|
not-yet-synced buffers) by broadcasting from rank 0 to all ranks.
|
|
|
|
Precondition: ``sync_module_states == True`` and ``self.process_group`` has
|
|
been set.
|
|
"""
|
|
_check_params_for_sync_module_states(params)
|
|
module_states: List[torch.Tensor] = []
|
|
# TODO (awgu): When exposing the original parameters, we need to also
|
|
# use this attribute to prevent re-synchronizing parameters.
|
|
for buffer in module.buffers():
|
|
# Avoid re-synchronizing buffers in case of nested wrapping
|
|
if not getattr(buffer, FSDP_SYNCED, False):
|
|
setattr(buffer, FSDP_SYNCED, True)
|
|
module_states.append(buffer.detach())
|
|
module_states.extend(param.detach() for param in params)
|
|
_sync_params_and_buffers(
|
|
process_group,
|
|
module_states,
|
|
PARAM_BROADCAST_BUCKET_SIZE,
|
|
src=0,
|
|
)
|
|
|
|
|
|
def _sync_module_states(
|
|
params: List[nn.Parameter],
|
|
buffers: List[torch.Tensor],
|
|
process_group: dist.ProcessGroup,
|
|
) -> None:
|
|
_check_params_for_sync_module_states(params)
|
|
# Assumes that each call to this method passes in disjoint `params` and
|
|
# and `buffers` across calls, so there is no chance of re-synchronizing
|
|
params_and_buffers = [param.detach() for param in params] + [
|
|
buffer.detach() for buffer in buffers
|
|
]
|
|
_sync_params_and_buffers(
|
|
process_group,
|
|
params_and_buffers,
|
|
PARAM_BROADCAST_BUCKET_SIZE,
|
|
src=0,
|
|
)
|
|
|
|
|
|
def _check_params_for_sync_module_states(
|
|
params: List[nn.Parameter],
|
|
) -> None:
|
|
if params and any(param.device == torch.device("cpu") for param in params):
|
|
raise ValueError(
|
|
"The module has CPU parameters when `sync_module_states=True`, "
|
|
"which only works when all parameters are on GPU. Please specify "
|
|
"the `device_id` argument or move the module to GPU before passing "
|
|
"into FSDP."
|
|
)
|
|
|
|
|
|
def _get_orig_params(
|
|
module: nn.Module,
|
|
ignored_params: Set[nn.Parameter],
|
|
) -> Iterator[nn.Parameter]:
|
|
"""
|
|
Returns an iterator over the original parameters in ``module``, ignoring
|
|
the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be
|
|
present due to nested FSDP wrapping), and any original parameters already
|
|
flattened (only relevant when ``use_orig_params=True``).
|
|
"""
|
|
param_gen = module.parameters()
|
|
try:
|
|
while True:
|
|
param = next(param_gen)
|
|
if param not in ignored_params and not _is_fsdp_flattened(param):
|
|
yield param
|
|
except StopIteration:
|
|
pass
|
|
|
|
|
|
def _check_orig_params_flattened(
|
|
fsdp_module,
|
|
ignored_params: Set[nn.Parameter],
|
|
) -> None:
|
|
"""
|
|
Checks that all original parameters have been flattened and hence made
|
|
invisible to ``named_parameters()`` for the module hierarchy rooted at
|
|
``fsdp_module``. This should be called as a sanity check after flattening
|
|
the wrapped module's parameters.
|
|
"""
|
|
for param_name, param in fsdp_module.named_parameters():
|
|
if param not in ignored_params and not _is_fsdp_flattened(param):
|
|
raise RuntimeError(
|
|
f"Found an unflattened parameter: {param_name}; "
|
|
f"{param.size()} {param.__class__}"
|
|
)
|
|
|
|
|
|
def _get_default_comm_hook(sharding_strategy: ShardingStrategy):
|
|
return (
|
|
default_hooks.allreduce_hook
|
|
if sharding_strategy == ShardingStrategy.NO_SHARD
|
|
else default_hooks.reduce_scatter_hook
|
|
)
|
|
|
|
|
|
def _get_default_comm_hook_state(
|
|
process_group: dist.ProcessGroup,
|
|
) -> default_hooks.DefaultState:
|
|
return default_hooks.DefaultState(process_group=process_group)
|