mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
PEP585 update - torch/distributed/fsdp (#145162)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145162 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
371a361db9
commit
c64e657632
|
|
@ -6,22 +6,10 @@ import logging
|
|||
import traceback
|
||||
import warnings
|
||||
import weakref
|
||||
from collections.abc import Generator, Iterable
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
no_type_check,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from typing import Any, Callable, cast, no_type_check, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -118,10 +106,10 @@ class _FSDPState(_State):
|
|||
def __init__(self) -> None:
|
||||
# TODO: Move all the attributes to this class to enable typing for
|
||||
# FSDP/fully_shard.
|
||||
self._ignored_modules: Set[nn.Module] = set()
|
||||
self._ignored_params: Set[nn.Parameter] = set()
|
||||
self._ignored_modules: set[nn.Module] = set()
|
||||
self._ignored_params: set[nn.Parameter] = set()
|
||||
# Buffer names are cleaned (without wrapper prefixes)
|
||||
self._ignored_buffer_names: Set[str] = set()
|
||||
self._ignored_buffer_names: set[str] = set()
|
||||
self.process_group: Optional[dist.ProcessGroup] = None
|
||||
self.rank: int = -1
|
||||
self.world_size: int = -1
|
||||
|
|
@ -129,13 +117,13 @@ class _FSDPState(_State):
|
|||
self.sharding_strategy = ShardingStrategy.FULL_SHARD
|
||||
self._use_orig_params: bool = False
|
||||
self.training_state = TrainingState.IDLE
|
||||
self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
|
||||
self._unshard_params_ctx: dict[nn.Module, Generator] = {}
|
||||
self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
|
||||
self._state_dict_config: StateDictConfig = FullStateDictConfig()
|
||||
self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
|
||||
self._is_root: Optional[bool] = None
|
||||
self._handle: Optional[flat_param_file.FlatParamHandle] = None
|
||||
self._fully_sharded_module_to_handle: Dict[
|
||||
self._fully_sharded_module_to_handle: dict[
|
||||
nn.Module, Optional[flat_param_file.FlatParamHandle]
|
||||
] = {}
|
||||
self.compute_device: Optional[torch.device] = None
|
||||
|
|
@ -149,8 +137,8 @@ class _FSDPState(_State):
|
|||
self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle()
|
||||
# All following attributes should only be used for root states:
|
||||
# Save these static lists to avoid the repeated tree traversals
|
||||
self._all_fsdp_states: List[_FSDPState] = []
|
||||
self._all_handles: List[flat_param_file.FlatParamHandle] = []
|
||||
self._all_fsdp_states: list[_FSDPState] = []
|
||||
self._all_handles: list[flat_param_file.FlatParamHandle] = []
|
||||
self._fsdp_extension: Optional[FSDPExtensions] = None
|
||||
|
||||
|
||||
|
|
@ -262,7 +250,7 @@ def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
|
|||
|
||||
def _named_parameters_with_duplicates(
|
||||
module: nn.Module, **kwargs: Any
|
||||
) -> List[tuple[str, nn.Parameter]]:
|
||||
) -> list[tuple[str, nn.Parameter]]:
|
||||
"""
|
||||
This API is required as some modules overwrite `named_parameters()` but do not support
|
||||
`remove_duplicate`.
|
||||
|
|
@ -282,7 +270,7 @@ def _named_parameters_with_duplicates(
|
|||
def _get_param_to_fqns(
|
||||
model: torch.nn.Module,
|
||||
dedup_shared_params: bool = True,
|
||||
) -> Dict[nn.Parameter, List[str]]:
|
||||
) -> dict[nn.Parameter, list[str]]:
|
||||
"""
|
||||
Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here,
|
||||
we use canonical to mean the fully-qualified name assigned to the parameter
|
||||
|
|
@ -352,7 +340,7 @@ def _get_param_to_fqns(
|
|||
def return_fn(param_to_fqns):
|
||||
return param_to_fqns
|
||||
|
||||
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
|
||||
param_to_unflat_param_names: dict[torch.nn.Parameter, list[str]] = {}
|
||||
return _apply_to_modules(
|
||||
model,
|
||||
module_fn,
|
||||
|
|
@ -377,7 +365,7 @@ def _log_post_backward_hook(
|
|||
@no_type_check
|
||||
def _get_handle_fqns_from_root(
|
||||
state: _FSDPState, handle: "FlatParamHandle"
|
||||
) -> Optional[List[str]]:
|
||||
) -> Optional[list[str]]:
|
||||
if handle is None:
|
||||
return None
|
||||
param_to_fqn = state._exec_order_data.param_to_fqn
|
||||
|
|
@ -392,7 +380,7 @@ def _apply_to_modules(
|
|||
root_module: torch.nn.Module,
|
||||
module_fn: Callable,
|
||||
return_fn: Callable,
|
||||
filter_fqns: Optional[List[str]] = None,
|
||||
filter_fqns: Optional[list[str]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -443,7 +431,7 @@ def _apply_to_modules(
|
|||
@no_type_check
|
||||
def _assert_in_training_states(
|
||||
state: _FSDPState,
|
||||
training_states: List[TrainingState],
|
||||
training_states: list[TrainingState],
|
||||
) -> None:
|
||||
"""Asserts that FSDP is in the states ``_training_states``."""
|
||||
# Raise a `ValueError` instead of using `assert` to ensure that these
|
||||
|
|
@ -462,7 +450,7 @@ def _assert_in_training_states(
|
|||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
|
||||
def _get_root_modules(modules: set[nn.Module]) -> set[nn.Module]:
|
||||
"""
|
||||
Returns:
|
||||
Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
|
||||
|
|
@ -470,7 +458,7 @@ def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
|
|||
words, these are the modules in ``modules`` that are not the child of
|
||||
any other module in ``modules``.
|
||||
"""
|
||||
root_modules: Set[nn.Module] = set()
|
||||
root_modules: set[nn.Module] = set()
|
||||
module_to_submodules = {module: set(module.modules()) for module in modules}
|
||||
for candidate_module in modules:
|
||||
is_root_module = True
|
||||
|
|
@ -488,12 +476,12 @@ def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
|
|||
|
||||
def _override_module_mixed_precision(
|
||||
root: torch.nn.Module,
|
||||
module_classes_to_override: Iterable[Type[nn.Module]],
|
||||
wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006
|
||||
) -> Set[Type[nn.Module]]:
|
||||
module_classes_to_override: Iterable[type[nn.Module]],
|
||||
wrap_override_dict: dict[str, Any] = {"mixed_precision": None}, # noqa: B006
|
||||
) -> set[type[nn.Module]]:
|
||||
module_classes_to_override = tuple(set(module_classes_to_override))
|
||||
# Return a set of the actually overridden module classes
|
||||
overridden_module_classes: Set[Type[nn.Module]] = set()
|
||||
overridden_module_classes: set[type[nn.Module]] = set()
|
||||
for mod in root.modules():
|
||||
if isinstance(mod, module_classes_to_override):
|
||||
overridden_module_classes.add(type(mod))
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Dict, Iterator, List, Set
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -28,8 +28,8 @@ class SimpleProfiler:
|
|||
H2D = "H2D"
|
||||
D2H = "D2H"
|
||||
|
||||
results: Dict[str, float] = defaultdict(float)
|
||||
profiling: Set[str] = set()
|
||||
results: dict[str, float] = defaultdict(float)
|
||||
profiling: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
|
|
@ -65,7 +65,7 @@ class SimpleProfiler:
|
|||
|
||||
def _get_sharded_module_tree_with_module_name_to_fqns(
|
||||
model: torch.nn.Module,
|
||||
) -> tuple[str, Dict[str, List[str]]]:
|
||||
) -> tuple[str, dict[str, list[str]]]:
|
||||
"""
|
||||
It is used for composable fully_shard() code path, it returns
|
||||
1. sharded module tree info: each line reprents a submodule name that contats the
|
||||
|
|
@ -143,10 +143,10 @@ def _get_sharded_module_tree_with_module_name_to_fqns(
|
|||
return sharded_tree_info[0], sharded_module_name_to_fqns
|
||||
|
||||
# Use List to mutate its value in place while running the recursive functions
|
||||
sharded_tree_info: List[str] = [
|
||||
sharded_tree_info: list[str] = [
|
||||
"",
|
||||
]
|
||||
sharded_module_name_to_fqns: Dict[str, List[str]] = {}
|
||||
sharded_module_name_to_fqns: dict[str, list[str]] = {}
|
||||
return _apply_to_modules(
|
||||
model,
|
||||
module_fn,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
from typing import Set
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _annotate_modules_for_dynamo(
|
||||
module: nn.Module,
|
||||
ignored_modules: Set[nn.Module],
|
||||
ignored_modules: set[nn.Module],
|
||||
use_orig_params: bool,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import itertools
|
||||
import warnings
|
||||
from enum import auto, Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -37,9 +37,9 @@ class _ExecOrderData:
|
|||
) -> None:
|
||||
# Tracks the (static) pre-forward order for execution order validation
|
||||
# and forward prefetching
|
||||
self.handles_pre_forward_order: List[FlatParamHandle] = []
|
||||
self.handles_pre_forward_order: list[FlatParamHandle] = []
|
||||
# Tracks the post-forward order for pre-backward prefetching
|
||||
self.handles_post_forward_order: List[Optional[FlatParamHandle]] = []
|
||||
self.handles_post_forward_order: list[Optional[FlatParamHandle]] = []
|
||||
self._iter = 0
|
||||
|
||||
# Gives the max number of backward/forward prefetched all-gathers by a
|
||||
|
|
@ -51,9 +51,9 @@ class _ExecOrderData:
|
|||
self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL
|
||||
self.process_group: Optional[dist.ProcessGroup] = None
|
||||
self.world_size: Optional[int] = None
|
||||
self.all_handles: List[FlatParamHandle] = []
|
||||
self.all_handles: list[FlatParamHandle] = []
|
||||
# Names are prefixed from the root module
|
||||
self.param_to_fqn: Dict[nn.Parameter, List[str]] = {}
|
||||
self.param_to_fqn: dict[nn.Parameter, list[str]] = {}
|
||||
# Current index in the pre-forward execution order
|
||||
self.current_order_index = 0
|
||||
self.warn_status = _ExecOrderWarnStatus.NONE
|
||||
|
|
@ -197,7 +197,7 @@ class _ExecOrderData:
|
|||
num_valid_indices = sum(
|
||||
(index is not None) for index in optional_local_indices
|
||||
)
|
||||
tensor_kwargs: Dict[str, Union[torch.dtype, torch.device]] = {
|
||||
tensor_kwargs: dict[str, Union[torch.dtype, torch.device]] = {
|
||||
"dtype": torch.int32,
|
||||
"device": device,
|
||||
}
|
||||
|
|
@ -313,7 +313,7 @@ class _ExecOrderData:
|
|||
corresponding to the handles in ``handle``. An entry in the
|
||||
returned tuple is ``None`` if the handle is invalid.
|
||||
"""
|
||||
indices: List[Optional[int]] = []
|
||||
indices: list[Optional[int]] = []
|
||||
if handle:
|
||||
indices.append(handle._handle_index)
|
||||
return tuple(indices)
|
||||
|
|
@ -321,13 +321,13 @@ class _ExecOrderData:
|
|||
def _get_names_from_handle_indices(
|
||||
self,
|
||||
handle_indices: tuple[int, ...],
|
||||
) -> List[List[str]]:
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Returns a list of FQNs for each handle in ``handle_indices``. If a
|
||||
handle index is invalid, then its FQNs are omitted from the returned
|
||||
list.
|
||||
"""
|
||||
fqns: List[List[str]] = []
|
||||
fqns: list[list[str]] = []
|
||||
for index in handle_indices:
|
||||
if index is None or index < 0 or index >= len(self.all_handles):
|
||||
continue
|
||||
|
|
@ -339,12 +339,12 @@ class _ExecOrderData:
|
|||
def _get_names_from_handles(
|
||||
self,
|
||||
handle: FlatParamHandle,
|
||||
) -> List[List[str]]:
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Returns a list of FQNs for each handle in ``handles_key``. If a handle
|
||||
is invalid, then its FQNs are omitted from the returned list.
|
||||
"""
|
||||
fqns: List[List[str]] = []
|
||||
fqns: list[list[str]] = []
|
||||
if handle:
|
||||
flat_param = handle.flat_param
|
||||
if flat_param in self.param_to_fqn:
|
||||
|
|
|
|||
|
|
@ -4,24 +4,10 @@ import functools
|
|||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Generator, Iterator, Sequence
|
||||
from enum import auto, Enum
|
||||
from itertools import accumulate, chain
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
NamedTuple,
|
||||
no_type_check,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, NamedTuple, no_type_check, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -354,7 +340,7 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
|
|||
_numels: tuple[int, ...]
|
||||
_shard_param_infos: tuple[_ShardParamInfo, ...]
|
||||
_shared_param_infos: tuple[SharedParamInfo, ...]
|
||||
_modules: Set[nn.Module]
|
||||
_modules: set[nn.Module]
|
||||
_shard_numel_padded: int
|
||||
_local_shard: Tensor
|
||||
_full_param_padded: Tensor
|
||||
|
|
@ -366,12 +352,12 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
|
|||
_mp_shard: Tensor
|
||||
_cpu_grad: Tensor
|
||||
_saved_grad_shard: Tensor
|
||||
_params: Optional[List[nn.Parameter]]
|
||||
_shared_params: Optional[List[nn.Parameter]]
|
||||
_tensors: Optional[List[Optional[Tensor]]]
|
||||
_is_grad_none_mask: Optional[List[bool]]
|
||||
_params: Optional[list[nn.Parameter]]
|
||||
_shared_params: Optional[list[nn.Parameter]]
|
||||
_tensors: Optional[list[Optional[Tensor]]]
|
||||
_is_grad_none_mask: Optional[list[bool]]
|
||||
|
||||
_is_padding_mask: List[bool]
|
||||
_is_padding_mask: list[bool]
|
||||
|
||||
def __new__(cls, data=None, requires_grad=True):
|
||||
assert cls is FlatParameter, "subclasses FlatParameter not supported"
|
||||
|
|
@ -386,17 +372,17 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
|
|||
def _init_metadata(
|
||||
cls,
|
||||
self,
|
||||
param_infos: List[ParamInfo],
|
||||
numels: List[int],
|
||||
shapes: List[torch.Size],
|
||||
strides: List[tuple[int, ...]],
|
||||
contiguities: List[bool],
|
||||
fqns: List[str],
|
||||
shared_param_infos: List[SharedParamInfo],
|
||||
param_extensions: List[Optional[Any]],
|
||||
params: Optional[List[nn.Parameter]],
|
||||
shared_params: Optional[List[nn.Parameter]],
|
||||
is_padding_mask: List[bool],
|
||||
param_infos: list[ParamInfo],
|
||||
numels: list[int],
|
||||
shapes: list[torch.Size],
|
||||
strides: list[tuple[int, ...]],
|
||||
contiguities: list[bool],
|
||||
fqns: list[str],
|
||||
shared_param_infos: list[SharedParamInfo],
|
||||
param_extensions: list[Optional[Any]],
|
||||
params: Optional[list[nn.Parameter]],
|
||||
shared_params: Optional[list[nn.Parameter]],
|
||||
is_padding_mask: list[bool],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize attributes holding metadata about the original parameters comprising the flat parameter.
|
||||
|
|
@ -426,7 +412,7 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
|
|||
self._param_extensions = param_extensions
|
||||
self._is_padding_mask = is_padding_mask
|
||||
|
||||
numels_without_padding: List[int] = []
|
||||
numels_without_padding: list[int] = []
|
||||
for numel, is_padding in zip(numels, is_padding_mask):
|
||||
if not is_padding:
|
||||
numels_without_padding.append(numel)
|
||||
|
|
@ -624,7 +610,7 @@ class FlatParamHandle:
|
|||
|
||||
def _init_flat_param_and_metadata(
|
||||
self,
|
||||
params: List[Union[Tensor, nn.Parameter]],
|
||||
params: list[Union[Tensor, nn.Parameter]],
|
||||
module: nn.Module,
|
||||
aligned_numel: int,
|
||||
use_orig_params: bool,
|
||||
|
|
@ -653,20 +639,20 @@ class FlatParamHandle:
|
|||
params_set = set(params)
|
||||
# For alignment padding, only `numels` gets strictly non-`None`
|
||||
# elements, and all other lists get `None` elements for padding.
|
||||
param_infos: List[ParamInfo] = []
|
||||
numels: List[int] = []
|
||||
shapes: List[torch.Size] = []
|
||||
strides: List[tuple[int, ...]] = []
|
||||
contiguities: List[bool] = []
|
||||
fqns: List[str] = []
|
||||
shared_param_infos: List[SharedParamInfo] = []
|
||||
shared_param_memo: Dict[
|
||||
param_infos: list[ParamInfo] = []
|
||||
numels: list[int] = []
|
||||
shapes: list[torch.Size] = []
|
||||
strides: list[tuple[int, ...]] = []
|
||||
contiguities: list[bool] = []
|
||||
fqns: list[str] = []
|
||||
shared_param_infos: list[SharedParamInfo] = []
|
||||
shared_param_memo: dict[
|
||||
Union[Tensor, nn.Parameter], tuple[nn.Module, str, str]
|
||||
] = {}
|
||||
params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
|
||||
shared_params: List[Union[Tensor, nn.Parameter]] = []
|
||||
param_extensions: List[Any] = []
|
||||
is_padding_mask: List[bool] = []
|
||||
params_to_flatten: list[Union[Tensor, nn.Parameter]] = []
|
||||
shared_params: list[Union[Tensor, nn.Parameter]] = []
|
||||
param_extensions: list[Any] = []
|
||||
is_padding_mask: list[bool] = []
|
||||
total_numel = total_numel_without_padding = 0
|
||||
for submodule_name, submodule in module.named_modules(remove_duplicate=False):
|
||||
for param_name, param in _named_parameters_with_duplicates(
|
||||
|
|
@ -779,8 +765,8 @@ class FlatParamHandle:
|
|||
)
|
||||
|
||||
def _validate_tensors_to_flatten(
|
||||
self, tensors: List[Union[Tensor, nn.Parameter]]
|
||||
) -> Tuple:
|
||||
self, tensors: list[Union[Tensor, nn.Parameter]]
|
||||
) -> tuple:
|
||||
"""Validate the tensors to flatten and returns any necessary metadata."""
|
||||
dtype: Optional[torch.dtype] = None
|
||||
# Return as the logical OR over each tensor's value
|
||||
|
|
@ -819,7 +805,7 @@ class FlatParamHandle:
|
|||
|
||||
def flatten_tensors(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
aligned_numel: int,
|
||||
) -> Tensor:
|
||||
"""
|
||||
|
|
@ -841,7 +827,7 @@ class FlatParamHandle:
|
|||
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
|
||||
)
|
||||
dtype, _, device = self._validate_tensors_to_flatten(tensors)
|
||||
flat_tensors: List[Tensor] = []
|
||||
flat_tensors: list[Tensor] = []
|
||||
if aligned_numel > 0:
|
||||
total_numel = 0
|
||||
for tensor in tensors:
|
||||
|
|
@ -876,7 +862,7 @@ class FlatParamHandle:
|
|||
|
||||
def flatten_tensors_into_flat_param(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
aligned_numel: int,
|
||||
requires_grad: bool,
|
||||
) -> FlatParameter:
|
||||
|
|
@ -1013,7 +999,7 @@ class FlatParamHandle:
|
|||
assert len(flat_param_offsets) == len(
|
||||
self.flat_param._numels_with_padding
|
||||
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
|
||||
shard_param_infos: List[_ShardParamInfo] = []
|
||||
shard_param_infos: list[_ShardParamInfo] = []
|
||||
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
|
||||
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
|
||||
# into the unsharded flat parameter (inclusive) of the given parameter
|
||||
|
|
@ -1130,7 +1116,7 @@ class FlatParamHandle:
|
|||
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
|
||||
return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
|
||||
|
||||
def _get_flat_param_offsets(self) -> List[tuple[int, int]]:
|
||||
def _get_flat_param_offsets(self) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
|
||||
|
||||
|
|
@ -1884,7 +1870,7 @@ class FlatParamHandle:
|
|||
def _get_unflat_views_aligned(
|
||||
self,
|
||||
tensor: Optional[Tensor] = None,
|
||||
) -> List[Tensor]:
|
||||
) -> list[Tensor]:
|
||||
"""
|
||||
Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
|
||||
|
||||
|
|
@ -1895,11 +1881,11 @@ class FlatParamHandle:
|
|||
flat_param = self.flat_param
|
||||
if tensor is None:
|
||||
tensor = flat_param
|
||||
splits: List[Tensor] = torch.split(
|
||||
splits: list[Tensor] = torch.split(
|
||||
tensor, flat_param._numels_with_padding, dim=0
|
||||
)
|
||||
idx = 0
|
||||
views: List[Tensor] = []
|
||||
views: list[Tensor] = []
|
||||
for split, is_padding in zip(splits, flat_param._is_padding_mask):
|
||||
if is_padding:
|
||||
continue
|
||||
|
|
@ -2462,7 +2448,7 @@ class FlatParamHandle:
|
|||
else:
|
||||
self._use_unsharded_views(as_params=True)
|
||||
|
||||
def _get_modules(self) -> Set[nn.Module]:
|
||||
def _get_modules(self) -> set[nn.Module]:
|
||||
"""Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
|
||||
return {pi.module for pi in self.flat_param._param_infos}.union(
|
||||
{spi.module for spi in self.flat_param._shared_param_infos}
|
||||
|
|
@ -2514,9 +2500,9 @@ class FlatParamHandle:
|
|||
yield (param_name, module_name)
|
||||
|
||||
@property
|
||||
def _fqns_in_shard(self) -> List[str]:
|
||||
def _fqns_in_shard(self) -> list[str]:
|
||||
"""Return the FQNs of the parameters present in this rank's shard."""
|
||||
fqns_in_shard: List[str] = []
|
||||
fqns_in_shard: list[str] = []
|
||||
for fqn, shard_param_info in zip(
|
||||
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
|
||||
):
|
||||
|
|
@ -2708,8 +2694,8 @@ def _safe_setattr_tensor_or_param(
|
|||
|
||||
|
||||
def _convert_to_params(
|
||||
tensors: List[Union[torch.Tensor, nn.Parameter]]
|
||||
) -> List[nn.Parameter]:
|
||||
tensors: list[Union[torch.Tensor, nn.Parameter]]
|
||||
) -> list[nn.Parameter]:
|
||||
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -64,7 +64,7 @@ class FSDPExtensions(ABC):
|
|||
def pre_load_state_dict_transform(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, List[Shard]]:
|
||||
) -> tuple[torch.Tensor, list[Shard]]:
|
||||
"""
|
||||
This is to be called before loading a *sharded* model state dict and
|
||||
should return the tensor and list of shards from which to load data.
|
||||
|
|
@ -157,7 +157,7 @@ def _ext_chunk_dtensor(
|
|||
def _ext_pre_load_state_dict_transform(
|
||||
tensor: torch.Tensor,
|
||||
fsdp_extension: Optional[FSDPExtensions] = None,
|
||||
) -> tuple[torch.Tensor, List[Shard]]:
|
||||
) -> tuple[torch.Tensor, list[Shard]]:
|
||||
if fsdp_extension is not None:
|
||||
return fsdp_extension.pre_load_state_dict_transform(tensor)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import cast, List, NamedTuple, Optional, Union
|
||||
from typing import cast, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -20,12 +20,12 @@ class AllGatherResult(NamedTuple):
|
|||
all_gather_event: Optional[torch.Event]
|
||||
all_gather_work: Optional[dist.distributed_c10d.Work]
|
||||
# For each parameter, the all-gather input dtype for each input
|
||||
param_all_gather_input_dtypes: List[List[torch.dtype]]
|
||||
param_all_gather_input_dtypes: list[list[torch.dtype]]
|
||||
# For each parameter, the all-gather input numel for each input
|
||||
param_all_gather_input_numels: List[List[int]]
|
||||
param_all_gather_input_numels: list[list[int]]
|
||||
# 1D flattened version of `param_all_gather_input_numels` saved to avoid
|
||||
# CPU overhead from recomputing
|
||||
all_gather_input_split_sizes: List[int]
|
||||
all_gather_input_split_sizes: list[int]
|
||||
|
||||
|
||||
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
|
||||
|
|
@ -47,8 +47,8 @@ lib.define(
|
|||
|
||||
@torch.library.impl(lib, "all_gather_copy_in", "Meta")
|
||||
def all_gather_copy_in_meta(
|
||||
all_gather_inputs: List[torch.Tensor],
|
||||
inp_split_sizes: List[int],
|
||||
all_gather_inputs: list[torch.Tensor],
|
||||
inp_split_sizes: list[int],
|
||||
all_gather_input_numel: int,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
|
|
@ -68,8 +68,8 @@ def all_gather_copy_in_meta(
|
|||
@torch.library.impl(lib, "all_gather_copy_in", "XPU")
|
||||
@torch.library.impl(lib, "all_gather_copy_in", "CPU")
|
||||
def all_gather_copy_in_cuda(
|
||||
all_gather_inputs: List[torch.Tensor],
|
||||
inp_split_sizes: List[int],
|
||||
all_gather_inputs: list[torch.Tensor],
|
||||
inp_split_sizes: list[int],
|
||||
all_gather_input_numel: int,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
|
|
@ -99,9 +99,9 @@ lib.define(
|
|||
@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
|
||||
def split_with_sizes_copy(
|
||||
all_gather_output: torch.Tensor,
|
||||
all_gather_input_split_sizes: List[int],
|
||||
all_gather_input_split_sizes: list[int],
|
||||
dim: int,
|
||||
out: List[torch.Tensor],
|
||||
out: list[torch.Tensor],
|
||||
) -> None:
|
||||
torch.split_with_sizes_copy(
|
||||
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
|
||||
|
|
@ -118,7 +118,7 @@ lib.define(
|
|||
@torch.library.impl(lib, "chunk_cat", "XPU")
|
||||
@torch.library.impl(lib, "chunk_cat", "CPU")
|
||||
def chunk_cat(
|
||||
tensors: List[torch.Tensor],
|
||||
tensors: list[torch.Tensor],
|
||||
dim: int,
|
||||
num_chunks: int,
|
||||
out: torch.Tensor,
|
||||
|
|
@ -128,7 +128,7 @@ def chunk_cat(
|
|||
|
||||
@torch.no_grad()
|
||||
def foreach_all_gather(
|
||||
fsdp_params: List[FSDPParam],
|
||||
fsdp_params: list[FSDPParam],
|
||||
group: dist.ProcessGroup,
|
||||
async_op: bool,
|
||||
all_gather_copy_in_stream: torch.Stream,
|
||||
|
|
@ -183,8 +183,8 @@ def foreach_all_gather(
|
|||
|
||||
@torch.no_grad()
|
||||
def _get_param_all_gather_inputs(
|
||||
fsdp_params: List[FSDPParam],
|
||||
) -> List[List[torch.Tensor]]:
|
||||
fsdp_params: list[FSDPParam],
|
||||
) -> list[list[torch.Tensor]]:
|
||||
if compiled_autograd_enabled():
|
||||
return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params]
|
||||
|
||||
|
|
@ -198,10 +198,10 @@ def _get_param_all_gather_inputs(
|
|||
and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather")
|
||||
)
|
||||
|
||||
param_all_gather_inputs: List[List[torch.Tensor]] = [[] for _ in fsdp_params]
|
||||
foreach_copy_indices: List[int] = []
|
||||
foreach_copy_inputs: List[torch.Tensor] = []
|
||||
foreach_copy_input_numels: List[int] = []
|
||||
param_all_gather_inputs: list[list[torch.Tensor]] = [[] for _ in fsdp_params]
|
||||
foreach_copy_indices: list[int] = []
|
||||
foreach_copy_inputs: list[torch.Tensor] = []
|
||||
foreach_copy_input_numels: list[int] = []
|
||||
|
||||
# 1st pass: for foreach-copy parameters, get inputs and metadata for the
|
||||
# foreach copy, and for the others, actually get their all-gather inputs
|
||||
|
|
@ -236,7 +236,7 @@ def _get_param_all_gather_inputs(
|
|||
@torch.no_grad()
|
||||
def foreach_all_gather_copy_out(
|
||||
all_gather_result: AllGatherResult,
|
||||
fsdp_params: List[FSDPParam],
|
||||
fsdp_params: list[FSDPParam],
|
||||
group: dist.ProcessGroup,
|
||||
) -> None:
|
||||
(
|
||||
|
|
@ -255,8 +255,8 @@ def foreach_all_gather_copy_out(
|
|||
all_gather_work.wait()
|
||||
world_size, device = group.size(), all_gather_output.device
|
||||
|
||||
split_with_sizes_out: List[torch.Tensor] = []
|
||||
shard_i_copy_infos: List[tuple[FSDPParam, List[torch.Tensor]]] = []
|
||||
split_with_sizes_out: list[torch.Tensor] = []
|
||||
shard_i_copy_infos: list[tuple[FSDPParam, list[torch.Tensor]]] = []
|
||||
for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
|
||||
param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
|
||||
):
|
||||
|
|
@ -320,8 +320,8 @@ def foreach_all_gather_copy_out(
|
|||
|
||||
@torch.no_grad()
|
||||
def foreach_reduce(
|
||||
fsdp_params: List[FSDPParam],
|
||||
unsharded_grads: List[torch.Tensor],
|
||||
fsdp_params: list[FSDPParam],
|
||||
unsharded_grads: list[torch.Tensor],
|
||||
reduce_scatter_group: dist.ProcessGroup,
|
||||
reduce_scatter_stream: torch.Stream,
|
||||
orig_dtype: torch.dtype,
|
||||
|
|
@ -488,7 +488,7 @@ def foreach_reduce(
|
|||
|
||||
|
||||
def foreach_reduce_scatter_copy_in(
|
||||
unsharded_grads: List[torch.Tensor],
|
||||
unsharded_grads: list[torch.Tensor],
|
||||
reduce_scatter_input: torch.Tensor,
|
||||
world_size: int,
|
||||
) -> None:
|
||||
|
|
@ -499,14 +499,14 @@ def foreach_reduce_scatter_copy_in(
|
|||
|
||||
|
||||
def _get_all_gather_input_metadatas(
|
||||
param_all_gather_inputs: List[List[torch.Tensor]],
|
||||
) -> tuple[List[List[torch.dtype]], List[List[int]], torch.dtype]:
|
||||
param_all_gather_input_dtypes: List[List[torch.dtype]] = []
|
||||
param_all_gather_input_numels: List[List[int]] = []
|
||||
param_all_gather_inputs: list[list[torch.Tensor]],
|
||||
) -> tuple[list[list[torch.dtype]], list[list[int]], torch.dtype]:
|
||||
param_all_gather_input_dtypes: list[list[torch.dtype]] = []
|
||||
param_all_gather_input_numels: list[list[int]] = []
|
||||
all_gather_dtype = param_all_gather_inputs[0][0].dtype
|
||||
for all_gather_inputs in param_all_gather_inputs:
|
||||
input_dtypes: List[torch.dtype] = []
|
||||
input_numels: List[int] = []
|
||||
input_dtypes: list[torch.dtype] = []
|
||||
input_numels: list[int] = []
|
||||
for all_gather_input in all_gather_inputs:
|
||||
if all_gather_input.dtype != all_gather_dtype:
|
||||
all_gather_dtype = torch.uint8
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import math
|
|||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -120,7 +120,7 @@ def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Si
|
|||
|
||||
def _chunk_with_empty(
|
||||
tensor: torch.Tensor, num_chunks: int, dim: int
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
|
||||
while len(chunks) < num_chunks:
|
||||
chunks.append(chunks[0].new_empty(0))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import itertools
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -70,11 +70,11 @@ def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device:
|
|||
return torch.device(mesh.device_type, device_handle.current_device())
|
||||
|
||||
|
||||
def _get_managed_modules(root_modules: tuple[nn.Module, ...]) -> List[nn.Module]:
|
||||
modules: List[nn.Module] = []
|
||||
def _get_managed_modules(root_modules: tuple[nn.Module, ...]) -> list[nn.Module]:
|
||||
modules: list[nn.Module] = []
|
||||
root_modules_set = set(root_modules)
|
||||
# Track visisted modules to avoid visiting shared modules multiple times
|
||||
visited_modules: Set[nn.Module] = set()
|
||||
visited_modules: set[nn.Module] = set()
|
||||
|
||||
def dfs(module: nn.Module) -> None:
|
||||
"""
|
||||
|
|
@ -113,14 +113,14 @@ def _verify_managed_param(name: str, param: nn.Parameter) -> None:
|
|||
|
||||
|
||||
def _get_managed_states(
|
||||
modules: List[nn.Module],
|
||||
) -> tuple[List[nn.Parameter], List[torch.Tensor]]:
|
||||
params: List[nn.Parameter] = []
|
||||
buffers: List[torch.Tensor] = []
|
||||
modules: list[nn.Module],
|
||||
) -> tuple[list[nn.Parameter], list[torch.Tensor]]:
|
||||
params: list[nn.Parameter] = []
|
||||
buffers: list[torch.Tensor] = []
|
||||
# Track visited parameters/buffers to avoid visiting shared parameters and
|
||||
# buffers multiple times
|
||||
visited_params: Set[nn.Parameter] = set()
|
||||
visited_buffers: Set[torch.Tensor] = set()
|
||||
visited_params: set[nn.Parameter] = set()
|
||||
visited_buffers: set[torch.Tensor] = set()
|
||||
for module in modules:
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param not in visited_params:
|
||||
|
|
@ -135,8 +135,8 @@ def _get_managed_states(
|
|||
|
||||
|
||||
def _move_states_to_device(
|
||||
params: List[nn.Parameter],
|
||||
buffers: List[torch.Tensor],
|
||||
params: list[nn.Parameter],
|
||||
buffers: list[torch.Tensor],
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, cast, List, Optional, Sequence
|
||||
from typing import Any, Callable, cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -171,8 +172,8 @@ class ParamModuleInfo:
|
|||
# Parameter names are unprefixed, e.g. "weight", not "lin.weight"
|
||||
module: nn.Module
|
||||
param_name: str
|
||||
shared_modules: List[nn.Module] = field(default_factory=list)
|
||||
shared_param_names: List[str] = field(default_factory=list)
|
||||
shared_modules: list[nn.Module] = field(default_factory=list)
|
||||
shared_param_names: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -211,10 +212,10 @@ class FSDPParam:
|
|||
_sharding_spec: DTensorSpec
|
||||
# DTensor attributes (only defined for DTensor `param`):
|
||||
_tp_spec: DTensorSpec
|
||||
all_gather_outputs: List[torch.Tensor] # 1D
|
||||
all_gather_outputs: list[torch.Tensor] # 1D
|
||||
# All-gather extension attributes
|
||||
_extensions_data: ExtensionsData
|
||||
_unsharded_inner_tensors: List[torch.Tensor]
|
||||
_unsharded_inner_tensors: list[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -241,7 +242,7 @@ class FSDPParam:
|
|||
if self.post_forward_mesh_info:
|
||||
self._init_sharded_post_forward_param_metadata(param)
|
||||
self._init_extensions()
|
||||
self.all_gather_outputs: List[torch.Tensor] = []
|
||||
self.all_gather_outputs: list[torch.Tensor] = []
|
||||
self.unsharded_accumulated_grad = None
|
||||
self._param_fqn: Optional[str] = None # prefixed from root module
|
||||
# TODO: Remove this padding logic once DTensor pads the local tensor:
|
||||
|
|
@ -439,12 +440,12 @@ class FSDPParam:
|
|||
)
|
||||
if has_fsdp_pre_all_gather:
|
||||
self._extensions_data = ExtensionsData()
|
||||
self._unsharded_inner_tensors: List[torch.Tensor] = []
|
||||
self._unsharded_inner_tensors: list[torch.Tensor] = []
|
||||
|
||||
def init_all_gather_outputs(
|
||||
self,
|
||||
all_gather_input_numels: List[int],
|
||||
all_gather_input_dtypes: List[torch.dtype],
|
||||
all_gather_input_numels: list[int],
|
||||
all_gather_input_dtypes: list[torch.dtype],
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
force_recreate: bool = False,
|
||||
|
|
@ -680,7 +681,7 @@ class FSDPParam:
|
|||
free_storage(tensor)
|
||||
|
||||
@property
|
||||
def all_gather_inputs(self) -> List[torch.Tensor]: # 1D
|
||||
def all_gather_inputs(self) -> list[torch.Tensor]: # 1D
|
||||
self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
|
||||
if self.sharded_state == ShardedState.SHARDED:
|
||||
if not compiled_autograd_enabled() and hasattr(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional, Set
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -31,7 +31,7 @@ from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
|
|||
|
||||
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||
|
||||
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict
|
||||
_ModuleToHandleDict = dict[nn.Module, RemovableHandle] # for state dict
|
||||
|
||||
|
||||
"""
|
||||
|
|
@ -77,7 +77,7 @@ class FSDPCommContext:
|
|||
self.all_gather_state: Optional[AllGatherState] = None
|
||||
self.reduce_scatter_state: Optional[ReduceScatterState] = None
|
||||
# Post-forward order for explicit backward prefetching
|
||||
self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles
|
||||
self.post_forward_order: list[FSDPParamGroup] = [] # will cause ref cycles
|
||||
|
||||
def get_all_gather_streams(
|
||||
self, async_op: bool, training_state: TrainingState
|
||||
|
|
@ -116,7 +116,7 @@ class FSDPParamGroup:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
params: List[nn.Parameter],
|
||||
params: list[nn.Parameter],
|
||||
modules: tuple[nn.Module, ...],
|
||||
mesh_info: FSDPMeshInfo,
|
||||
post_forward_mesh_info: Optional[FSDPMeshInfo],
|
||||
|
|
@ -162,7 +162,7 @@ class FSDPParamGroup:
|
|||
# - Communication and communication/computation overlap
|
||||
self.comm_ctx = FSDPCommContext()
|
||||
# Group's indices in the shared post-forward order
|
||||
self._post_forward_indices: List[int] = []
|
||||
self._post_forward_indices: list[int] = []
|
||||
# Whether to reduce gradients at all (whether for FSDP or HSDP)
|
||||
self.reduce_grads: bool = True
|
||||
# Whether to all-reduce gradients for HSDP; only used if
|
||||
|
|
@ -325,8 +325,8 @@ class FSDPParamGroup:
|
|||
self._to_sharded()
|
||||
|
||||
def pre_forward(
|
||||
self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
|
||||
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
if not compiled_autograd_enabled():
|
||||
logger.debug("%s", self._with_fqn("FSDP::pre_forward"))
|
||||
with record_function(self._with_fqn("FSDP::pre_forward")):
|
||||
|
|
@ -389,8 +389,8 @@ class FSDPParamGroup:
|
|||
return
|
||||
# Save the autograd-computed gradients before resharding to only
|
||||
# access the unsharded parameters when their data is present
|
||||
fsdp_params_with_grad: List[FSDPParam] = []
|
||||
unsharded_grads: List[torch.Tensor] = []
|
||||
fsdp_params_with_grad: list[FSDPParam] = []
|
||||
unsharded_grads: list[torch.Tensor] = []
|
||||
for fsdp_param in self.fsdp_params:
|
||||
if not hasattr(fsdp_param, "_unsharded_param"):
|
||||
continue
|
||||
|
|
@ -543,8 +543,8 @@ class FSDPParamGroup:
|
|||
|
||||
# Hook Registration #
|
||||
def _register_post_backward_hook(
|
||||
self, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
|
||||
self, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
# Traceable FSDP2 relies on `root_post_backward_callback` to call each
|
||||
# `FSDPParamGroup.post_backward`
|
||||
if (not torch._dynamo.config.skip_fsdp_hooks) or compiled_autograd_enabled():
|
||||
|
|
@ -554,8 +554,8 @@ class FSDPParamGroup:
|
|||
args_list, args_spec = tree_flatten(args)
|
||||
kwargs_list, kwargs_spec = tree_flatten(kwargs)
|
||||
args_kwargs_list = list(args_list) + list(kwargs_list)
|
||||
inp_tensor_indices: List[int] = []
|
||||
inp_tensors: List[torch.Tensor] = []
|
||||
inp_tensor_indices: list[int] = []
|
||||
inp_tensors: list[torch.Tensor] = []
|
||||
for i, obj in enumerate(args_kwargs_list):
|
||||
if torch.is_tensor(obj) and obj.requires_grad:
|
||||
inp_tensor_indices.append(i)
|
||||
|
|
@ -579,7 +579,7 @@ class FSDPParamGroup:
|
|||
), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
|
||||
if num_pre_save_hooks > 0:
|
||||
return # already registered
|
||||
modules_with_fsdp_params: Set[nn.Module] = {
|
||||
modules_with_fsdp_params: set[nn.Module] = {
|
||||
fsdp_param._module_info.module for fsdp_param in self.fsdp_params
|
||||
}
|
||||
|
||||
|
|
@ -670,8 +670,8 @@ class FSDPParamGroup:
|
|||
|
||||
|
||||
def _get_param_module_infos(
|
||||
params: List[nn.Parameter], modules: tuple[nn.Module, ...]
|
||||
) -> List[ParamModuleInfo]:
|
||||
params: list[nn.Parameter], modules: tuple[nn.Module, ...]
|
||||
) -> list[ParamModuleInfo]:
|
||||
"""
|
||||
Shared parameter: lin1.weight = lin2.weight
|
||||
Shared module: mlp.lin1 = mlp.lin2
|
||||
|
|
@ -679,7 +679,7 @@ def _get_param_module_infos(
|
|||
find shared modules' parameters and shared parameters within a module.
|
||||
"""
|
||||
params_set = set(params)
|
||||
param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {}
|
||||
param_to_module_info: dict[nn.Parameter, ParamModuleInfo] = {}
|
||||
for module in modules:
|
||||
for _, submodule in module.named_modules(remove_duplicate=False):
|
||||
for param_name, param in _named_parameters_with_duplicates(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -40,7 +41,7 @@ class FSDPStateContext:
|
|||
|
||||
def __init__(self) -> None:
|
||||
# All FSDP states in the root state's module tree
|
||||
self.all_states: List[FSDPState] = []
|
||||
self.all_states: list[FSDPState] = []
|
||||
# Iteration's forward root runs the once-per-forward logic; this root
|
||||
# may not be the overall root set by lazy initialization in cases where
|
||||
# only a submodule runs forward (e.g. encoder-only for eval)
|
||||
|
|
@ -73,9 +74,9 @@ class FSDPState(_State):
|
|||
self._state_ctx = FSDPStateContext()
|
||||
self._comm_ctx = FSDPCommContext()
|
||||
self._training_state: TrainingState = TrainingState.IDLE
|
||||
self._states_to_forward_prefetch: List[FSDPState] = []
|
||||
self._states_to_backward_prefetch: List[FSDPState] = []
|
||||
self._modules_to_run_forward: Set[nn.Module] = set()
|
||||
self._states_to_forward_prefetch: list[FSDPState] = []
|
||||
self._states_to_backward_prefetch: list[FSDPState] = []
|
||||
self._modules_to_run_forward: set[nn.Module] = set()
|
||||
|
||||
# Define a separate init since `__init__` is called in the contract
|
||||
def init(
|
||||
|
|
@ -108,8 +109,8 @@ class FSDPState(_State):
|
|||
self._post_forward_hook_handle = hook_handle
|
||||
|
||||
def _root_pre_forward(
|
||||
self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
|
||||
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
self._lazy_init()
|
||||
if self._state_ctx.iter_forward_root is not None:
|
||||
return args, kwargs
|
||||
|
|
@ -150,7 +151,7 @@ class FSDPState(_State):
|
|||
)
|
||||
detect_compiled_autograd()
|
||||
root_module = self._modules[0]
|
||||
visited_states: Set[FSDPState] = set()
|
||||
visited_states: set[FSDPState] = set()
|
||||
for module_name, module in root_module.named_modules():
|
||||
if (state := _get_module_fsdp_state(module)) is None:
|
||||
continue
|
||||
|
|
@ -188,8 +189,8 @@ class FSDPState(_State):
|
|||
"""Sets module and parameter FQN attributes for debugging."""
|
||||
assert self._is_root
|
||||
root_module = self._modules[0]
|
||||
param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {}
|
||||
module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {}
|
||||
param_to_fsdp_param: dict[nn.Parameter, FSDPParam] = {}
|
||||
module_to_fsdp_param_group: dict[nn.Module, FSDPParamGroup] = {}
|
||||
for state in self._state_ctx.all_states:
|
||||
if fsdp_param_group := state._fsdp_param_group:
|
||||
for fsdp_param in fsdp_param_group.fsdp_params:
|
||||
|
|
@ -211,8 +212,8 @@ class FSDPState(_State):
|
|||
|
||||
@disable_if_config_true
|
||||
def _pre_forward(
|
||||
self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
|
||||
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
# When composing with module-hook-based activation checkpointing, the
|
||||
# the pre-backward hook is responsible for the unshard
|
||||
if self._training_state == TrainingState.PRE_BACKWARD:
|
||||
|
|
@ -343,7 +344,7 @@ def _register_group_forward_hooks(
|
|||
modules: Sequence[nn.Module],
|
||||
pre_hook: Callable,
|
||||
post_hook: Callable,
|
||||
modules_to_run: Set[nn.Module],
|
||||
modules_to_run: set[nn.Module],
|
||||
):
|
||||
"""
|
||||
Registers group forward pre and post-hooks. The pre-hook runs upon the
|
||||
|
|
|
|||
|
|
@ -1,18 +1,8 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, cast, NoReturn, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -42,14 +32,14 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
cls_to_fsdp_cls: Dict[Type, Type] = {}
|
||||
cls_to_fsdp_cls: dict[type, type] = {}
|
||||
|
||||
|
||||
# The decorator adds a state object to `module` that can be accessed via
|
||||
# `fully_shard.state(module)`. The state object and module are 1:1.
|
||||
@contract(state_cls=FSDPState)
|
||||
def fully_shard(
|
||||
module: Union[nn.Module, List[nn.Module]],
|
||||
module: Union[nn.Module, list[nn.Module]],
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = None,
|
||||
reshard_after_forward: Union[bool, int] = True,
|
||||
|
|
@ -331,7 +321,7 @@ class FSDPModule:
|
|||
if fsdp_param_group := state._fsdp_param_group:
|
||||
fsdp_param_group.reshard_after_backward = reshard_after_backward
|
||||
|
||||
def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None:
|
||||
def set_modules_to_forward_prefetch(self, modules: list["FSDPModule"]) -> None:
|
||||
"""
|
||||
Sets the FSDP modules for which this FSDP module should explicitly
|
||||
prefetch all-gathers in forward. The prefetching runs after this
|
||||
|
|
@ -351,7 +341,7 @@ class FSDPModule:
|
|||
module._get_fsdp_state() for module in modules
|
||||
]
|
||||
|
||||
def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None:
|
||||
def set_modules_to_backward_prefetch(self, modules: list["FSDPModule"]) -> None:
|
||||
"""
|
||||
Sets the FSDP modules for which this FSDP module should explicitly
|
||||
prefetch all-gathers in backward. This overrides the default backward
|
||||
|
|
|
|||
|
|
@ -3,21 +3,8 @@ import collections
|
|||
import itertools
|
||||
import os
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
no_type_check,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Generator, Iterable, Iterator
|
||||
from typing import Any, Callable, no_type_check, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -329,7 +316,7 @@ def _init_ignored_module_states(
|
|||
|
||||
|
||||
def _check_ignored_states(
|
||||
ignored_states: List[Any], passed_as_ignored_states: bool
|
||||
ignored_states: list[Any], passed_as_ignored_states: bool
|
||||
) -> None:
|
||||
"""
|
||||
Check that the ignored states are uniformly parameters or uniformly modules.
|
||||
|
|
@ -361,7 +348,7 @@ def _check_ignored_states(
|
|||
def _init_device_handle(
|
||||
state: _FSDPState,
|
||||
module: nn.Module,
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_params: set[nn.Parameter],
|
||||
device_id: Optional[Union[int, torch.device]],
|
||||
) -> _FSDPState:
|
||||
"""
|
||||
|
|
@ -416,7 +403,7 @@ def _init_buffer_state(
|
|||
# `module`) to its original dtype for restoring that dtype during model
|
||||
# checkpointing when buffer mixed precision is enabled. The names should
|
||||
# be clean since the casting happens in a `summon_full_params()` context.
|
||||
_buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
|
||||
_buffer_name_to_orig_dtype: dict[str, torch.dtype] = {}
|
||||
for buffer_name, buffer in module.named_buffers():
|
||||
buffer_name = clean_tensor_name(buffer_name)
|
||||
_buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
|
||||
|
|
@ -479,13 +466,13 @@ def _init_core_state(
|
|||
state._unshard_event = None
|
||||
# Mapping from fully sharded module to the handles it is responsible to
|
||||
# unshard and reshard (see [Note: Fully Sharded Module])
|
||||
_fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = {}
|
||||
_fully_sharded_module_to_handle: dict[nn.Module, FlatParamHandle] = {}
|
||||
state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
|
||||
# Invariant: `state.params` contains exactly the `FlatParameter`s of the
|
||||
# handles in `state._handle`
|
||||
_handle: Optional[FlatParamHandle] = None
|
||||
state._handle = _handle
|
||||
params: List[FlatParameter] = []
|
||||
params: list[FlatParameter] = []
|
||||
state.params = params
|
||||
return state
|
||||
|
||||
|
|
@ -494,11 +481,11 @@ def _init_core_state(
|
|||
def _init_runtime_state(
|
||||
state: _FSDPState,
|
||||
) -> _FSDPState:
|
||||
_root_pre_forward_handles: List[RemovableHandle] = []
|
||||
_root_pre_forward_handles: list[RemovableHandle] = []
|
||||
state._root_pre_forward_handles = _root_pre_forward_handles
|
||||
_pre_forward_handles: List[RemovableHandle] = []
|
||||
_pre_forward_handles: list[RemovableHandle] = []
|
||||
state._pre_forward_handles = _pre_forward_handles
|
||||
_post_forward_handles: List[RemovableHandle] = []
|
||||
_post_forward_handles: list[RemovableHandle] = []
|
||||
state._post_forward_handles = _post_forward_handles
|
||||
state._sync_gradients = True
|
||||
state._comm_hook = None
|
||||
|
|
@ -542,13 +529,13 @@ def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
|
|||
state_dict_config: StateDictConfig = FullStateDictConfig()
|
||||
state._optim_state_dict_config = FullOptimStateDictConfig()
|
||||
state._state_dict_config = state_dict_config
|
||||
unshard_params_ctx: Dict[nn.Module, Generator] = {}
|
||||
unshard_params_ctx: dict[nn.Module, Generator] = {}
|
||||
state._unshard_params_ctx = unshard_params_ctx
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def _verify_managed_params(module: nn.Module, params: List[nn.Parameter]) -> None:
|
||||
def _verify_managed_params(module: nn.Module, params: list[nn.Parameter]) -> None:
|
||||
"""
|
||||
Verify if the parameters are accepted by FSDP. The only restriction now
|
||||
is that the parameter cannot be a scalar tensor (param.shape == []).
|
||||
|
|
@ -639,7 +626,7 @@ def _init_param_handle_from_module(
|
|||
@no_type_check
|
||||
def _init_param_handle_from_params(
|
||||
state: _FSDPState,
|
||||
params: List[nn.Parameter],
|
||||
params: list[nn.Parameter],
|
||||
fully_sharded_module: nn.Module,
|
||||
):
|
||||
if len(params) == 0:
|
||||
|
|
@ -670,7 +657,7 @@ def _init_param_handle_from_params(
|
|||
def _get_ignored_modules(
|
||||
root_module: nn.Module,
|
||||
_ignored_modules: Optional[Iterable[torch.nn.Module]],
|
||||
) -> Set[nn.Module]:
|
||||
) -> set[nn.Module]:
|
||||
"""
|
||||
Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances.
|
||||
|
||||
|
|
@ -726,15 +713,15 @@ def _get_ignored_modules(
|
|||
|
||||
def _get_ignored_params(
|
||||
root_module: torch.nn.Module,
|
||||
ignored_modules: Set[torch.nn.Module],
|
||||
ignored_modules: set[torch.nn.Module],
|
||||
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
|
||||
) -> Set[torch.nn.Parameter]:
|
||||
) -> set[torch.nn.Parameter]:
|
||||
"""
|
||||
Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``.
|
||||
|
||||
:class:`FlatParameter` s are excluded from the result.
|
||||
"""
|
||||
all_ignored_params: Set[torch.nn.Parameter] = set()
|
||||
all_ignored_params: set[torch.nn.Parameter] = set()
|
||||
|
||||
params_in_ignored_modules = {
|
||||
p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
|
||||
|
|
@ -760,10 +747,10 @@ def _get_ignored_params(
|
|||
|
||||
def _get_ignored_buffer_names(
|
||||
root_module: torch.nn.Module,
|
||||
ignored_modules: Set[torch.nn.Module],
|
||||
) -> Set[str]:
|
||||
ignored_modules: set[torch.nn.Module],
|
||||
) -> set[str]:
|
||||
"""Return the cleaned buffer FQNs in ``ignored_modules``."""
|
||||
all_ignored_buffer_names: Set[str] = set()
|
||||
all_ignored_buffer_names: set[str] = set()
|
||||
|
||||
buffers_in_ignored_modules = {
|
||||
buffer for m in ignored_modules for buffer in m.buffers()
|
||||
|
|
@ -787,7 +774,7 @@ def _get_ignored_buffer_names(
|
|||
return all_ignored_buffer_names
|
||||
|
||||
|
||||
def _get_buffer_names(root_module: nn.Module) -> Set[str]:
|
||||
def _get_buffer_names(root_module: nn.Module) -> set[str]:
|
||||
"""Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`."""
|
||||
return {
|
||||
clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
|
||||
|
|
@ -796,7 +783,7 @@ def _get_buffer_names(root_module: nn.Module) -> Set[str]:
|
|||
|
||||
def _check_single_device_module(
|
||||
module: nn.Module,
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_params: set[nn.Parameter],
|
||||
device_id: Optional[Union[int, torch.device]],
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -855,8 +842,8 @@ def _get_device_from_device_id(
|
|||
|
||||
def _need_to_materialize_module(
|
||||
module: nn.Module,
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_modules: Set[nn.Module],
|
||||
ignored_params: set[nn.Parameter],
|
||||
ignored_modules: set[nn.Module],
|
||||
) -> tuple[bool, bool]:
|
||||
"""
|
||||
Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization.
|
||||
|
|
@ -886,7 +873,7 @@ def _need_to_materialize_module(
|
|||
def _materialize_with_param_init_fn(
|
||||
root_module: nn.Module,
|
||||
param_init_fn: Callable[[nn.Module], None],
|
||||
ignored_modules: Set[nn.Module],
|
||||
ignored_modules: set[nn.Module],
|
||||
) -> None:
|
||||
if not callable(param_init_fn):
|
||||
raise ValueError(
|
||||
|
|
@ -900,7 +887,7 @@ def _materialize_with_param_init_fn(
|
|||
def _materialize_meta_module(
|
||||
root_module: nn.Module,
|
||||
device_from_device_id: Optional[torch.device],
|
||||
ignored_modules: Set[nn.Module],
|
||||
ignored_modules: set[nn.Module],
|
||||
device_handle: _FSDPDeviceHandle,
|
||||
):
|
||||
# Run default meta device initialization
|
||||
|
|
@ -933,13 +920,13 @@ def _materialize_meta_module(
|
|||
|
||||
|
||||
def _get_modules_to_materialize(
|
||||
root_module: nn.Module, ignored_modules: Set[nn.Module]
|
||||
) -> List[nn.Module]:
|
||||
root_module: nn.Module, ignored_modules: set[nn.Module]
|
||||
) -> list[nn.Module]:
|
||||
# Run BFS to collect the modules to materialize via `reset_parameters()`,
|
||||
# stopping at any module with FSDP already applied or at ignored modules.
|
||||
modules_to_materialize: List[nn.Module] = []
|
||||
modules_to_materialize: list[nn.Module] = []
|
||||
queue = collections.deque([root_module])
|
||||
visited_modules: Set[nn.Module] = {root_module}
|
||||
visited_modules: set[nn.Module] = {root_module}
|
||||
while queue:
|
||||
module = queue.popleft()
|
||||
modules_to_materialize.append(module)
|
||||
|
|
@ -956,8 +943,8 @@ def _get_modules_to_materialize(
|
|||
|
||||
def _move_module_to_device(
|
||||
module: nn.Module,
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_buffers: Set[torch.Tensor],
|
||||
ignored_params: set[nn.Parameter],
|
||||
ignored_buffers: set[torch.Tensor],
|
||||
device_from_device_id: Optional[torch.device],
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -976,10 +963,10 @@ def _move_module_to_device(
|
|||
if device_from_device_id is not None:
|
||||
# BFS from `module` without traversing any nested FSDP instances to
|
||||
# collect the parameters/buffers that have not yet been managed
|
||||
queue: Deque[nn.Module] = collections.deque()
|
||||
queue: collections.deque[nn.Module] = collections.deque()
|
||||
queue.append(module)
|
||||
params: List[nn.Parameter] = []
|
||||
buffers: List[torch.Tensor] = []
|
||||
params: list[nn.Parameter] = []
|
||||
buffers: list[torch.Tensor] = []
|
||||
while queue:
|
||||
curr_module = queue.popleft()
|
||||
# NOTE: We include a check to only move parameters/buffers that are
|
||||
|
|
@ -1009,8 +996,8 @@ def _move_module_to_device(
|
|||
|
||||
|
||||
def _move_states_to_device(
|
||||
params: List[nn.Parameter],
|
||||
buffers: List[torch.Tensor],
|
||||
params: list[nn.Parameter],
|
||||
buffers: list[torch.Tensor],
|
||||
device_from_device_id: Optional[torch.device],
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -1053,7 +1040,7 @@ def _warn_cpu_init():
|
|||
|
||||
def _get_compute_device(
|
||||
module: nn.Module,
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_params: set[nn.Parameter],
|
||||
device_from_device_id: Optional[torch.device],
|
||||
rank: int,
|
||||
device_handle: _FSDPDeviceHandle,
|
||||
|
|
@ -1088,7 +1075,7 @@ def _get_compute_device(
|
|||
# TODO: See how to deprecate!
|
||||
def _sync_module_params_and_buffers(
|
||||
module: nn.Module,
|
||||
params: List[nn.Parameter],
|
||||
params: list[nn.Parameter],
|
||||
process_group: dist.ProcessGroup,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -1097,7 +1084,7 @@ def _sync_module_params_and_buffers(
|
|||
Precondition: ``sync_module_states == True`` and ``self.process_group`` has
|
||||
been set.
|
||||
"""
|
||||
module_states: List[torch.Tensor] = []
|
||||
module_states: list[torch.Tensor] = []
|
||||
for buffer in module.buffers():
|
||||
# Avoid re-synchronizing buffers in case of nested wrapping
|
||||
if not getattr(buffer, FSDP_SYNCED, False):
|
||||
|
|
@ -1131,7 +1118,7 @@ def _sync_module_params_and_buffers(
|
|||
|
||||
|
||||
def _check_module_states_for_sync_module_states(
|
||||
module_states: List[torch.Tensor],
|
||||
module_states: list[torch.Tensor],
|
||||
) -> None:
|
||||
if module_states and any(
|
||||
tensor.device == torch.device("cpu") for tensor in module_states
|
||||
|
|
@ -1145,7 +1132,7 @@ def _check_module_states_for_sync_module_states(
|
|||
|
||||
def _get_orig_params(
|
||||
module: nn.Module,
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_params: set[nn.Parameter],
|
||||
) -> Iterator[nn.Parameter]:
|
||||
"""
|
||||
Return an iterator over the original parameters in ``module``.
|
||||
|
|
@ -1167,7 +1154,7 @@ def _get_orig_params(
|
|||
|
||||
def _check_orig_params_flattened(
|
||||
fsdp_module,
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_params: set[nn.Parameter],
|
||||
) -> None:
|
||||
"""
|
||||
Check that original parameters in ``fsdp_module`` have been flattened.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import collections
|
||||
from typing import Deque, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ class _FreeEventQueue:
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queue: Deque[torch.Event] = collections.deque()
|
||||
self._queue: collections.deque[torch.Event] = collections.deque()
|
||||
self._max_num_inflight_all_gathers = 2 # empirically chosen
|
||||
|
||||
def enqueue(self, free_event: torch.Event) -> None:
|
||||
|
|
|
|||
|
|
@ -3,23 +3,10 @@ import copy
|
|||
import functools
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
NamedTuple,
|
||||
no_type_check,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, cast, NamedTuple, no_type_check, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -66,11 +53,11 @@ logger = logging.getLogger(__name__)
|
|||
class FSDPParamInfo:
|
||||
state: _FSDPState
|
||||
handle: FlatParamHandle
|
||||
param_indices: Dict[str, int]
|
||||
param_requires_grad: List[bool]
|
||||
param_indices: dict[str, int]
|
||||
param_requires_grad: list[bool]
|
||||
|
||||
|
||||
def sorted_items(dictionary: Dict[str, Any]) -> Iterator[tuple[str, Any]]:
|
||||
def sorted_items(dictionary: dict[str, Any]) -> Iterator[tuple[str, Any]]:
|
||||
keys = sorted(dictionary.keys())
|
||||
for k in keys:
|
||||
yield k, dictionary[k]
|
||||
|
|
@ -97,9 +84,9 @@ class _ConsolidatedOptimState:
|
|||
name to its value.
|
||||
"""
|
||||
|
||||
tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
non_tensor_state: Dict[str, Any] = field(default_factory=dict)
|
||||
tensor_state: dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
zero_dim_tensor_state: dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
non_tensor_state: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class _PosDimTensorInfo(NamedTuple):
|
||||
|
|
@ -131,11 +118,11 @@ class _OptimStateKey(NamedTuple):
|
|||
|
||||
def _unflatten_optim_state(
|
||||
fsdp_param_info: FSDPParamInfo,
|
||||
flat_param_state: Dict[str, Any],
|
||||
flat_param_state: dict[str, Any],
|
||||
to_save: bool,
|
||||
shard_state: bool,
|
||||
cpu_offload: bool,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Unflattens the optimizer state, consisting of the "state" part and the
|
||||
"param_groups" part. Unflattening the "state" part involves consolidating
|
||||
|
|
@ -190,7 +177,7 @@ def _is_zero_dim_tensor(x: Any) -> bool:
|
|||
|
||||
def _communicate_optim_state(
|
||||
fsdp_param_info: FSDPParamInfo,
|
||||
flat_param_state: Dict[str, Any],
|
||||
flat_param_state: dict[str, Any],
|
||||
) -> _ConsolidatedOptimState:
|
||||
"""
|
||||
Communicates the optimizer state for a flat parameter across ranks. All
|
||||
|
|
@ -262,7 +249,7 @@ def _unflatten_communicated_optim_state(
|
|||
fsdp_param_info: FSDPParamInfo,
|
||||
state: _ConsolidatedOptimState,
|
||||
shard_state: bool,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Unflattens the communicated optimizer state (given by ``tensor_state``,
|
||||
``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat
|
||||
|
|
@ -283,8 +270,8 @@ def _unflatten_communicated_optim_state(
|
|||
fsdp_state = fsdp_param_info.state
|
||||
handle = fsdp_param_info.handle
|
||||
flat_param = handle.flat_param
|
||||
unflat_param_state: List[Dict[str, Any]] = []
|
||||
flat_param_views: Dict[str, Iterator] = {}
|
||||
unflat_param_state: list[dict[str, Any]] = []
|
||||
flat_param_views: dict[str, Iterator] = {}
|
||||
num_unflat_params = flat_param._num_params
|
||||
tensor_state, zero_dim_tensor_state, non_tensor_state = (
|
||||
state.tensor_state,
|
||||
|
|
@ -337,10 +324,10 @@ def _unflatten_communicated_optim_state(
|
|||
|
||||
def _broadcast_processed_state(
|
||||
fsdp_state: _FSDPState,
|
||||
optim_state: Dict[str, Any],
|
||||
optim_state: dict[str, Any],
|
||||
group: Optional[dist.ProcessGroup],
|
||||
) -> Dict[str, Any]:
|
||||
objects: List[Any] = [None]
|
||||
) -> dict[str, Any]:
|
||||
objects: list[Any] = [None]
|
||||
if dist.get_rank(group) == 0:
|
||||
objects[0] = tree_map_only(
|
||||
torch.Tensor,
|
||||
|
|
@ -380,8 +367,8 @@ def _broadcast_state(
|
|||
def _shard_orig_param_state(
|
||||
fsdp_param_info: FSDPParamInfo,
|
||||
fqn: str,
|
||||
optim_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
optim_state: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Shard the optimizer state for the original parameter with the name ``fqn``.
|
||||
This API should only be used when ``use_orig_params`` is True.
|
||||
|
|
@ -398,7 +385,7 @@ def _shard_orig_param_state(
|
|||
if not shard_param_info.in_shard:
|
||||
return {}
|
||||
# Flatten and shard the state.
|
||||
new_optim_state: Dict[str, Any] = {}
|
||||
new_optim_state: dict[str, Any] = {}
|
||||
intra_param_start_idx = shard_param_info.intra_param_start_idx
|
||||
intra_param_end_idx = shard_param_info.intra_param_end_idx
|
||||
for state_name, value in optim_state.items():
|
||||
|
|
@ -413,13 +400,13 @@ def _shard_orig_param_state(
|
|||
|
||||
|
||||
def _flatten_optim_state_dict(
|
||||
optim_state_dict: Dict[str, Any],
|
||||
optim_state_dict: dict[str, Any],
|
||||
model: nn.Module,
|
||||
use_orig_params: bool = False,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
rank0_only: bool = False,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Flattens the full optimizer state dict, still keying by unflattened parameter
|
||||
names.
|
||||
|
|
@ -461,7 +448,7 @@ def _flatten_optim_state_dict(
|
|||
unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group)
|
||||
|
||||
# Construct the "state" part
|
||||
flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {}
|
||||
flat_osd_state: dict[Union[_OptimStateKey, str], Any] = {}
|
||||
unflat_osd_state = unflat_osd["state"]
|
||||
all_state_keys = set(unflat_osd_state.keys())
|
||||
|
||||
|
|
@ -557,9 +544,9 @@ def _flatten_optim_state_dict(
|
|||
|
||||
def _flatten_optim_state(
|
||||
fsdp_param_info: FSDPParamInfo,
|
||||
unflat_osd_state: Dict[str, Dict[str, Any]],
|
||||
unflat_param_names: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
unflat_osd_state: dict[str, dict[str, Any]],
|
||||
unflat_param_names: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Flattens the optimizer state in ``full_optim_state_dict`` for a single
|
||||
flat parameter in ``fsdp_param_info`` corresponding to the unflattened
|
||||
|
|
@ -629,7 +616,7 @@ def _flatten_optim_state(
|
|||
assert state_names is not None
|
||||
|
||||
# Flatten the state
|
||||
flat_state: Dict[str, Optional[torch.Tensor]] = {}
|
||||
flat_state: dict[str, Optional[torch.Tensor]] = {}
|
||||
for state_name in state_names:
|
||||
state_values = [
|
||||
unflat_param_state[state_name] if unflat_param_state is not None else None
|
||||
|
|
@ -695,8 +682,8 @@ def _flatten_optim_state(
|
|||
|
||||
def _flatten_tensor_optim_state(
|
||||
state_name: str,
|
||||
pos_dim_tensors: List[torch.Tensor],
|
||||
unflat_param_names: List[str],
|
||||
pos_dim_tensors: list[torch.Tensor],
|
||||
unflat_param_names: list[str],
|
||||
unflat_param_shapes: Sequence[torch.Size],
|
||||
handle: FlatParamHandle,
|
||||
) -> torch.Tensor:
|
||||
|
|
@ -780,8 +767,8 @@ def _flatten_tensor_optim_state(
|
|||
|
||||
def _flatten_zero_dim_tensor_optim_state(
|
||||
state_name: str,
|
||||
zero_dim_tensors: List[torch.Tensor],
|
||||
unflat_param_names: List[str],
|
||||
zero_dim_tensors: list[torch.Tensor],
|
||||
unflat_param_names: list[str],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Flattens the zero-dimension tensor optimizer state given by the values
|
||||
|
|
@ -834,8 +821,8 @@ def _flatten_zero_dim_tensor_optim_state(
|
|||
|
||||
def _flatten_non_tensor_optim_state(
|
||||
state_name: str,
|
||||
non_tensors: List[Any],
|
||||
unflat_param_names: List[str],
|
||||
non_tensors: list[Any],
|
||||
unflat_param_names: list[str],
|
||||
) -> Any:
|
||||
"""
|
||||
Flattens the non-tensor optimizer state given by the values ``non_tensors``
|
||||
|
|
@ -872,18 +859,18 @@ def _flatten_non_tensor_optim_state(
|
|||
|
||||
|
||||
def _rekey_sharded_optim_state_dict(
|
||||
sharded_osd: Dict[str, Any],
|
||||
sharded_osd: dict[str, Any],
|
||||
model: nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[nn.Parameter],
|
||||
]
|
||||
],
|
||||
using_optim_input: bool,
|
||||
is_named_optimizer: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Rekeys the optimizer state dict from unflattened parameter names to flat
|
||||
parameter IDs according to the calling rank's ``optim``, which may be
|
||||
|
|
@ -892,8 +879,8 @@ def _rekey_sharded_optim_state_dict(
|
|||
"""
|
||||
param_to_fqns = _get_param_to_fqns(model)
|
||||
flat_param_to_fqn = _get_flat_param_to_fqn(model)
|
||||
param_to_param_key: Dict[nn.Parameter, Union[int, str]] = cast(
|
||||
Dict[nn.Parameter, Union[int, str]],
|
||||
param_to_param_key: dict[nn.Parameter, Union[int, str]] = cast(
|
||||
dict[nn.Parameter, Union[int, str]],
|
||||
(
|
||||
_get_param_to_param_id_from_optim_input(model, optim_input)
|
||||
if using_optim_input
|
||||
|
|
@ -907,10 +894,10 @@ def _rekey_sharded_optim_state_dict(
|
|||
# passed to the optimizer
|
||||
assert len(param_to_param_key) <= len(param_to_fqns)
|
||||
|
||||
unflat_param_names_to_flat_param_key: Dict[
|
||||
unflat_param_names_to_flat_param_key: dict[
|
||||
tuple[str, ...], Union[int, str]
|
||||
] = {} # for "state"
|
||||
unflat_param_name_to_flat_param_key: Dict[
|
||||
unflat_param_name_to_flat_param_key: dict[
|
||||
str, Union[int, str]
|
||||
] = {} # for "param_groups"
|
||||
for param, unflat_param_names in param_to_fqns.items():
|
||||
|
|
@ -923,7 +910,7 @@ def _rekey_sharded_optim_state_dict(
|
|||
unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key
|
||||
|
||||
sharded_osd_state = sharded_osd["state"]
|
||||
rekeyed_osd_state: Dict[Union[str, int], Any] = {}
|
||||
rekeyed_osd_state: dict[Union[str, int], Any] = {}
|
||||
for key, param_state in sharded_osd_state.items():
|
||||
if isinstance(key, str):
|
||||
rekeyed_osd_state[key] = param_state
|
||||
|
|
@ -935,7 +922,7 @@ def _rekey_sharded_optim_state_dict(
|
|||
|
||||
# Only process param_groups if it exists in sharded_osd
|
||||
if "param_groups" in sharded_osd:
|
||||
rekeyed_osd_param_groups: List[Dict[str, Any]] = []
|
||||
rekeyed_osd_param_groups: list[dict[str, Any]] = []
|
||||
for unflat_param_group in sharded_osd["param_groups"]:
|
||||
flat_param_group = copy.deepcopy(unflat_param_group)
|
||||
flat_param_keys = sorted(
|
||||
|
|
@ -955,11 +942,11 @@ def _get_param_id_to_param_from_optim_input(
|
|||
model: nn.Module,
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[nn.Parameter],
|
||||
]
|
||||
] = None,
|
||||
) -> Dict[int, nn.Parameter]:
|
||||
) -> dict[int, nn.Parameter]:
|
||||
"""
|
||||
Constructs a mapping from parameter IDs to parameters. This may be used
|
||||
both for models with ``FlatParameter`` s and without.
|
||||
|
|
@ -993,7 +980,7 @@ def _get_param_id_to_param_from_optim_input(
|
|||
if optim_input is None:
|
||||
return dict(enumerate(model.parameters()))
|
||||
try:
|
||||
params = cast(List[nn.Parameter], list(optim_input))
|
||||
params = cast(list[nn.Parameter], list(optim_input))
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Optimizer input should be an iterable of Tensors or dicts, "
|
||||
|
|
@ -1013,7 +1000,7 @@ def _get_param_id_to_param_from_optim_input(
|
|||
if all_tensors:
|
||||
return dict(enumerate(params))
|
||||
assert all_dicts
|
||||
param_id_to_param: List[nn.Parameter] = []
|
||||
param_id_to_param: list[nn.Parameter] = []
|
||||
for param_group in params:
|
||||
has_params_key = "params" in param_group # type: ignore[operator]
|
||||
assert has_params_key, (
|
||||
|
|
@ -1026,7 +1013,7 @@ def _get_param_id_to_param_from_optim_input(
|
|||
return dict(enumerate(param_id_to_param))
|
||||
|
||||
|
||||
def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]:
|
||||
def _get_flat_param_to_fqn(model: torch.nn.Module) -> dict[FlatParameter, str]:
|
||||
"""
|
||||
Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes
|
||||
from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical"
|
||||
|
|
@ -1053,7 +1040,7 @@ def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]:
|
|||
def return_fn(flat_param_to_fqn):
|
||||
return flat_param_to_fqn
|
||||
|
||||
flat_param_to_fqn_ret: Dict[FlatParameter, str] = {}
|
||||
flat_param_to_fqn_ret: dict[FlatParameter, str] = {}
|
||||
return _apply_to_modules(
|
||||
model,
|
||||
module_fn,
|
||||
|
|
@ -1067,16 +1054,16 @@ def _get_param_key_to_param(
|
|||
optim: torch.optim.Optimizer,
|
||||
model: Optional[nn.Module] = None,
|
||||
is_named_optimizer: bool = False,
|
||||
param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
|
||||
flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
|
||||
) -> Dict[Union[int, str], nn.Parameter]:
|
||||
param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None,
|
||||
flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None,
|
||||
) -> dict[Union[int, str], nn.Parameter]:
|
||||
"""
|
||||
Constructs a mapping from parameter keys to parameters. For the regular
|
||||
optimizers, the keys are parameter IDs. For NamedOptimizer, the keys
|
||||
are FQNs. This API may be used both for models with ``FlatParameter`` s and
|
||||
without.
|
||||
"""
|
||||
clean_fqn_to_curr_fqn: Dict[str, str] = {}
|
||||
clean_fqn_to_curr_fqn: dict[str, str] = {}
|
||||
if is_named_optimizer:
|
||||
assert (
|
||||
param_to_fqns is not None and flat_param_to_fqn is not None
|
||||
|
|
@ -1085,7 +1072,7 @@ def _get_param_key_to_param(
|
|||
for key, _ in _named_parameters_with_duplicates(model):
|
||||
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
|
||||
|
||||
param_key_to_param: Dict[Union[str, int], nn.Parameter] = {}
|
||||
param_key_to_param: dict[Union[str, int], nn.Parameter] = {}
|
||||
pid = 0
|
||||
for param_group in optim.param_groups:
|
||||
if is_named_optimizer:
|
||||
|
|
@ -1118,9 +1105,9 @@ def _get_param_to_param_key(
|
|||
optim: torch.optim.Optimizer,
|
||||
model: Optional[nn.Module] = None,
|
||||
is_named_optimizer: bool = False,
|
||||
param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
|
||||
flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
|
||||
) -> Dict[nn.Parameter, Union[int, str]]:
|
||||
param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None,
|
||||
flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None,
|
||||
) -> dict[nn.Parameter, Union[int, str]]:
|
||||
"""
|
||||
Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API
|
||||
only supports the case where `optim` is a regular optimizer, not NamedOptimizer.
|
||||
|
|
@ -1136,25 +1123,25 @@ def _get_param_to_param_id_from_optim_input(
|
|||
model: nn.Module,
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[nn.Parameter],
|
||||
]
|
||||
] = None,
|
||||
) -> Dict[nn.Parameter, int]:
|
||||
) -> dict[nn.Parameter, int]:
|
||||
"""Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`."""
|
||||
param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input)
|
||||
return {param: param_id for param_id, param in param_id_to_param.items()}
|
||||
|
||||
|
||||
def _check_missing_keys_on_rank(
|
||||
r0_optim_state_keys: List[_OptimStateKey],
|
||||
optim_state_key_to_param_key: Dict[_OptimStateKey, Union[str, int]],
|
||||
param_key_to_param: Dict[Union[str, int], nn.Parameter],
|
||||
r0_optim_state_keys: list[_OptimStateKey],
|
||||
optim_state_key_to_param_key: dict[_OptimStateKey, Union[str, int]],
|
||||
param_key_to_param: dict[Union[str, int], nn.Parameter],
|
||||
group: Optional[dist.ProcessGroup],
|
||||
) -> None:
|
||||
# Ensure that all ranks have at least the optimizer states needed by
|
||||
# rank 0's optimizer
|
||||
missing_keys: List[_OptimStateKey] = []
|
||||
missing_keys: list[_OptimStateKey] = []
|
||||
for r0_optim_state_key in r0_optim_state_keys:
|
||||
if r0_optim_state_key not in optim_state_key_to_param_key:
|
||||
# A parameter from rank 0's optimizer does not exist for this
|
||||
|
|
@ -1179,7 +1166,7 @@ def _check_missing_keys_on_rank(
|
|||
"are missing some of those states"
|
||||
)
|
||||
for rank, keys in enumerate(obj_list):
|
||||
keys = cast(List[_OptimStateKey], keys)
|
||||
keys = cast(list[_OptimStateKey], keys)
|
||||
if len(keys) > 0:
|
||||
error_msg += (
|
||||
f"\nRank {rank} is missing states for the parameters: "
|
||||
|
|
@ -1189,13 +1176,13 @@ def _check_missing_keys_on_rank(
|
|||
|
||||
|
||||
def _map_param_key_to_optim_keys(
|
||||
optim_state_dict: Dict[str, Any],
|
||||
optim_state_dict: dict[str, Any],
|
||||
group: Optional[dist.ProcessGroup],
|
||||
param_key_to_param: Dict[Union[int, str], nn.Parameter],
|
||||
param_to_fqns: Dict[nn.Parameter, List[str]],
|
||||
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
|
||||
param_key_to_param: dict[Union[int, str], nn.Parameter],
|
||||
param_to_fqns: dict[nn.Parameter, list[str]],
|
||||
fqn_to_fsdp_param_info: dict[str, FSDPParamInfo],
|
||||
merge_keys: bool = False,
|
||||
) -> tuple[List[_OptimStateKey], Dict[_OptimStateKey, Union[int, str]]]:
|
||||
) -> tuple[list[_OptimStateKey], dict[_OptimStateKey, Union[int, str]]]:
|
||||
"""
|
||||
Construct the local mapping between the ``_OptimStateKey`` and parameter keys
|
||||
and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0
|
||||
|
|
@ -1203,8 +1190,8 @@ def _map_param_key_to_optim_keys(
|
|||
Note that ``merge_keys`` should equal to ``use_orig_params``.
|
||||
"""
|
||||
rank = dist.get_rank(group)
|
||||
optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]] = {} # local
|
||||
all_optim_state_keys: List[_OptimStateKey] = []
|
||||
optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]] = {} # local
|
||||
all_optim_state_keys: list[_OptimStateKey] = []
|
||||
|
||||
for param_key, param in param_key_to_param.items():
|
||||
# Do not include parameters without state to avoid empty mappings
|
||||
|
|
@ -1228,7 +1215,7 @@ def _map_param_key_to_optim_keys(
|
|||
optim_state_key_to_param_key[optim_state_key] = param_key
|
||||
|
||||
if merge_keys:
|
||||
all_keys: List[List[_OptimStateKey]] = [
|
||||
all_keys: list[list[_OptimStateKey]] = [
|
||||
[] for _ in range(dist.get_world_size(group))
|
||||
]
|
||||
dist.all_gather_object(all_keys, all_optim_state_keys, group=group)
|
||||
|
|
@ -1237,7 +1224,7 @@ def _map_param_key_to_optim_keys(
|
|||
]
|
||||
all_optim_state_keys = sorted(set(merge_all_optim_state_keys))
|
||||
else:
|
||||
key_obj_list: List[Optional[List[_OptimStateKey]]] = (
|
||||
key_obj_list: list[Optional[list[_OptimStateKey]]] = (
|
||||
[all_optim_state_keys] if rank == 0 else [None]
|
||||
)
|
||||
dist.broadcast_object_list(key_obj_list, src=0, group=group)
|
||||
|
|
@ -1254,11 +1241,11 @@ def _map_param_key_to_optim_keys(
|
|||
|
||||
|
||||
def _unflatten_param_groups(
|
||||
state_dict: Dict[str, Any],
|
||||
param_key_to_param: Dict[Union[int, str], nn.Parameter],
|
||||
param_to_fqns: Dict[nn.Parameter, List[str]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
param_groups: List[Dict[str, Any]] = []
|
||||
state_dict: dict[str, Any],
|
||||
param_key_to_param: dict[Union[int, str], nn.Parameter],
|
||||
param_to_fqns: dict[nn.Parameter, list[str]],
|
||||
) -> list[dict[str, Any]]:
|
||||
param_groups: list[dict[str, Any]] = []
|
||||
for flat_param_group in state_dict["param_groups"]:
|
||||
unflat_param_group = copy.deepcopy(flat_param_group)
|
||||
param_group_params = [
|
||||
|
|
@ -1277,7 +1264,7 @@ def _unflatten_param_groups(
|
|||
return param_groups
|
||||
|
||||
|
||||
def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool:
|
||||
def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Returns whether the state_dict is from a NamedOptimizer.
|
||||
This function checks that the keys in the state_dict['state'] are strings
|
||||
|
|
@ -1299,22 +1286,22 @@ def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool:
|
|||
@dataclass
|
||||
class StateInfo:
|
||||
# The key of these dictionaries are the state name, e.g., `exp_avg`.
|
||||
tensors: Dict[str, _PosDimTensorInfo]
|
||||
scalar_tensors: Dict[str, torch.Tensor]
|
||||
non_tensors: Dict[str, Any]
|
||||
tensors: dict[str, _PosDimTensorInfo]
|
||||
scalar_tensors: dict[str, torch.Tensor]
|
||||
non_tensors: dict[str, Any]
|
||||
|
||||
|
||||
def _allgather_state_info(
|
||||
fsdp_state: _FSDPState,
|
||||
input_states: Dict[str, Any],
|
||||
) -> List[Dict[str, StateInfo]]:
|
||||
input_states: dict[str, Any],
|
||||
) -> list[dict[str, StateInfo]]:
|
||||
"""
|
||||
Given the ``input_states``, allgather StateInfo for each state. The function
|
||||
uses all_gather_object to gather StateInfo so no GPU tensors are sent.
|
||||
"""
|
||||
|
||||
processed_state_dict: Dict[str, StateInfo] = {}
|
||||
gathered_state_info: List[Dict[str, StateInfo]] = [
|
||||
processed_state_dict: dict[str, StateInfo] = {}
|
||||
gathered_state_info: list[dict[str, StateInfo]] = [
|
||||
{} for _ in range(fsdp_state.world_size)
|
||||
]
|
||||
|
||||
|
|
@ -1343,10 +1330,10 @@ def _allgather_state_info(
|
|||
|
||||
def _convert_all_state_info(
|
||||
fsdp_param_info: FSDPParamInfo,
|
||||
gathered_state_info: List[Dict[str, StateInfo]],
|
||||
input_states: Dict[str, Any],
|
||||
output_states: Dict[str, Dict[str, Any]],
|
||||
) -> tuple[Optional[torch.dtype], Dict[str, List[Optional[torch.Tensor]]]]:
|
||||
gathered_state_info: list[dict[str, StateInfo]],
|
||||
input_states: dict[str, Any],
|
||||
output_states: dict[str, dict[str, Any]],
|
||||
) -> tuple[Optional[torch.dtype], dict[str, list[Optional[torch.Tensor]]]]:
|
||||
"""
|
||||
Given the ``gathered_state_info`` and ``input_states``, the API converted
|
||||
the StateInfo into the original state if the state is not a non-scalar
|
||||
|
|
@ -1354,20 +1341,20 @@ def _convert_all_state_info(
|
|||
``state_buffer`` in a correct order for later allgather purpose.
|
||||
"""
|
||||
|
||||
state_buffers: Dict[str, List[Optional[torch.Tensor]]] = {}
|
||||
state_buffers: dict[str, list[Optional[torch.Tensor]]] = {}
|
||||
|
||||
for fqn, gathered_state in output_states.items():
|
||||
state_info = [s[fqn] for s in gathered_state_info]
|
||||
all_tensor_states = sorted(
|
||||
{n for state in state_info for n in state.tensors.keys()}
|
||||
)
|
||||
empty_ranks: Set[int] = set()
|
||||
empty_ranks: set[int] = set()
|
||||
dtype: Optional[torch.dtype] = None
|
||||
# First check all the non-scalar states and get the information of
|
||||
# states on each rank.
|
||||
for state_name in all_tensor_states:
|
||||
numels = []
|
||||
_empty_ranks: Set[int] = set()
|
||||
_empty_ranks: set[int] = set()
|
||||
for rank, object_state in enumerate(state_info):
|
||||
numels.append(0)
|
||||
info = object_state.tensors.get(state_name, None)
|
||||
|
|
@ -1426,7 +1413,7 @@ def _convert_all_state_info(
|
|||
|
||||
def _unflatten_orig_param_states(
|
||||
fsdp_param_info: FSDPParamInfo,
|
||||
output_states: Dict[str, Dict[str, Any]],
|
||||
output_states: dict[str, dict[str, Any]],
|
||||
state_name: str,
|
||||
shard_state: bool,
|
||||
to_save: bool,
|
||||
|
|
@ -1498,12 +1485,12 @@ def _unflatten_orig_param_states(
|
|||
|
||||
def _allgather_orig_param_states(
|
||||
fsdp_param_info: FSDPParamInfo,
|
||||
gathered_state_info: List[Dict[str, StateInfo]],
|
||||
input_states: Dict[str, Any],
|
||||
gathered_state_info: list[dict[str, StateInfo]],
|
||||
input_states: dict[str, Any],
|
||||
shard_state: bool,
|
||||
to_save: bool,
|
||||
cpu_offload: bool,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Given the ``gathered_state_info`` and ``input_states``, the API allgathers
|
||||
all tensor states and restore non-tensor states from ``gathered_state_info``.
|
||||
|
|
@ -1515,7 +1502,7 @@ def _allgather_orig_param_states(
|
|||
fsdp_state._device_handle.memory_summary(),
|
||||
)
|
||||
|
||||
output_states: Dict[str, Dict[str, Any]] = {fqn: {} for fqn in input_states.keys()}
|
||||
output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states.keys()}
|
||||
|
||||
dtype, state_buffers = _convert_all_state_info(
|
||||
fsdp_param_info, gathered_state_info, input_states, output_states
|
||||
|
|
@ -1524,7 +1511,7 @@ def _allgather_orig_param_states(
|
|||
if len(state_buffers) == 0:
|
||||
return output_states
|
||||
|
||||
has_state_params: List[bool] = [
|
||||
has_state_params: list[bool] = [
|
||||
True if fqn in output_states else False
|
||||
for fqn, idx in fsdp_param_info.param_indices.items()
|
||||
]
|
||||
|
|
@ -1545,7 +1532,7 @@ def _allgather_orig_param_states(
|
|||
# Synchronize can be slow but this will be easier for us to debug.
|
||||
fsdp_state._device_handle.synchronize()
|
||||
for state_name, buffers in state_buffers.items():
|
||||
local_buffers: List[torch.Tensor] = []
|
||||
local_buffers: list[torch.Tensor] = []
|
||||
begin = fsdp_state.rank * flat_param._sharded_size.numel()
|
||||
# End is inclusive.
|
||||
end = begin + flat_param._sharded_size.numel() - 1
|
||||
|
|
@ -1666,11 +1653,11 @@ def _allgather_orig_param_states(
|
|||
|
||||
def _gather_all_orig_param_state(
|
||||
fsdp_param_info: FSDPParamInfo,
|
||||
input_states: Dict[str, Any],
|
||||
input_states: dict[str, Any],
|
||||
shard_state: bool,
|
||||
to_save: bool,
|
||||
cpu_offload: bool,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Given a optimizer state dict, ``input_states``, which the keys are FQNs to the
|
||||
original parameters (not FlatParameters nor parmeter ID), gather all the
|
||||
|
|
@ -1715,20 +1702,20 @@ def _gather_all_orig_param_state(
|
|||
|
||||
|
||||
def _convert_state_with_orig_params(
|
||||
all_optim_state_keys: List[_OptimStateKey],
|
||||
optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
|
||||
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
|
||||
optim_state_dict: Dict[Union[str, int], Any],
|
||||
all_optim_state_keys: list[_OptimStateKey],
|
||||
optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]],
|
||||
fqn_to_fsdp_param_info: dict[str, FSDPParamInfo],
|
||||
optim_state_dict: dict[Union[str, int], Any],
|
||||
to_save: bool,
|
||||
shard_state: bool,
|
||||
cpu_offload: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
fsdp_osd_state: Dict[str, Any] = {}
|
||||
) -> dict[str, Any]:
|
||||
fsdp_osd_state: dict[str, Any] = {}
|
||||
# This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo
|
||||
# usually corresponds to multiple parameters. We could not use FSDPParamInfo
|
||||
# as the key because FSDPParamInfo is not hashable. As a result, we fall back
|
||||
# to `id(FSDPParamInfo)`, which the type is an integer.
|
||||
all_states: Dict[int, Dict[str, Any]] = {}
|
||||
all_states: dict[int, dict[str, Any]] = {}
|
||||
# Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
|
||||
# across ranks
|
||||
for optim_state_key in all_optim_state_keys:
|
||||
|
|
@ -1806,15 +1793,15 @@ def _convert_state_with_orig_params(
|
|||
|
||||
|
||||
def _convert_state_with_flat_params(
|
||||
all_optim_state_keys: List[_OptimStateKey],
|
||||
optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
|
||||
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
|
||||
optim_state_dict: Dict[Union[str, int], Any],
|
||||
all_optim_state_keys: list[_OptimStateKey],
|
||||
optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]],
|
||||
fqn_to_fsdp_param_info: dict[str, FSDPParamInfo],
|
||||
optim_state_dict: dict[Union[str, int], Any],
|
||||
to_save: bool,
|
||||
shard_state: bool,
|
||||
cpu_offload: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
fsdp_osd_state: Dict[str, Any] = {}
|
||||
) -> dict[str, Any]:
|
||||
fsdp_osd_state: dict[str, Any] = {}
|
||||
# Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
|
||||
# across ranks
|
||||
for optim_state_key in all_optim_state_keys:
|
||||
|
|
@ -1866,10 +1853,10 @@ def _convert_state_with_flat_params(
|
|||
def _optim_state_dict(
|
||||
model: nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
optim_state_dict: Dict[str, Any],
|
||||
optim_state_dict: dict[str, Any],
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[nn.Parameter],
|
||||
]
|
||||
],
|
||||
|
|
@ -1879,7 +1866,7 @@ def _optim_state_dict(
|
|||
using_optim_input: bool,
|
||||
use_orig_params: bool = False,
|
||||
cpu_offload: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Consolidates the optimizer state and returns it as a :class:`dict`
|
||||
following the convention of :meth:`torch.optim.Optimizer.state_dict`,
|
||||
|
|
@ -1940,7 +1927,7 @@ def _optim_state_dict(
|
|||
is_named_optimizer = _is_named_optimizer(optim_state_dict)
|
||||
|
||||
param_key_to_param = cast(
|
||||
Dict[Union[int, str], nn.Parameter],
|
||||
dict[Union[int, str], nn.Parameter],
|
||||
(
|
||||
_get_param_id_to_param_from_optim_input(model, optim_input)
|
||||
if using_optim_input
|
||||
|
|
@ -1985,7 +1972,7 @@ def _optim_state_dict(
|
|||
if not to_save:
|
||||
return {}
|
||||
|
||||
fsdp_osd: Dict[str, Any] = {"state": fsdp_osd_state}
|
||||
fsdp_osd: dict[str, Any] = {"state": fsdp_osd_state}
|
||||
|
||||
flat_param_fqns = set(flat_param_to_fqn.values())
|
||||
for key, value in optim_state_dict["state"].items():
|
||||
|
|
@ -2019,7 +2006,7 @@ def _optim_state_dict(
|
|||
return fsdp_osd
|
||||
|
||||
|
||||
def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
|
||||
def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]:
|
||||
"""
|
||||
Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo``
|
||||
if the param is managed by FSDP. Shared parameters, or original parameters that
|
||||
|
|
@ -2055,7 +2042,7 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
|
|||
def return_fn(fqn_to_param_info):
|
||||
return fqn_to_param_info
|
||||
|
||||
fqn_to_param_info: Dict[str, FSDPParamInfo] = {}
|
||||
fqn_to_param_info: dict[str, FSDPParamInfo] = {}
|
||||
# FlatParameter._fqns stores the local fqn, starting from the root of the
|
||||
# FSDP. Using _apply_to_modules() with model (may not be the FSDP root
|
||||
# module) allows us to construct the global fqn.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import functools
|
||||
import logging
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Dict, List, no_type_check, Optional, Set
|
||||
from typing import Any, Callable, no_type_check, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -57,7 +57,7 @@ class _PrefetchMode(Enum):
|
|||
|
||||
def _get_fsdp_root_states_with_modules(
|
||||
module: nn.Module,
|
||||
) -> tuple[List[_FSDPState], List[nn.Module]]:
|
||||
) -> tuple[list[_FSDPState], list[nn.Module]]:
|
||||
"""
|
||||
Returns a tuple containing:
|
||||
1. A list of the root ``_FSDPState`` instances in the module tree rooted at
|
||||
|
|
@ -70,9 +70,9 @@ def _get_fsdp_root_states_with_modules(
|
|||
must call :func:`_is_fsdp_root` to force a lazy initialization to determine
|
||||
the FSDP root in case lazy initialization has not yet happened.
|
||||
"""
|
||||
fsdp_root_states: List[_FSDPState] = []
|
||||
fsdp_root_modules: List[nn.Module] = []
|
||||
visited_fsdp_states: Set[_FSDPState] = set()
|
||||
fsdp_root_states: list[_FSDPState] = []
|
||||
fsdp_root_modules: list[nn.Module] = []
|
||||
visited_fsdp_states: set[_FSDPState] = set()
|
||||
# NOTE: This function assumes that `module.modules()` proceeds top-down.
|
||||
for submodule in module.modules():
|
||||
optional_state = _get_module_fsdp_state(submodule)
|
||||
|
|
@ -87,7 +87,7 @@ def _get_fsdp_root_states_with_modules(
|
|||
return fsdp_root_states, fsdp_root_modules
|
||||
|
||||
|
||||
def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
|
||||
def _get_fsdp_root_states(module: nn.Module) -> list[_FSDPState]:
|
||||
"""See :func:`_get_fsdp_root_states_with_modules`."""
|
||||
fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
|
||||
return fsdp_root_states
|
||||
|
|
@ -178,7 +178,7 @@ def _share_state_and_init_handle_attrs(
|
|||
handle = root_state._handle
|
||||
if handle:
|
||||
handle.init_flat_param_attributes()
|
||||
attr_name_to_values: Dict[str, Set[Any]] = {}
|
||||
attr_name_to_values: dict[str, set[Any]] = {}
|
||||
for attr_name in HOMOGENEOUS_ATTR_NAMES:
|
||||
attr_name_to_values[attr_name] = set()
|
||||
root_state._all_handles = root_state._exec_order_data.all_handles # share reference
|
||||
|
|
@ -347,8 +347,8 @@ def _pre_forward(
|
|||
unshard_fn: Callable,
|
||||
module: nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
"""
|
||||
Runs the pre-forward logic. This includes an opportunity to unshard
|
||||
currently sharded parameters such as those for the current forward and
|
||||
|
|
@ -1469,7 +1469,7 @@ def _register_post_backward_reshard_only_hook(
|
|||
state: _FSDPState,
|
||||
handle: Optional[FlatParamHandle],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Registers post-backward hooks to reshard flat parameters that do not
|
||||
|
|
@ -1483,7 +1483,7 @@ def _register_post_backward_reshard_only_hook(
|
|||
return
|
||||
# Construct `inp_tensors` lazily to avoid CPU overhead in typical case
|
||||
# where each flat parameter requires gradient
|
||||
inp_tensors: Optional[List[torch.Tensor]] = None
|
||||
inp_tensors: Optional[list[torch.Tensor]] = None
|
||||
if not handle:
|
||||
return
|
||||
flat_param = handle.flat_param
|
||||
|
|
@ -1555,7 +1555,7 @@ def _wait_for_computation_stream(
|
|||
|
||||
|
||||
def _reset_flat_param_grad_info_if_needed(
|
||||
handles: List[FlatParamHandle],
|
||||
handles: list[FlatParamHandle],
|
||||
):
|
||||
"""
|
||||
Clears the original parameters' gradients if needed. This method's CPU
|
||||
|
|
@ -1573,7 +1573,7 @@ def _reset_flat_param_grad_info_if_needed(
|
|||
def _get_buffers_and_dtypes_for_computation(
|
||||
state: _FSDPState,
|
||||
root_module: nn.Module,
|
||||
) -> tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
|
||||
) -> tuple[list[torch.Tensor], list[Optional[torch.dtype]]]:
|
||||
"""
|
||||
Returns all buffers in the module tree rooted at ``root_module`` and a
|
||||
corresponding list of the buffer dtypes for computation. Each buffer dtype
|
||||
|
|
@ -1581,9 +1581,9 @@ def _get_buffers_and_dtypes_for_computation(
|
|||
low precision dtype otherwise.
|
||||
"""
|
||||
_p_assert(state._is_root, "Expects the root to cast buffers")
|
||||
buffers: List[torch.Tensor] = []
|
||||
buffer_dtypes: List[Optional[torch.dtype]] = []
|
||||
visited_buffers: Set[torch.Tensor] = set()
|
||||
buffers: list[torch.Tensor] = []
|
||||
buffer_dtypes: list[Optional[torch.dtype]] = []
|
||||
visited_buffers: set[torch.Tensor] = set()
|
||||
# Traverse the FSDP states bottom-up so that we prefer the owning FSDP
|
||||
# instance's mixed precision setting for each buffer
|
||||
fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(
|
||||
|
|
@ -1605,12 +1605,12 @@ def _get_buffers_and_dtypes_for_computation(
|
|||
@no_type_check
|
||||
def _get_orig_buffer_dtypes(
|
||||
state: _FSDPState,
|
||||
buffer_names: List[str],
|
||||
) -> List[torch.dtype]:
|
||||
buffer_names: list[str],
|
||||
) -> list[torch.dtype]:
|
||||
"""
|
||||
Returns the original buffer types of the given buffer names.
|
||||
"""
|
||||
buffer_dtypes: List[torch.dtype] = []
|
||||
buffer_dtypes: list[torch.dtype] = []
|
||||
for buffer_name in buffer_names:
|
||||
_p_assert(
|
||||
buffer_name in state._buffer_name_to_orig_dtype,
|
||||
|
|
@ -1623,8 +1623,8 @@ def _get_orig_buffer_dtypes(
|
|||
|
||||
|
||||
def _cast_buffers_to_dtype_and_device(
|
||||
buffers: List[torch.Tensor],
|
||||
buffer_dtypes: List[Optional[torch.dtype]],
|
||||
buffers: list[torch.Tensor],
|
||||
buffer_dtypes: list[Optional[torch.dtype]],
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ import contextlib
|
|||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Callable, cast, Dict, Generator, Iterator, List, no_type_check
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Callable, cast, no_type_check
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -170,10 +171,10 @@ def _common_unshard_pre_state_dict_hook(
|
|||
def _common_unshard_post_state_dict_hook(
|
||||
module: nn.Module,
|
||||
fsdp_state: _FSDPState,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
param_hook: Callable,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
The post-state_dict flow that shared by all state_dict types that require
|
||||
``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this
|
||||
|
|
@ -302,9 +303,9 @@ def _full_pre_state_dict_hook(
|
|||
def _full_post_state_dict_hook(
|
||||
module: nn.Module,
|
||||
fsdp_state: _FSDPState,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Hook that runs after model.state_dict() is called before returning result to
|
||||
user. For FSDP, we may have to clone the tensors in state_dict as params go
|
||||
|
|
@ -313,7 +314,7 @@ def _full_post_state_dict_hook(
|
|||
"""
|
||||
|
||||
def param_hook(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
fqn: str,
|
||||
) -> None:
|
||||
|
|
@ -347,7 +348,7 @@ def _full_post_state_dict_hook(
|
|||
def _full_pre_load_state_dict_hook(
|
||||
module: nn.Module,
|
||||
fsdp_state: _FSDPState,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
_lazy_init(fsdp_state, module)
|
||||
|
|
@ -393,9 +394,9 @@ def _local_pre_state_dict_hook(
|
|||
def _local_post_state_dict_hook(
|
||||
module: nn.Module,
|
||||
fsdp_state: _FSDPState,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
This hook create a ShardedTensor from the local flat_param and replace
|
||||
the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
|
||||
|
|
@ -448,7 +449,7 @@ def _local_post_load_state_dict_hook(
|
|||
def _local_pre_load_state_dict_hook(
|
||||
module: nn.Module,
|
||||
fsdp_state: _FSDPState,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -526,15 +527,15 @@ def _sharded_pre_state_dict_hook(
|
|||
def _sharded_post_state_dict_hook(
|
||||
module: nn.Module,
|
||||
fsdp_state: _FSDPState,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
The hook replaces the unflattened, unsharded parameter in the state_dict
|
||||
with a unflattened, sharded parameter (a ShardedTensor).
|
||||
"""
|
||||
|
||||
def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
|
||||
def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str):
|
||||
param = state_dict[fqn]
|
||||
if not fsdp_state._state_dict_config._use_dtensor:
|
||||
sharded_tensor = _ext_chunk_tensor(
|
||||
|
|
@ -574,7 +575,7 @@ def _sharded_post_load_state_dict_hook(
|
|||
def _sharded_pre_load_state_dict_hook(
|
||||
module: nn.Module,
|
||||
fsdp_state: _FSDPState,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -686,10 +687,10 @@ def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
|
|||
@torch.no_grad()
|
||||
def _post_state_dict_hook(
|
||||
module: nn.Module,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
_post_state_dict_hook() is called after the state_dict() of this
|
||||
FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide
|
||||
|
|
@ -802,7 +803,7 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
|
|||
@torch.no_grad()
|
||||
def _pre_load_state_dict_hook(
|
||||
module: nn.Module,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> None:
|
||||
|
|
@ -845,7 +846,7 @@ def _pre_load_state_dict_hook(
|
|||
@torch.no_grad()
|
||||
def _post_load_state_dict_hook(
|
||||
module: nn.Module,
|
||||
incompatible_keys: tuple[List[str], List[str]],
|
||||
incompatible_keys: tuple[list[str], list[str]],
|
||||
*args: Any,
|
||||
) -> None:
|
||||
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
|
||||
|
|
@ -906,7 +907,7 @@ def _register_state_dict_hooks_base(
|
|||
state: _FSDPState,
|
||||
hook_registration_fn_name: str,
|
||||
hook: Callable,
|
||||
hook_registration_fn_kwargs: Dict[str, Any],
|
||||
hook_registration_fn_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Registers ``hook`` using ``hook_registration_fn``."""
|
||||
if not _is_composable(state):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import functools
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -29,7 +29,7 @@ class TracingConfig:
|
|||
"""
|
||||
|
||||
tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
|
||||
concrete_args: Optional[Dict[str, Any]] = None
|
||||
concrete_args: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class _ParamUsageInfo(NamedTuple):
|
||||
|
|
@ -52,7 +52,7 @@ class _ParamUsageInfo(NamedTuple):
|
|||
"""
|
||||
|
||||
module: nn.Module
|
||||
named_params: List[tuple[str, nn.Parameter]]
|
||||
named_params: list[tuple[str, nn.Parameter]]
|
||||
|
||||
|
||||
class _ExecutionInfo:
|
||||
|
|
@ -79,12 +79,12 @@ class _ExecutionInfo:
|
|||
|
||||
def __init__(self, root_module: nn.Module) -> None:
|
||||
self.curr_module: nn.Module = root_module
|
||||
self.module_forward_order: List[nn.Module] = [root_module]
|
||||
self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
|
||||
self.module_forward_order: list[nn.Module] = [root_module]
|
||||
self.module_to_param_usage_infos: dict[nn.Module, list[_ParamUsageInfo]] = {
|
||||
root_module: []
|
||||
}
|
||||
self.param_forward_order: List[nn.Parameter] = []
|
||||
self.visited_params: Set[nn.Parameter] = set()
|
||||
self.param_forward_order: list[nn.Parameter] = []
|
||||
self.visited_params: set[nn.Parameter] = set()
|
||||
|
||||
|
||||
class _ExecOrderTracer:
|
||||
|
|
@ -120,7 +120,7 @@ class _ExecOrderTracer:
|
|||
module: nn.Module,
|
||||
forward: Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Overrides ``call_module`` to save execution information to
|
||||
|
|
@ -160,12 +160,12 @@ class _ExecOrderTracer:
|
|||
self,
|
||||
create_proxy: Callable,
|
||||
exec_info: _ExecutionInfo,
|
||||
fqn_to_param: Dict[str, nn.Parameter],
|
||||
fqn_to_param: dict[str, nn.Parameter],
|
||||
# Below are the expected arguments to `create_proxy()`
|
||||
kind: str,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
|
||||
|
|
@ -210,7 +210,7 @@ class _ExecOrderTracer:
|
|||
curr_module = exec_info.curr_module
|
||||
if kind in ("call_function", "call_method"):
|
||||
if args is not None:
|
||||
named_params: List[tuple[str, nn.Parameter]] = []
|
||||
named_params: list[tuple[str, nn.Parameter]] = []
|
||||
for arg in args:
|
||||
if (
|
||||
isinstance(arg, torch.fx.Proxy)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ imports. For brevity, we may import the file as ``traversal_utils``.
|
|||
"""
|
||||
|
||||
import collections
|
||||
from typing import Deque, List, Set
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.contract import _get_registry
|
||||
|
|
@ -48,7 +47,7 @@ def _composable(module: nn.Module) -> bool:
|
|||
# `FlatParameter` registration, which is not needed for `use_orig_params=True`.
|
||||
def _get_fsdp_states_with_modules(
|
||||
module: nn.Module,
|
||||
) -> tuple[List[_FSDPState], List[nn.Module]]:
|
||||
) -> tuple[list[_FSDPState], list[nn.Module]]:
|
||||
"""
|
||||
Returns a tuple containing:
|
||||
1. A list of the ``_FSDPState`` instances in the module tree rooted at
|
||||
|
|
@ -65,19 +64,19 @@ def _get_fsdp_states_with_modules(
|
|||
NOTE: The traversal does not proceed into any module annotated by an
|
||||
incompatible API (e.g. ``replicate``).
|
||||
"""
|
||||
fsdp_states: List[_FSDPState] = []
|
||||
fsdp_modules: List[nn.Module] = []
|
||||
fsdp_states: list[_FSDPState] = []
|
||||
fsdp_modules: list[nn.Module] = []
|
||||
# Track the visited FSDP states since multiple modules may share the same
|
||||
# one and we want to return a de-duplicated list
|
||||
visited_fsdp_states: Set[_FSDPState] = set()
|
||||
visited_fsdp_states: set[_FSDPState] = set()
|
||||
# Track the visited modules in case of shared modules, which implies the
|
||||
# module graph is no longer a tree
|
||||
visited_modules: Set[nn.Module] = set()
|
||||
visited_modules: set[nn.Module] = set()
|
||||
|
||||
# Perform depth-first search from `module` to ensure that we do not
|
||||
# traverse into an incompatible API's subtree (use DFS instead of BFS to
|
||||
# match `.modules()` order)
|
||||
deque: Deque[nn.Module] = collections.deque([module])
|
||||
deque: collections.deque[nn.Module] = collections.deque([module])
|
||||
while deque:
|
||||
submodule = deque.popleft()
|
||||
visited_modules.add(submodule)
|
||||
|
|
@ -94,13 +93,13 @@ def _get_fsdp_states_with_modules(
|
|||
return fsdp_states, fsdp_modules
|
||||
|
||||
|
||||
def _get_fsdp_states(module: nn.Module) -> List[_FSDPState]:
|
||||
def _get_fsdp_states(module: nn.Module) -> list[_FSDPState]:
|
||||
"""See :func:`_get_fsdp_states_with_modules`."""
|
||||
fsdp_states, _ = _get_fsdp_states_with_modules(module)
|
||||
return fsdp_states
|
||||
|
||||
|
||||
def _get_fsdp_handles(module: nn.Module) -> List:
|
||||
def _get_fsdp_handles(module: nn.Module) -> list:
|
||||
"""
|
||||
Returns all ``FlatParamHandle`` s in the module tree rooted at ``module``
|
||||
following the rules in :func:`_get_fsdp_state`.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import warnings
|
||||
from typing import cast, Generator
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import functools
|
|||
import inspect
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Set, Type, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
|
|
@ -25,9 +25,9 @@ from torch.distributed.fsdp.wrap import (
|
|||
def _auto_wrap(
|
||||
root_module: nn.Module,
|
||||
policy: Union[Callable, _Policy],
|
||||
ignored_modules: Set[nn.Module],
|
||||
ignored_params: Set[nn.Parameter],
|
||||
root_kwargs: Dict[str, Any],
|
||||
ignored_modules: set[nn.Module],
|
||||
ignored_params: set[nn.Parameter],
|
||||
root_kwargs: dict[str, Any],
|
||||
fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
|
||||
):
|
||||
"""
|
||||
|
|
@ -111,7 +111,7 @@ def _check_nested_wrapping(root_module: nn.Module):
|
|||
|
||||
|
||||
def _warn_on_overridden_mixed_precision(
|
||||
overridden_module_classes: Set[Type[nn.Module]],
|
||||
overridden_module_classes: set[type[nn.Module]],
|
||||
):
|
||||
if len(overridden_module_classes) == 0:
|
||||
return
|
||||
|
|
@ -125,8 +125,8 @@ def _warn_on_overridden_mixed_precision(
|
|||
|
||||
def _validate_frozen_params(
|
||||
root_module: nn.Module,
|
||||
modules_to_wrap: Set[nn.Module],
|
||||
ignored_params: Set[nn.Parameter],
|
||||
modules_to_wrap: set[nn.Module],
|
||||
ignored_params: set[nn.Parameter],
|
||||
use_orig_params: bool,
|
||||
):
|
||||
"""
|
||||
|
|
@ -136,15 +136,15 @@ def _validate_frozen_params(
|
|||
recommended for ``use_orig_params=True`` (user warning).
|
||||
"""
|
||||
post_order_named_modules = _get_post_order_named_modules(root_module)
|
||||
visited_modules: Set[nn.Module] = set()
|
||||
visited_modules: set[nn.Module] = set()
|
||||
for module_name, module in post_order_named_modules:
|
||||
if module in modules_to_wrap:
|
||||
param_to_fqn = _get_managed_param_to_fqn(
|
||||
module, ignored_params, visited_modules, module_name
|
||||
)
|
||||
frozen_param_fqns: List[str] = []
|
||||
frozen_param_fqns: list[str] = []
|
||||
frozen_param_numel = 0
|
||||
nonfrozen_param_fqns: List[str] = []
|
||||
nonfrozen_param_fqns: list[str] = []
|
||||
nonfrozen_param_numel = 0
|
||||
for param, fqn in param_to_fqn.items():
|
||||
if param.requires_grad:
|
||||
|
|
@ -178,7 +178,7 @@ def _validate_frozen_params(
|
|||
|
||||
def _get_post_order_named_modules(
|
||||
root_module: nn.Module,
|
||||
) -> List[tuple[str, nn.Module]]:
|
||||
) -> list[tuple[str, nn.Module]]:
|
||||
"""
|
||||
This returns the named modules following a post-order traversal, which is a
|
||||
valid reverse topological sort. We achieve this using the reverse of a
|
||||
|
|
@ -202,7 +202,7 @@ def _get_post_order_named_modules(
|
|||
visited_modules = {root_module}
|
||||
stack = [("", root_module)]
|
||||
# Append and reverse at the end for linear-time algorithm
|
||||
reverse_post_order_named_modules: List[tuple[str, nn.Module]] = []
|
||||
reverse_post_order_named_modules: list[tuple[str, nn.Module]] = []
|
||||
while stack:
|
||||
module_name, module = stack.pop()
|
||||
reverse_post_order_named_modules.append((module_name, module))
|
||||
|
|
@ -220,10 +220,10 @@ def _get_post_order_named_modules(
|
|||
|
||||
def _get_managed_param_to_fqn(
|
||||
module_to_wrap: nn.Module,
|
||||
ignored_params: Set[nn.Parameter],
|
||||
visited_modules: Set[nn.Module],
|
||||
ignored_params: set[nn.Parameter],
|
||||
visited_modules: set[nn.Module],
|
||||
root_prefix: str,
|
||||
) -> Dict[nn.Parameter, str]:
|
||||
) -> dict[nn.Parameter, str]:
|
||||
"""
|
||||
This returns a dict that maps managed parameter to its FQN for the given
|
||||
``module_to_wrap``. The dict's keys are exactly the parameters that would
|
||||
|
|
@ -238,7 +238,7 @@ def _get_managed_param_to_fqn(
|
|||
on the full module tree in one shot. Given those differences, we do not try
|
||||
to unify the two.
|
||||
"""
|
||||
param_to_fqn: Dict[nn.Parameter, str] = {}
|
||||
param_to_fqn: dict[nn.Parameter, str] = {}
|
||||
# Run BFS (or any tree traversal works)
|
||||
queue = collections.deque([(module_to_wrap, root_prefix)])
|
||||
visited_modules.add(module_to_wrap)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@ This file includes public APIs for FSDP such as the classes used for the
|
|||
constructor arguments.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Optional, Sequence, Type
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
|
@ -223,7 +224,7 @@ class MixedPrecision:
|
|||
keep_low_precision_grads: bool = False
|
||||
cast_forward_inputs: bool = False
|
||||
cast_root_forward_inputs: bool = True
|
||||
_module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)
|
||||
_module_classes_to_ignore: Sequence[type[torch.nn.Module]] = (_BatchNorm,)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -6,19 +6,10 @@ import functools
|
|||
import math
|
||||
import traceback
|
||||
import warnings
|
||||
from collections.abc import Generator, Iterable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -562,7 +553,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
def fsdp_modules(
|
||||
module: nn.Module,
|
||||
root_only: bool = False,
|
||||
) -> List["FullyShardedDataParallel"]:
|
||||
) -> list["FullyShardedDataParallel"]:
|
||||
"""Return all nested FSDP instances.
|
||||
|
||||
This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``.
|
||||
|
|
@ -1013,7 +1004,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
param_name = param_name.replace(FSDP_PREFIX, "")
|
||||
yield (param_name, param)
|
||||
|
||||
def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
|
||||
def _assert_state(self, state: Union[TrainingState, list[TrainingState]]) -> None:
|
||||
"""Assert we are in the given state."""
|
||||
# Since assert can be turned off and this error checking
|
||||
# is really important, we use explicit error checking
|
||||
|
|
@ -1135,7 +1126,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
# iteration order and hence deterministic total norm computation
|
||||
sharded_params = []
|
||||
nonsharded_params = []
|
||||
grads: List[torch.Tensor] = []
|
||||
grads: list[torch.Tensor] = []
|
||||
for handle in self._all_handles:
|
||||
if handle.uses_sharded_strategy:
|
||||
target_set = sharded_params_set
|
||||
|
|
@ -1257,10 +1248,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
def _optim_state_dict_impl(
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
optim_state_dict: Dict[str, Any],
|
||||
optim_state_dict: dict[str, Any],
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[torch.nn.Parameter],
|
||||
]
|
||||
] = None,
|
||||
|
|
@ -1270,7 +1261,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
cpu_offload: bool = True,
|
||||
*,
|
||||
_stacklevel: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Transform the state-dict of an optimizer corresponding to a sharded model.
|
||||
|
||||
This is the internal API that is used by all the optim_state_dict implementations.
|
||||
|
|
@ -1312,11 +1303,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@staticmethod
|
||||
def _optim_state_dict_to_load_impl(
|
||||
optim_state_dict: Dict[str, Any],
|
||||
optim_state_dict: dict[str, Any],
|
||||
model: torch.nn.Module,
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[torch.nn.Parameter],
|
||||
]
|
||||
] = None,
|
||||
|
|
@ -1325,7 +1316,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
rank0_only: bool = False,
|
||||
is_named_optimizer: bool = False,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
|
||||
|
||||
|
|
@ -1376,13 +1367,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
optim: torch.optim.Optimizer,
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[torch.nn.Parameter],
|
||||
]
|
||||
] = None,
|
||||
rank0_only: bool = True,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Return the full optimizer state-dict.
|
||||
|
||||
Consolidates the full optimizer state on rank 0 and returns it
|
||||
|
|
@ -1451,7 +1442,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Return the optimizer state-dict in its sharded form.
|
||||
|
||||
The API is similar to :meth:`full_optim_state_dict` but this API chunks
|
||||
|
|
@ -1482,16 +1473,16 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@staticmethod
|
||||
def shard_full_optim_state_dict(
|
||||
full_optim_state_dict: Dict[str, Any],
|
||||
full_optim_state_dict: dict[str, Any],
|
||||
model: torch.nn.Module,
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[torch.nn.Parameter],
|
||||
]
|
||||
] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Shard a full optimizer state-dict.
|
||||
|
||||
Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened
|
||||
|
|
@ -1561,10 +1552,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@staticmethod
|
||||
def flatten_sharded_optim_state_dict(
|
||||
sharded_optim_state_dict: Dict[str, Any],
|
||||
sharded_optim_state_dict: dict[str, Any],
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Flatten a sharded optimizer state-dict.
|
||||
|
||||
The API is similar to :meth:`shard_full_optim_state_dict`. The only
|
||||
|
|
@ -1600,17 +1591,17 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@staticmethod
|
||||
def scatter_full_optim_state_dict(
|
||||
full_optim_state_dict: Optional[Dict[str, Any]],
|
||||
full_optim_state_dict: Optional[dict[str, Any]],
|
||||
model: torch.nn.Module,
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[torch.nn.Parameter],
|
||||
]
|
||||
] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
group: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Scatter the full optimizer state dict from rank 0 to all other ranks.
|
||||
|
||||
Returns the sharded optimizer state dict on each rank.
|
||||
|
|
@ -1684,17 +1675,17 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
|
||||
@staticmethod
|
||||
def rekey_optim_state_dict(
|
||||
optim_state_dict: Dict[str, Any],
|
||||
optim_state_dict: dict[str, Any],
|
||||
optim_state_key_type: OptimStateKeyType,
|
||||
model: torch.nn.Module,
|
||||
optim_input: Optional[
|
||||
Union[
|
||||
List[Dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
Iterable[torch.nn.Parameter],
|
||||
]
|
||||
] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``.
|
||||
|
||||
This can be used to achieve compatibility between optimizer state dicts from models with FSDP
|
||||
|
|
@ -1762,7 +1753,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
else _get_param_key_to_param(optim)
|
||||
)
|
||||
param_to_param_name = _get_param_to_fqn(model)
|
||||
param_id_to_param_name: List[str] = [
|
||||
param_id_to_param_name: list[str] = [
|
||||
param_to_param_name[param] for param in param_id_to_param.values()
|
||||
]
|
||||
new_osd["state"] = {
|
||||
|
|
@ -1811,9 +1802,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
def optim_state_dict(
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
optim_state_dict: Optional[Dict[str, Any]] = None,
|
||||
optim_state_dict: Optional[dict[str, Any]] = None,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the state-dict of an optimizer corresponding to a sharded model.
|
||||
|
||||
|
|
@ -1907,11 +1898,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
def optim_state_dict_to_load(
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
optim_state_dict: Dict[str, Any],
|
||||
optim_state_dict: dict[str, Any],
|
||||
is_named_optimizer: bool = False,
|
||||
load_directly: bool = False,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
|
||||
|
||||
|
|
@ -2146,7 +2137,7 @@ def _get_grad_norm(
|
|||
|
||||
def _get_param_to_fqn(
|
||||
model: torch.nn.Module,
|
||||
) -> Dict[torch.nn.Parameter, str]:
|
||||
) -> dict[torch.nn.Parameter, str]:
|
||||
"""
|
||||
Construct a mapping from parameters to their parameter names.
|
||||
|
||||
|
|
@ -2178,7 +2169,7 @@ def _get_param_to_fqn(
|
|||
|
||||
def _get_fqn_to_param(
|
||||
model: torch.nn.Module,
|
||||
) -> Dict[str, torch.nn.Parameter]:
|
||||
) -> dict[str, torch.nn.Parameter]:
|
||||
"""Construct the inverse mapping of :meth:`_get_param_to_fqn`."""
|
||||
param_to_param_name = _get_param_to_fqn(model)
|
||||
return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections import abc, defaultdict
|
||||
from typing import Any, Dict, Iterable, List, Optional, overload, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -12,7 +13,7 @@ from torch.distributed.distributed_c10d import ProcessGroup
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _refresh_per_optimizer_state() -> Dict[str, Any]:
|
||||
def _refresh_per_optimizer_state() -> dict[str, Any]:
|
||||
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||
|
||||
|
||||
|
|
@ -36,7 +37,7 @@ class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
|
|||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
assert _is_supported_device(master_tensor)
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||
self._per_device_tensors: dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
|
||||
class ShardedGradScaler(GradScaler):
|
||||
|
|
@ -115,7 +116,7 @@ class ShardedGradScaler(GradScaler):
|
|||
...
|
||||
|
||||
@overload
|
||||
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
|
@ -145,7 +146,7 @@ class ShardedGradScaler(GradScaler):
|
|||
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
|
||||
return scaled_output.type(outputs.dtype)
|
||||
|
||||
stash: List[_GeneralMultiDeviceReplicator] = []
|
||||
stash: list[_GeneralMultiDeviceReplicator] = []
|
||||
|
||||
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
|
||||
if isinstance(val, torch.Tensor):
|
||||
|
|
@ -175,7 +176,7 @@ class ShardedGradScaler(GradScaler):
|
|||
inv_scale: torch.Tensor,
|
||||
found_inf: torch.Tensor,
|
||||
allow_fp16: bool = True,
|
||||
) -> Dict[torch.device, torch.Tensor]:
|
||||
) -> dict[torch.device, torch.Tensor]:
|
||||
per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
|
||||
per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,19 +7,8 @@
|
|||
import contextlib
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
@ -51,7 +40,7 @@ def _post_order_apply(
|
|||
not changed.
|
||||
"""
|
||||
# Track visited modules to avoid visiting shared modules multiple times
|
||||
visited_modules: Set[nn.Module] = {root_module}
|
||||
visited_modules: set[nn.Module] = {root_module}
|
||||
|
||||
def _post_order_apply_inner(
|
||||
module: nn.Module,
|
||||
|
|
@ -82,7 +71,7 @@ def _post_order_apply(
|
|||
|
||||
def _construct_wrap_fn(
|
||||
root_module: nn.Module,
|
||||
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
|
||||
target_module_to_kwargs: dict[nn.Module, dict[str, Any]],
|
||||
fsdp_fn: Callable,
|
||||
) -> Callable[[nn.Module], Optional[nn.Module]]:
|
||||
"""
|
||||
|
|
@ -104,10 +93,10 @@ def _construct_wrap_fn(
|
|||
|
||||
def _run_mixed_precision_override_policy(
|
||||
root_module: nn.Module,
|
||||
module_classes: Iterable[Type[nn.Module]],
|
||||
ignored_modules: Set[nn.Module],
|
||||
root_kwargs: Dict[str, Any],
|
||||
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
|
||||
module_classes: Iterable[type[nn.Module]],
|
||||
ignored_modules: set[nn.Module],
|
||||
root_kwargs: dict[str, Any],
|
||||
target_module_to_kwargs: dict[nn.Module, dict[str, Any]],
|
||||
):
|
||||
module_classes_tuple = tuple(set(module_classes))
|
||||
for module in root_module.modules():
|
||||
|
|
@ -141,9 +130,9 @@ class _Policy(ABC):
|
|||
def _run_policy(
|
||||
self,
|
||||
root_module: nn.Module,
|
||||
ignored_modules: Set[nn.Module],
|
||||
root_kwargs: Dict[str, Any],
|
||||
) -> Dict[nn.Module, Dict[str, Any]]:
|
||||
ignored_modules: set[nn.Module],
|
||||
root_kwargs: dict[str, Any],
|
||||
) -> dict[nn.Module, dict[str, Any]]:
|
||||
"""
|
||||
This should return a dict ``target_module_to_kwargs`` that maps from
|
||||
each target module to wrap to its kwargs.
|
||||
|
|
@ -155,7 +144,7 @@ def _module_wrap_policy(
|
|||
module: nn.Module,
|
||||
recurse: bool,
|
||||
nonwrapped_numel: int,
|
||||
module_classes: Set[Type[nn.Module]],
|
||||
module_classes: set[type[nn.Module]],
|
||||
) -> bool:
|
||||
"""
|
||||
This auto wrap policy wraps every module that is an instance of any type in
|
||||
|
|
@ -189,7 +178,7 @@ class ModuleWrapPolicy(_Policy):
|
|||
passing in the kwargs given to the root.
|
||||
"""
|
||||
|
||||
def __init__(self, module_classes: Iterable[Type[nn.Module]]):
|
||||
def __init__(self, module_classes: Iterable[type[nn.Module]]):
|
||||
module_classes_set = set(module_classes)
|
||||
self._module_classes = module_classes_set
|
||||
self._module_classes_str = str(module_classes_set)
|
||||
|
|
@ -197,11 +186,11 @@ class ModuleWrapPolicy(_Policy):
|
|||
def _run_policy(
|
||||
self,
|
||||
root_module: nn.Module,
|
||||
ignored_modules: Set[nn.Module],
|
||||
root_kwargs: Dict[str, Any],
|
||||
) -> Dict[nn.Module, Dict[str, Any]]:
|
||||
ignored_modules: set[nn.Module],
|
||||
root_kwargs: dict[str, Any],
|
||||
) -> dict[nn.Module, dict[str, Any]]:
|
||||
module_classes = tuple(self._module_classes)
|
||||
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
|
||||
target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
|
||||
for module in root_module.modules():
|
||||
if module in ignored_modules:
|
||||
continue
|
||||
|
|
@ -245,16 +234,16 @@ class CustomPolicy(_Policy):
|
|||
>>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
|
||||
"""
|
||||
|
||||
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
|
||||
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]):
|
||||
self._lambda_fn = lambda_fn
|
||||
|
||||
def _run_policy(
|
||||
self,
|
||||
root_module: nn.Module,
|
||||
ignored_modules: Set[nn.Module],
|
||||
root_kwargs: Dict[str, Any],
|
||||
) -> Dict[nn.Module, Dict[str, Any]]:
|
||||
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
|
||||
ignored_modules: set[nn.Module],
|
||||
root_kwargs: dict[str, Any],
|
||||
) -> dict[nn.Module, dict[str, Any]]:
|
||||
target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
|
||||
for module in root_module.modules():
|
||||
if module in ignored_modules:
|
||||
continue
|
||||
|
|
@ -307,7 +296,7 @@ def transformer_auto_wrap_policy(
|
|||
module: nn.Module,
|
||||
recurse: bool,
|
||||
nonwrapped_numel: int,
|
||||
transformer_layer_cls: Set[Type[nn.Module]],
|
||||
transformer_layer_cls: set[type[nn.Module]],
|
||||
) -> bool:
|
||||
"""
|
||||
See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
|
||||
|
|
@ -352,8 +341,8 @@ def size_based_auto_wrap_policy(
|
|||
nonwrapped_numel: int,
|
||||
# Additional custom arguments
|
||||
min_num_params: int = int(1e8),
|
||||
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
|
||||
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
|
||||
force_leaf_modules: Optional[set[type[nn.Module]]] = None,
|
||||
exclude_wrap_modules: Optional[set[type[nn.Module]]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
A size-based auto wrap policy.
|
||||
|
|
@ -495,8 +484,8 @@ def _recursive_wrap(
|
|||
module: nn.Module,
|
||||
auto_wrap_policy: Callable,
|
||||
wrapper_cls: Callable,
|
||||
ignored_modules: Set[nn.Module],
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_modules: set[nn.Module],
|
||||
ignored_params: set[nn.Parameter],
|
||||
only_wrap_children: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> tuple[nn.Module, int]:
|
||||
|
|
@ -573,9 +562,9 @@ class _ConfigAutoWrap:
|
|||
|
||||
in_autowrap_context: bool = False # Context flag
|
||||
wrapper_cls: Optional[Callable] = None # The wrapper class
|
||||
kwargs: Dict[str, Any] = {} # Wrapper's args
|
||||
kwargs: dict[str, Any] = {} # Wrapper's args
|
||||
|
||||
def __init__(self, **kwargs: Dict[str, Any]):
|
||||
def __init__(self, **kwargs: dict[str, Any]):
|
||||
self.kwargs = kwargs
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user