PEP585 update - torch/distributed/fsdp (#145162)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145162
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-01-18 14:57:31 -08:00 committed by PyTorch MergeBot
parent 371a361db9
commit c64e657632
26 changed files with 497 additions and 576 deletions

View File

@ -6,22 +6,10 @@ import logging
import traceback
import warnings
import weakref
from collections.abc import Generator, Iterable
from enum import auto, Enum
from functools import partial
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
List,
no_type_check,
Optional,
Set,
Type,
TYPE_CHECKING,
)
from typing import Any, Callable, cast, no_type_check, Optional, TYPE_CHECKING
import torch
import torch.distributed as dist
@ -118,10 +106,10 @@ class _FSDPState(_State):
def __init__(self) -> None:
# TODO: Move all the attributes to this class to enable typing for
# FSDP/fully_shard.
self._ignored_modules: Set[nn.Module] = set()
self._ignored_params: Set[nn.Parameter] = set()
self._ignored_modules: set[nn.Module] = set()
self._ignored_params: set[nn.Parameter] = set()
# Buffer names are cleaned (without wrapper prefixes)
self._ignored_buffer_names: Set[str] = set()
self._ignored_buffer_names: set[str] = set()
self.process_group: Optional[dist.ProcessGroup] = None
self.rank: int = -1
self.world_size: int = -1
@ -129,13 +117,13 @@ class _FSDPState(_State):
self.sharding_strategy = ShardingStrategy.FULL_SHARD
self._use_orig_params: bool = False
self.training_state = TrainingState.IDLE
self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
self._unshard_params_ctx: dict[nn.Module, Generator] = {}
self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
self._state_dict_config: StateDictConfig = FullStateDictConfig()
self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
self._is_root: Optional[bool] = None
self._handle: Optional[flat_param_file.FlatParamHandle] = None
self._fully_sharded_module_to_handle: Dict[
self._fully_sharded_module_to_handle: dict[
nn.Module, Optional[flat_param_file.FlatParamHandle]
] = {}
self.compute_device: Optional[torch.device] = None
@ -149,8 +137,8 @@ class _FSDPState(_State):
self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle()
# All following attributes should only be used for root states:
# Save these static lists to avoid the repeated tree traversals
self._all_fsdp_states: List[_FSDPState] = []
self._all_handles: List[flat_param_file.FlatParamHandle] = []
self._all_fsdp_states: list[_FSDPState] = []
self._all_handles: list[flat_param_file.FlatParamHandle] = []
self._fsdp_extension: Optional[FSDPExtensions] = None
@ -262,7 +250,7 @@ def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
def _named_parameters_with_duplicates(
module: nn.Module, **kwargs: Any
) -> List[tuple[str, nn.Parameter]]:
) -> list[tuple[str, nn.Parameter]]:
"""
This API is required as some modules overwrite `named_parameters()` but do not support
`remove_duplicate`.
@ -282,7 +270,7 @@ def _named_parameters_with_duplicates(
def _get_param_to_fqns(
model: torch.nn.Module,
dedup_shared_params: bool = True,
) -> Dict[nn.Parameter, List[str]]:
) -> dict[nn.Parameter, list[str]]:
"""
Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here,
we use canonical to mean the fully-qualified name assigned to the parameter
@ -352,7 +340,7 @@ def _get_param_to_fqns(
def return_fn(param_to_fqns):
return param_to_fqns
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
param_to_unflat_param_names: dict[torch.nn.Parameter, list[str]] = {}
return _apply_to_modules(
model,
module_fn,
@ -377,7 +365,7 @@ def _log_post_backward_hook(
@no_type_check
def _get_handle_fqns_from_root(
state: _FSDPState, handle: "FlatParamHandle"
) -> Optional[List[str]]:
) -> Optional[list[str]]:
if handle is None:
return None
param_to_fqn = state._exec_order_data.param_to_fqn
@ -392,7 +380,7 @@ def _apply_to_modules(
root_module: torch.nn.Module,
module_fn: Callable,
return_fn: Callable,
filter_fqns: Optional[List[str]] = None,
filter_fqns: Optional[list[str]] = None,
*args,
**kwargs,
):
@ -443,7 +431,7 @@ def _apply_to_modules(
@no_type_check
def _assert_in_training_states(
state: _FSDPState,
training_states: List[TrainingState],
training_states: list[TrainingState],
) -> None:
"""Asserts that FSDP is in the states ``_training_states``."""
# Raise a `ValueError` instead of using `assert` to ensure that these
@ -462,7 +450,7 @@ def _assert_in_training_states(
raise ValueError(msg)
def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
def _get_root_modules(modules: set[nn.Module]) -> set[nn.Module]:
"""
Returns:
Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
@ -470,7 +458,7 @@ def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
words, these are the modules in ``modules`` that are not the child of
any other module in ``modules``.
"""
root_modules: Set[nn.Module] = set()
root_modules: set[nn.Module] = set()
module_to_submodules = {module: set(module.modules()) for module in modules}
for candidate_module in modules:
is_root_module = True
@ -488,12 +476,12 @@ def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
def _override_module_mixed_precision(
root: torch.nn.Module,
module_classes_to_override: Iterable[Type[nn.Module]],
wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006
) -> Set[Type[nn.Module]]:
module_classes_to_override: Iterable[type[nn.Module]],
wrap_override_dict: dict[str, Any] = {"mixed_precision": None}, # noqa: B006
) -> set[type[nn.Module]]:
module_classes_to_override = tuple(set(module_classes_to_override))
# Return a set of the actually overridden module classes
overridden_module_classes: Set[Type[nn.Module]] = set()
overridden_module_classes: set[type[nn.Module]] = set()
for mod in root.modules():
if isinstance(mod, module_classes_to_override):
overridden_module_classes.add(type(mod))

View File

@ -2,9 +2,9 @@
import logging
import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from enum import Enum
from typing import Dict, Iterator, List, Set
import torch
import torch.distributed as dist
@ -28,8 +28,8 @@ class SimpleProfiler:
H2D = "H2D"
D2H = "D2H"
results: Dict[str, float] = defaultdict(float)
profiling: Set[str] = set()
results: dict[str, float] = defaultdict(float)
profiling: set[str] = set()
@classmethod
def reset(cls) -> None:
@ -65,7 +65,7 @@ class SimpleProfiler:
def _get_sharded_module_tree_with_module_name_to_fqns(
model: torch.nn.Module,
) -> tuple[str, Dict[str, List[str]]]:
) -> tuple[str, dict[str, list[str]]]:
"""
It is used for composable fully_shard() code path, it returns
1. sharded module tree info: each line reprents a submodule name that contats the
@ -143,10 +143,10 @@ def _get_sharded_module_tree_with_module_name_to_fqns(
return sharded_tree_info[0], sharded_module_name_to_fqns
# Use List to mutate its value in place while running the recursive functions
sharded_tree_info: List[str] = [
sharded_tree_info: list[str] = [
"",
]
sharded_module_name_to_fqns: Dict[str, List[str]] = {}
sharded_module_name_to_fqns: dict[str, list[str]] = {}
return _apply_to_modules(
model,
module_fn,

View File

@ -1,11 +1,9 @@
from typing import Set
import torch.nn as nn
def _annotate_modules_for_dynamo(
module: nn.Module,
ignored_modules: Set[nn.Module],
ignored_modules: set[nn.Module],
use_orig_params: bool,
) -> None:
"""

View File

@ -2,7 +2,7 @@
import itertools
import warnings
from enum import auto, Enum
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -37,9 +37,9 @@ class _ExecOrderData:
) -> None:
# Tracks the (static) pre-forward order for execution order validation
# and forward prefetching
self.handles_pre_forward_order: List[FlatParamHandle] = []
self.handles_pre_forward_order: list[FlatParamHandle] = []
# Tracks the post-forward order for pre-backward prefetching
self.handles_post_forward_order: List[Optional[FlatParamHandle]] = []
self.handles_post_forward_order: list[Optional[FlatParamHandle]] = []
self._iter = 0
# Gives the max number of backward/forward prefetched all-gathers by a
@ -51,9 +51,9 @@ class _ExecOrderData:
self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL
self.process_group: Optional[dist.ProcessGroup] = None
self.world_size: Optional[int] = None
self.all_handles: List[FlatParamHandle] = []
self.all_handles: list[FlatParamHandle] = []
# Names are prefixed from the root module
self.param_to_fqn: Dict[nn.Parameter, List[str]] = {}
self.param_to_fqn: dict[nn.Parameter, list[str]] = {}
# Current index in the pre-forward execution order
self.current_order_index = 0
self.warn_status = _ExecOrderWarnStatus.NONE
@ -197,7 +197,7 @@ class _ExecOrderData:
num_valid_indices = sum(
(index is not None) for index in optional_local_indices
)
tensor_kwargs: Dict[str, Union[torch.dtype, torch.device]] = {
tensor_kwargs: dict[str, Union[torch.dtype, torch.device]] = {
"dtype": torch.int32,
"device": device,
}
@ -313,7 +313,7 @@ class _ExecOrderData:
corresponding to the handles in ``handle``. An entry in the
returned tuple is ``None`` if the handle is invalid.
"""
indices: List[Optional[int]] = []
indices: list[Optional[int]] = []
if handle:
indices.append(handle._handle_index)
return tuple(indices)
@ -321,13 +321,13 @@ class _ExecOrderData:
def _get_names_from_handle_indices(
self,
handle_indices: tuple[int, ...],
) -> List[List[str]]:
) -> list[list[str]]:
"""
Returns a list of FQNs for each handle in ``handle_indices``. If a
handle index is invalid, then its FQNs are omitted from the returned
list.
"""
fqns: List[List[str]] = []
fqns: list[list[str]] = []
for index in handle_indices:
if index is None or index < 0 or index >= len(self.all_handles):
continue
@ -339,12 +339,12 @@ class _ExecOrderData:
def _get_names_from_handles(
self,
handle: FlatParamHandle,
) -> List[List[str]]:
) -> list[list[str]]:
"""
Returns a list of FQNs for each handle in ``handles_key``. If a handle
is invalid, then its FQNs are omitted from the returned list.
"""
fqns: List[List[str]] = []
fqns: list[list[str]] = []
if handle:
flat_param = handle.flat_param
if flat_param in self.param_to_fqn:

View File

@ -4,24 +4,10 @@ import functools
import logging
import os
import warnings
from collections.abc import Generator, Iterator, Sequence
from enum import auto, Enum
from itertools import accumulate, chain
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterator,
List,
NamedTuple,
no_type_check,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from typing import Any, Callable, cast, NamedTuple, no_type_check, Optional, Union
import torch
import torch.distributed as dist
@ -354,7 +340,7 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
_numels: tuple[int, ...]
_shard_param_infos: tuple[_ShardParamInfo, ...]
_shared_param_infos: tuple[SharedParamInfo, ...]
_modules: Set[nn.Module]
_modules: set[nn.Module]
_shard_numel_padded: int
_local_shard: Tensor
_full_param_padded: Tensor
@ -366,12 +352,12 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
_mp_shard: Tensor
_cpu_grad: Tensor
_saved_grad_shard: Tensor
_params: Optional[List[nn.Parameter]]
_shared_params: Optional[List[nn.Parameter]]
_tensors: Optional[List[Optional[Tensor]]]
_is_grad_none_mask: Optional[List[bool]]
_params: Optional[list[nn.Parameter]]
_shared_params: Optional[list[nn.Parameter]]
_tensors: Optional[list[Optional[Tensor]]]
_is_grad_none_mask: Optional[list[bool]]
_is_padding_mask: List[bool]
_is_padding_mask: list[bool]
def __new__(cls, data=None, requires_grad=True):
assert cls is FlatParameter, "subclasses FlatParameter not supported"
@ -386,17 +372,17 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
def _init_metadata(
cls,
self,
param_infos: List[ParamInfo],
numels: List[int],
shapes: List[torch.Size],
strides: List[tuple[int, ...]],
contiguities: List[bool],
fqns: List[str],
shared_param_infos: List[SharedParamInfo],
param_extensions: List[Optional[Any]],
params: Optional[List[nn.Parameter]],
shared_params: Optional[List[nn.Parameter]],
is_padding_mask: List[bool],
param_infos: list[ParamInfo],
numels: list[int],
shapes: list[torch.Size],
strides: list[tuple[int, ...]],
contiguities: list[bool],
fqns: list[str],
shared_param_infos: list[SharedParamInfo],
param_extensions: list[Optional[Any]],
params: Optional[list[nn.Parameter]],
shared_params: Optional[list[nn.Parameter]],
is_padding_mask: list[bool],
) -> None:
"""
Initialize attributes holding metadata about the original parameters comprising the flat parameter.
@ -426,7 +412,7 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
self._param_extensions = param_extensions
self._is_padding_mask = is_padding_mask
numels_without_padding: List[int] = []
numels_without_padding: list[int] = []
for numel, is_padding in zip(numels, is_padding_mask):
if not is_padding:
numels_without_padding.append(numel)
@ -624,7 +610,7 @@ class FlatParamHandle:
def _init_flat_param_and_metadata(
self,
params: List[Union[Tensor, nn.Parameter]],
params: list[Union[Tensor, nn.Parameter]],
module: nn.Module,
aligned_numel: int,
use_orig_params: bool,
@ -653,20 +639,20 @@ class FlatParamHandle:
params_set = set(params)
# For alignment padding, only `numels` gets strictly non-`None`
# elements, and all other lists get `None` elements for padding.
param_infos: List[ParamInfo] = []
numels: List[int] = []
shapes: List[torch.Size] = []
strides: List[tuple[int, ...]] = []
contiguities: List[bool] = []
fqns: List[str] = []
shared_param_infos: List[SharedParamInfo] = []
shared_param_memo: Dict[
param_infos: list[ParamInfo] = []
numels: list[int] = []
shapes: list[torch.Size] = []
strides: list[tuple[int, ...]] = []
contiguities: list[bool] = []
fqns: list[str] = []
shared_param_infos: list[SharedParamInfo] = []
shared_param_memo: dict[
Union[Tensor, nn.Parameter], tuple[nn.Module, str, str]
] = {}
params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
shared_params: List[Union[Tensor, nn.Parameter]] = []
param_extensions: List[Any] = []
is_padding_mask: List[bool] = []
params_to_flatten: list[Union[Tensor, nn.Parameter]] = []
shared_params: list[Union[Tensor, nn.Parameter]] = []
param_extensions: list[Any] = []
is_padding_mask: list[bool] = []
total_numel = total_numel_without_padding = 0
for submodule_name, submodule in module.named_modules(remove_duplicate=False):
for param_name, param in _named_parameters_with_duplicates(
@ -779,8 +765,8 @@ class FlatParamHandle:
)
def _validate_tensors_to_flatten(
self, tensors: List[Union[Tensor, nn.Parameter]]
) -> Tuple:
self, tensors: list[Union[Tensor, nn.Parameter]]
) -> tuple:
"""Validate the tensors to flatten and returns any necessary metadata."""
dtype: Optional[torch.dtype] = None
# Return as the logical OR over each tensor's value
@ -819,7 +805,7 @@ class FlatParamHandle:
def flatten_tensors(
self,
tensors: List[Tensor],
tensors: list[Tensor],
aligned_numel: int,
) -> Tensor:
"""
@ -841,7 +827,7 @@ class FlatParamHandle:
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
)
dtype, _, device = self._validate_tensors_to_flatten(tensors)
flat_tensors: List[Tensor] = []
flat_tensors: list[Tensor] = []
if aligned_numel > 0:
total_numel = 0
for tensor in tensors:
@ -876,7 +862,7 @@ class FlatParamHandle:
def flatten_tensors_into_flat_param(
self,
tensors: List[Tensor],
tensors: list[Tensor],
aligned_numel: int,
requires_grad: bool,
) -> FlatParameter:
@ -1013,7 +999,7 @@ class FlatParamHandle:
assert len(flat_param_offsets) == len(
self.flat_param._numels_with_padding
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
shard_param_infos: List[_ShardParamInfo] = []
shard_param_infos: list[_ShardParamInfo] = []
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
# into the unsharded flat parameter (inclusive) of the given parameter
@ -1130,7 +1116,7 @@ class FlatParamHandle:
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
def _get_flat_param_offsets(self) -> List[tuple[int, int]]:
def _get_flat_param_offsets(self) -> list[tuple[int, int]]:
"""
Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
@ -1884,7 +1870,7 @@ class FlatParamHandle:
def _get_unflat_views_aligned(
self,
tensor: Optional[Tensor] = None,
) -> List[Tensor]:
) -> list[Tensor]:
"""
Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
@ -1895,11 +1881,11 @@ class FlatParamHandle:
flat_param = self.flat_param
if tensor is None:
tensor = flat_param
splits: List[Tensor] = torch.split(
splits: list[Tensor] = torch.split(
tensor, flat_param._numels_with_padding, dim=0
)
idx = 0
views: List[Tensor] = []
views: list[Tensor] = []
for split, is_padding in zip(splits, flat_param._is_padding_mask):
if is_padding:
continue
@ -2462,7 +2448,7 @@ class FlatParamHandle:
else:
self._use_unsharded_views(as_params=True)
def _get_modules(self) -> Set[nn.Module]:
def _get_modules(self) -> set[nn.Module]:
"""Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
return {pi.module for pi in self.flat_param._param_infos}.union(
{spi.module for spi in self.flat_param._shared_param_infos}
@ -2514,9 +2500,9 @@ class FlatParamHandle:
yield (param_name, module_name)
@property
def _fqns_in_shard(self) -> List[str]:
def _fqns_in_shard(self) -> list[str]:
"""Return the FQNs of the parameters present in this rank's shard."""
fqns_in_shard: List[str] = []
fqns_in_shard: list[str] = []
for fqn, shard_param_info in zip(
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
):
@ -2708,8 +2694,8 @@ def _safe_setattr_tensor_or_param(
def _convert_to_params(
tensors: List[Union[torch.Tensor, nn.Parameter]]
) -> List[nn.Parameter]:
tensors: list[Union[torch.Tensor, nn.Parameter]]
) -> list[nn.Parameter]:
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any, Optional
import torch
import torch.distributed as dist
@ -64,7 +64,7 @@ class FSDPExtensions(ABC):
def pre_load_state_dict_transform(
self,
tensor: torch.Tensor,
) -> tuple[torch.Tensor, List[Shard]]:
) -> tuple[torch.Tensor, list[Shard]]:
"""
This is to be called before loading a *sharded* model state dict and
should return the tensor and list of shards from which to load data.
@ -157,7 +157,7 @@ def _ext_chunk_dtensor(
def _ext_pre_load_state_dict_transform(
tensor: torch.Tensor,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> tuple[torch.Tensor, List[Shard]]:
) -> tuple[torch.Tensor, list[Shard]]:
if fsdp_extension is not None:
return fsdp_extension.pre_load_state_dict_transform(tensor)

View File

@ -1,4 +1,4 @@
from typing import cast, List, NamedTuple, Optional, Union
from typing import cast, NamedTuple, Optional, Union
import torch
import torch.distributed as dist
@ -20,12 +20,12 @@ class AllGatherResult(NamedTuple):
all_gather_event: Optional[torch.Event]
all_gather_work: Optional[dist.distributed_c10d.Work]
# For each parameter, the all-gather input dtype for each input
param_all_gather_input_dtypes: List[List[torch.dtype]]
param_all_gather_input_dtypes: list[list[torch.dtype]]
# For each parameter, the all-gather input numel for each input
param_all_gather_input_numels: List[List[int]]
param_all_gather_input_numels: list[list[int]]
# 1D flattened version of `param_all_gather_input_numels` saved to avoid
# CPU overhead from recomputing
all_gather_input_split_sizes: List[int]
all_gather_input_split_sizes: list[int]
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
@ -47,8 +47,8 @@ lib.define(
@torch.library.impl(lib, "all_gather_copy_in", "Meta")
def all_gather_copy_in_meta(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_inputs: list[torch.Tensor],
inp_split_sizes: list[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
@ -68,8 +68,8 @@ def all_gather_copy_in_meta(
@torch.library.impl(lib, "all_gather_copy_in", "XPU")
@torch.library.impl(lib, "all_gather_copy_in", "CPU")
def all_gather_copy_in_cuda(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_inputs: list[torch.Tensor],
inp_split_sizes: list[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
@ -99,9 +99,9 @@ lib.define(
@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: List[int],
all_gather_input_split_sizes: list[int],
dim: int,
out: List[torch.Tensor],
out: list[torch.Tensor],
) -> None:
torch.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
@ -118,7 +118,7 @@ lib.define(
@torch.library.impl(lib, "chunk_cat", "XPU")
@torch.library.impl(lib, "chunk_cat", "CPU")
def chunk_cat(
tensors: List[torch.Tensor],
tensors: list[torch.Tensor],
dim: int,
num_chunks: int,
out: torch.Tensor,
@ -128,7 +128,7 @@ def chunk_cat(
@torch.no_grad()
def foreach_all_gather(
fsdp_params: List[FSDPParam],
fsdp_params: list[FSDPParam],
group: dist.ProcessGroup,
async_op: bool,
all_gather_copy_in_stream: torch.Stream,
@ -183,8 +183,8 @@ def foreach_all_gather(
@torch.no_grad()
def _get_param_all_gather_inputs(
fsdp_params: List[FSDPParam],
) -> List[List[torch.Tensor]]:
fsdp_params: list[FSDPParam],
) -> list[list[torch.Tensor]]:
if compiled_autograd_enabled():
return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params]
@ -198,10 +198,10 @@ def _get_param_all_gather_inputs(
and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather")
)
param_all_gather_inputs: List[List[torch.Tensor]] = [[] for _ in fsdp_params]
foreach_copy_indices: List[int] = []
foreach_copy_inputs: List[torch.Tensor] = []
foreach_copy_input_numels: List[int] = []
param_all_gather_inputs: list[list[torch.Tensor]] = [[] for _ in fsdp_params]
foreach_copy_indices: list[int] = []
foreach_copy_inputs: list[torch.Tensor] = []
foreach_copy_input_numels: list[int] = []
# 1st pass: for foreach-copy parameters, get inputs and metadata for the
# foreach copy, and for the others, actually get their all-gather inputs
@ -236,7 +236,7 @@ def _get_param_all_gather_inputs(
@torch.no_grad()
def foreach_all_gather_copy_out(
all_gather_result: AllGatherResult,
fsdp_params: List[FSDPParam],
fsdp_params: list[FSDPParam],
group: dist.ProcessGroup,
) -> None:
(
@ -255,8 +255,8 @@ def foreach_all_gather_copy_out(
all_gather_work.wait()
world_size, device = group.size(), all_gather_output.device
split_with_sizes_out: List[torch.Tensor] = []
shard_i_copy_infos: List[tuple[FSDPParam, List[torch.Tensor]]] = []
split_with_sizes_out: list[torch.Tensor] = []
shard_i_copy_infos: list[tuple[FSDPParam, list[torch.Tensor]]] = []
for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
):
@ -320,8 +320,8 @@ def foreach_all_gather_copy_out(
@torch.no_grad()
def foreach_reduce(
fsdp_params: List[FSDPParam],
unsharded_grads: List[torch.Tensor],
fsdp_params: list[FSDPParam],
unsharded_grads: list[torch.Tensor],
reduce_scatter_group: dist.ProcessGroup,
reduce_scatter_stream: torch.Stream,
orig_dtype: torch.dtype,
@ -488,7 +488,7 @@ def foreach_reduce(
def foreach_reduce_scatter_copy_in(
unsharded_grads: List[torch.Tensor],
unsharded_grads: list[torch.Tensor],
reduce_scatter_input: torch.Tensor,
world_size: int,
) -> None:
@ -499,14 +499,14 @@ def foreach_reduce_scatter_copy_in(
def _get_all_gather_input_metadatas(
param_all_gather_inputs: List[List[torch.Tensor]],
) -> tuple[List[List[torch.dtype]], List[List[int]], torch.dtype]:
param_all_gather_input_dtypes: List[List[torch.dtype]] = []
param_all_gather_input_numels: List[List[int]] = []
param_all_gather_inputs: list[list[torch.Tensor]],
) -> tuple[list[list[torch.dtype]], list[list[int]], torch.dtype]:
param_all_gather_input_dtypes: list[list[torch.dtype]] = []
param_all_gather_input_numels: list[list[int]] = []
all_gather_dtype = param_all_gather_inputs[0][0].dtype
for all_gather_inputs in param_all_gather_inputs:
input_dtypes: List[torch.dtype] = []
input_numels: List[int] = []
input_dtypes: list[torch.dtype] = []
input_numels: list[int] = []
for all_gather_input in all_gather_inputs:
if all_gather_input.dtype != all_gather_dtype:
all_gather_dtype = torch.uint8

View File

@ -3,7 +3,7 @@ import math
import traceback
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, List, Optional
from typing import Any, Optional
import torch
import torch.distributed as dist
@ -120,7 +120,7 @@ def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Si
def _chunk_with_empty(
tensor: torch.Tensor, num_chunks: int, dim: int
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
while len(chunks) < num_chunks:
chunks.append(chunks[0].new_empty(0))

View File

@ -1,5 +1,5 @@
import itertools
from typing import List, Optional, Set, Union
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -70,11 +70,11 @@ def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device:
return torch.device(mesh.device_type, device_handle.current_device())
def _get_managed_modules(root_modules: tuple[nn.Module, ...]) -> List[nn.Module]:
modules: List[nn.Module] = []
def _get_managed_modules(root_modules: tuple[nn.Module, ...]) -> list[nn.Module]:
modules: list[nn.Module] = []
root_modules_set = set(root_modules)
# Track visisted modules to avoid visiting shared modules multiple times
visited_modules: Set[nn.Module] = set()
visited_modules: set[nn.Module] = set()
def dfs(module: nn.Module) -> None:
"""
@ -113,14 +113,14 @@ def _verify_managed_param(name: str, param: nn.Parameter) -> None:
def _get_managed_states(
modules: List[nn.Module],
) -> tuple[List[nn.Parameter], List[torch.Tensor]]:
params: List[nn.Parameter] = []
buffers: List[torch.Tensor] = []
modules: list[nn.Module],
) -> tuple[list[nn.Parameter], list[torch.Tensor]]:
params: list[nn.Parameter] = []
buffers: list[torch.Tensor] = []
# Track visited parameters/buffers to avoid visiting shared parameters and
# buffers multiple times
visited_params: Set[nn.Parameter] = set()
visited_buffers: Set[torch.Tensor] = set()
visited_params: set[nn.Parameter] = set()
visited_buffers: set[torch.Tensor] = set()
for module in modules:
for name, param in module.named_parameters(recurse=False):
if param not in visited_params:
@ -135,8 +135,8 @@ def _get_managed_states(
def _move_states_to_device(
params: List[nn.Parameter],
buffers: List[torch.Tensor],
params: list[nn.Parameter],
buffers: list[torch.Tensor],
device: torch.device,
) -> None:
"""

View File

@ -1,9 +1,10 @@
# mypy: allow-untyped-defs
import inspect
import itertools
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import auto, Enum
from typing import Any, Callable, cast, List, Optional, Sequence
from typing import Any, Callable, cast, Optional
import torch
import torch.nn as nn
@ -171,8 +172,8 @@ class ParamModuleInfo:
# Parameter names are unprefixed, e.g. "weight", not "lin.weight"
module: nn.Module
param_name: str
shared_modules: List[nn.Module] = field(default_factory=list)
shared_param_names: List[str] = field(default_factory=list)
shared_modules: list[nn.Module] = field(default_factory=list)
shared_param_names: list[str] = field(default_factory=list)
@dataclass
@ -211,10 +212,10 @@ class FSDPParam:
_sharding_spec: DTensorSpec
# DTensor attributes (only defined for DTensor `param`):
_tp_spec: DTensorSpec
all_gather_outputs: List[torch.Tensor] # 1D
all_gather_outputs: list[torch.Tensor] # 1D
# All-gather extension attributes
_extensions_data: ExtensionsData
_unsharded_inner_tensors: List[torch.Tensor]
_unsharded_inner_tensors: list[torch.Tensor]
def __init__(
self,
@ -241,7 +242,7 @@ class FSDPParam:
if self.post_forward_mesh_info:
self._init_sharded_post_forward_param_metadata(param)
self._init_extensions()
self.all_gather_outputs: List[torch.Tensor] = []
self.all_gather_outputs: list[torch.Tensor] = []
self.unsharded_accumulated_grad = None
self._param_fqn: Optional[str] = None # prefixed from root module
# TODO: Remove this padding logic once DTensor pads the local tensor:
@ -439,12 +440,12 @@ class FSDPParam:
)
if has_fsdp_pre_all_gather:
self._extensions_data = ExtensionsData()
self._unsharded_inner_tensors: List[torch.Tensor] = []
self._unsharded_inner_tensors: list[torch.Tensor] = []
def init_all_gather_outputs(
self,
all_gather_input_numels: List[int],
all_gather_input_dtypes: List[torch.dtype],
all_gather_input_numels: list[int],
all_gather_input_dtypes: list[torch.dtype],
world_size: int,
device: torch.device,
force_recreate: bool = False,
@ -680,7 +681,7 @@ class FSDPParam:
free_storage(tensor)
@property
def all_gather_inputs(self) -> List[torch.Tensor]: # 1D
def all_gather_inputs(self) -> list[torch.Tensor]: # 1D
self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
if self.sharded_state == ShardedState.SHARDED:
if not compiled_autograd_enabled() and hasattr(

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
import logging
from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional, Set
from typing import Any, Callable, cast, NamedTuple, Optional
import torch
import torch.distributed as dist
@ -31,7 +31,7 @@ from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict
_ModuleToHandleDict = dict[nn.Module, RemovableHandle] # for state dict
"""
@ -77,7 +77,7 @@ class FSDPCommContext:
self.all_gather_state: Optional[AllGatherState] = None
self.reduce_scatter_state: Optional[ReduceScatterState] = None
# Post-forward order for explicit backward prefetching
self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles
self.post_forward_order: list[FSDPParamGroup] = [] # will cause ref cycles
def get_all_gather_streams(
self, async_op: bool, training_state: TrainingState
@ -116,7 +116,7 @@ class FSDPParamGroup:
def __init__(
self,
params: List[nn.Parameter],
params: list[nn.Parameter],
modules: tuple[nn.Module, ...],
mesh_info: FSDPMeshInfo,
post_forward_mesh_info: Optional[FSDPMeshInfo],
@ -162,7 +162,7 @@ class FSDPParamGroup:
# - Communication and communication/computation overlap
self.comm_ctx = FSDPCommContext()
# Group's indices in the shared post-forward order
self._post_forward_indices: List[int] = []
self._post_forward_indices: list[int] = []
# Whether to reduce gradients at all (whether for FSDP or HSDP)
self.reduce_grads: bool = True
# Whether to all-reduce gradients for HSDP; only used if
@ -325,8 +325,8 @@ class FSDPParamGroup:
self._to_sharded()
def pre_forward(
self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
if not compiled_autograd_enabled():
logger.debug("%s", self._with_fqn("FSDP::pre_forward"))
with record_function(self._with_fqn("FSDP::pre_forward")):
@ -389,8 +389,8 @@ class FSDPParamGroup:
return
# Save the autograd-computed gradients before resharding to only
# access the unsharded parameters when their data is present
fsdp_params_with_grad: List[FSDPParam] = []
unsharded_grads: List[torch.Tensor] = []
fsdp_params_with_grad: list[FSDPParam] = []
unsharded_grads: list[torch.Tensor] = []
for fsdp_param in self.fsdp_params:
if not hasattr(fsdp_param, "_unsharded_param"):
continue
@ -543,8 +543,8 @@ class FSDPParamGroup:
# Hook Registration #
def _register_post_backward_hook(
self, args: tuple[Any, ...], kwargs: Dict[str, Any]
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
self, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
# Traceable FSDP2 relies on `root_post_backward_callback` to call each
# `FSDPParamGroup.post_backward`
if (not torch._dynamo.config.skip_fsdp_hooks) or compiled_autograd_enabled():
@ -554,8 +554,8 @@ class FSDPParamGroup:
args_list, args_spec = tree_flatten(args)
kwargs_list, kwargs_spec = tree_flatten(kwargs)
args_kwargs_list = list(args_list) + list(kwargs_list)
inp_tensor_indices: List[int] = []
inp_tensors: List[torch.Tensor] = []
inp_tensor_indices: list[int] = []
inp_tensors: list[torch.Tensor] = []
for i, obj in enumerate(args_kwargs_list):
if torch.is_tensor(obj) and obj.requires_grad:
inp_tensor_indices.append(i)
@ -579,7 +579,7 @@ class FSDPParamGroup:
), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
if num_pre_save_hooks > 0:
return # already registered
modules_with_fsdp_params: Set[nn.Module] = {
modules_with_fsdp_params: set[nn.Module] = {
fsdp_param._module_info.module for fsdp_param in self.fsdp_params
}
@ -670,8 +670,8 @@ class FSDPParamGroup:
def _get_param_module_infos(
params: List[nn.Parameter], modules: tuple[nn.Module, ...]
) -> List[ParamModuleInfo]:
params: list[nn.Parameter], modules: tuple[nn.Module, ...]
) -> list[ParamModuleInfo]:
"""
Shared parameter: lin1.weight = lin2.weight
Shared module: mlp.lin1 = mlp.lin2
@ -679,7 +679,7 @@ def _get_param_module_infos(
find shared modules' parameters and shared parameters within a module.
"""
params_set = set(params)
param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {}
param_to_module_info: dict[nn.Parameter, ParamModuleInfo] = {}
for module in modules:
for _, submodule in module.named_modules(remove_duplicate=False):
for param_name, param in _named_parameters_with_duplicates(

View File

@ -2,7 +2,8 @@
# mypy: allow-untyped-defs
import functools
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING
from collections.abc import Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
import torch.nn as nn
@ -40,7 +41,7 @@ class FSDPStateContext:
def __init__(self) -> None:
# All FSDP states in the root state's module tree
self.all_states: List[FSDPState] = []
self.all_states: list[FSDPState] = []
# Iteration's forward root runs the once-per-forward logic; this root
# may not be the overall root set by lazy initialization in cases where
# only a submodule runs forward (e.g. encoder-only for eval)
@ -73,9 +74,9 @@ class FSDPState(_State):
self._state_ctx = FSDPStateContext()
self._comm_ctx = FSDPCommContext()
self._training_state: TrainingState = TrainingState.IDLE
self._states_to_forward_prefetch: List[FSDPState] = []
self._states_to_backward_prefetch: List[FSDPState] = []
self._modules_to_run_forward: Set[nn.Module] = set()
self._states_to_forward_prefetch: list[FSDPState] = []
self._states_to_backward_prefetch: list[FSDPState] = []
self._modules_to_run_forward: set[nn.Module] = set()
# Define a separate init since `__init__` is called in the contract
def init(
@ -108,8 +109,8 @@ class FSDPState(_State):
self._post_forward_hook_handle = hook_handle
def _root_pre_forward(
self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
self._lazy_init()
if self._state_ctx.iter_forward_root is not None:
return args, kwargs
@ -150,7 +151,7 @@ class FSDPState(_State):
)
detect_compiled_autograd()
root_module = self._modules[0]
visited_states: Set[FSDPState] = set()
visited_states: set[FSDPState] = set()
for module_name, module in root_module.named_modules():
if (state := _get_module_fsdp_state(module)) is None:
continue
@ -188,8 +189,8 @@ class FSDPState(_State):
"""Sets module and parameter FQN attributes for debugging."""
assert self._is_root
root_module = self._modules[0]
param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {}
module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {}
param_to_fsdp_param: dict[nn.Parameter, FSDPParam] = {}
module_to_fsdp_param_group: dict[nn.Module, FSDPParamGroup] = {}
for state in self._state_ctx.all_states:
if fsdp_param_group := state._fsdp_param_group:
for fsdp_param in fsdp_param_group.fsdp_params:
@ -211,8 +212,8 @@ class FSDPState(_State):
@disable_if_config_true
def _pre_forward(
self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
# When composing with module-hook-based activation checkpointing, the
# the pre-backward hook is responsible for the unshard
if self._training_state == TrainingState.PRE_BACKWARD:
@ -343,7 +344,7 @@ def _register_group_forward_hooks(
modules: Sequence[nn.Module],
pre_hook: Callable,
post_hook: Callable,
modules_to_run: Set[nn.Module],
modules_to_run: set[nn.Module],
):
"""
Registers group forward pre and post-hooks. The pre-hook runs upon the

View File

@ -1,18 +1,8 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
NoReturn,
Optional,
Type,
Union,
)
from collections.abc import Iterable
from typing import Any, Callable, cast, NoReturn, Optional, Union
import torch
import torch.nn as nn
@ -42,14 +32,14 @@ __all__ = [
]
cls_to_fsdp_cls: Dict[Type, Type] = {}
cls_to_fsdp_cls: dict[type, type] = {}
# The decorator adds a state object to `module` that can be accessed via
# `fully_shard.state(module)`. The state object and module are 1:1.
@contract(state_cls=FSDPState)
def fully_shard(
module: Union[nn.Module, List[nn.Module]],
module: Union[nn.Module, list[nn.Module]],
*,
mesh: Optional[DeviceMesh] = None,
reshard_after_forward: Union[bool, int] = True,
@ -331,7 +321,7 @@ class FSDPModule:
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reshard_after_backward = reshard_after_backward
def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None:
def set_modules_to_forward_prefetch(self, modules: list["FSDPModule"]) -> None:
"""
Sets the FSDP modules for which this FSDP module should explicitly
prefetch all-gathers in forward. The prefetching runs after this
@ -351,7 +341,7 @@ class FSDPModule:
module._get_fsdp_state() for module in modules
]
def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None:
def set_modules_to_backward_prefetch(self, modules: list["FSDPModule"]) -> None:
"""
Sets the FSDP modules for which this FSDP module should explicitly
prefetch all-gathers in backward. This overrides the default backward

View File

@ -3,21 +3,8 @@ import collections
import itertools
import os
import warnings
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
Iterable,
Iterator,
List,
no_type_check,
Optional,
Set,
TYPE_CHECKING,
Union,
)
from collections.abc import Generator, Iterable, Iterator
from typing import Any, Callable, no_type_check, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -329,7 +316,7 @@ def _init_ignored_module_states(
def _check_ignored_states(
ignored_states: List[Any], passed_as_ignored_states: bool
ignored_states: list[Any], passed_as_ignored_states: bool
) -> None:
"""
Check that the ignored states are uniformly parameters or uniformly modules.
@ -361,7 +348,7 @@ def _check_ignored_states(
def _init_device_handle(
state: _FSDPState,
module: nn.Module,
ignored_params: Set[nn.Parameter],
ignored_params: set[nn.Parameter],
device_id: Optional[Union[int, torch.device]],
) -> _FSDPState:
"""
@ -416,7 +403,7 @@ def _init_buffer_state(
# `module`) to its original dtype for restoring that dtype during model
# checkpointing when buffer mixed precision is enabled. The names should
# be clean since the casting happens in a `summon_full_params()` context.
_buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
_buffer_name_to_orig_dtype: dict[str, torch.dtype] = {}
for buffer_name, buffer in module.named_buffers():
buffer_name = clean_tensor_name(buffer_name)
_buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
@ -479,13 +466,13 @@ def _init_core_state(
state._unshard_event = None
# Mapping from fully sharded module to the handles it is responsible to
# unshard and reshard (see [Note: Fully Sharded Module])
_fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = {}
_fully_sharded_module_to_handle: dict[nn.Module, FlatParamHandle] = {}
state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
# Invariant: `state.params` contains exactly the `FlatParameter`s of the
# handles in `state._handle`
_handle: Optional[FlatParamHandle] = None
state._handle = _handle
params: List[FlatParameter] = []
params: list[FlatParameter] = []
state.params = params
return state
@ -494,11 +481,11 @@ def _init_core_state(
def _init_runtime_state(
state: _FSDPState,
) -> _FSDPState:
_root_pre_forward_handles: List[RemovableHandle] = []
_root_pre_forward_handles: list[RemovableHandle] = []
state._root_pre_forward_handles = _root_pre_forward_handles
_pre_forward_handles: List[RemovableHandle] = []
_pre_forward_handles: list[RemovableHandle] = []
state._pre_forward_handles = _pre_forward_handles
_post_forward_handles: List[RemovableHandle] = []
_post_forward_handles: list[RemovableHandle] = []
state._post_forward_handles = _post_forward_handles
state._sync_gradients = True
state._comm_hook = None
@ -542,13 +529,13 @@ def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
state_dict_config: StateDictConfig = FullStateDictConfig()
state._optim_state_dict_config = FullOptimStateDictConfig()
state._state_dict_config = state_dict_config
unshard_params_ctx: Dict[nn.Module, Generator] = {}
unshard_params_ctx: dict[nn.Module, Generator] = {}
state._unshard_params_ctx = unshard_params_ctx
return state
def _verify_managed_params(module: nn.Module, params: List[nn.Parameter]) -> None:
def _verify_managed_params(module: nn.Module, params: list[nn.Parameter]) -> None:
"""
Verify if the parameters are accepted by FSDP. The only restriction now
is that the parameter cannot be a scalar tensor (param.shape == []).
@ -639,7 +626,7 @@ def _init_param_handle_from_module(
@no_type_check
def _init_param_handle_from_params(
state: _FSDPState,
params: List[nn.Parameter],
params: list[nn.Parameter],
fully_sharded_module: nn.Module,
):
if len(params) == 0:
@ -670,7 +657,7 @@ def _init_param_handle_from_params(
def _get_ignored_modules(
root_module: nn.Module,
_ignored_modules: Optional[Iterable[torch.nn.Module]],
) -> Set[nn.Module]:
) -> set[nn.Module]:
"""
Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances.
@ -726,15 +713,15 @@ def _get_ignored_modules(
def _get_ignored_params(
root_module: torch.nn.Module,
ignored_modules: Set[torch.nn.Module],
ignored_modules: set[torch.nn.Module],
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
) -> Set[torch.nn.Parameter]:
) -> set[torch.nn.Parameter]:
"""
Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``.
:class:`FlatParameter` s are excluded from the result.
"""
all_ignored_params: Set[torch.nn.Parameter] = set()
all_ignored_params: set[torch.nn.Parameter] = set()
params_in_ignored_modules = {
p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
@ -760,10 +747,10 @@ def _get_ignored_params(
def _get_ignored_buffer_names(
root_module: torch.nn.Module,
ignored_modules: Set[torch.nn.Module],
) -> Set[str]:
ignored_modules: set[torch.nn.Module],
) -> set[str]:
"""Return the cleaned buffer FQNs in ``ignored_modules``."""
all_ignored_buffer_names: Set[str] = set()
all_ignored_buffer_names: set[str] = set()
buffers_in_ignored_modules = {
buffer for m in ignored_modules for buffer in m.buffers()
@ -787,7 +774,7 @@ def _get_ignored_buffer_names(
return all_ignored_buffer_names
def _get_buffer_names(root_module: nn.Module) -> Set[str]:
def _get_buffer_names(root_module: nn.Module) -> set[str]:
"""Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`."""
return {
clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
@ -796,7 +783,7 @@ def _get_buffer_names(root_module: nn.Module) -> Set[str]:
def _check_single_device_module(
module: nn.Module,
ignored_params: Set[nn.Parameter],
ignored_params: set[nn.Parameter],
device_id: Optional[Union[int, torch.device]],
) -> None:
"""
@ -855,8 +842,8 @@ def _get_device_from_device_id(
def _need_to_materialize_module(
module: nn.Module,
ignored_params: Set[nn.Parameter],
ignored_modules: Set[nn.Module],
ignored_params: set[nn.Parameter],
ignored_modules: set[nn.Module],
) -> tuple[bool, bool]:
"""
Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization.
@ -886,7 +873,7 @@ def _need_to_materialize_module(
def _materialize_with_param_init_fn(
root_module: nn.Module,
param_init_fn: Callable[[nn.Module], None],
ignored_modules: Set[nn.Module],
ignored_modules: set[nn.Module],
) -> None:
if not callable(param_init_fn):
raise ValueError(
@ -900,7 +887,7 @@ def _materialize_with_param_init_fn(
def _materialize_meta_module(
root_module: nn.Module,
device_from_device_id: Optional[torch.device],
ignored_modules: Set[nn.Module],
ignored_modules: set[nn.Module],
device_handle: _FSDPDeviceHandle,
):
# Run default meta device initialization
@ -933,13 +920,13 @@ def _materialize_meta_module(
def _get_modules_to_materialize(
root_module: nn.Module, ignored_modules: Set[nn.Module]
) -> List[nn.Module]:
root_module: nn.Module, ignored_modules: set[nn.Module]
) -> list[nn.Module]:
# Run BFS to collect the modules to materialize via `reset_parameters()`,
# stopping at any module with FSDP already applied or at ignored modules.
modules_to_materialize: List[nn.Module] = []
modules_to_materialize: list[nn.Module] = []
queue = collections.deque([root_module])
visited_modules: Set[nn.Module] = {root_module}
visited_modules: set[nn.Module] = {root_module}
while queue:
module = queue.popleft()
modules_to_materialize.append(module)
@ -956,8 +943,8 @@ def _get_modules_to_materialize(
def _move_module_to_device(
module: nn.Module,
ignored_params: Set[nn.Parameter],
ignored_buffers: Set[torch.Tensor],
ignored_params: set[nn.Parameter],
ignored_buffers: set[torch.Tensor],
device_from_device_id: Optional[torch.device],
) -> None:
"""
@ -976,10 +963,10 @@ def _move_module_to_device(
if device_from_device_id is not None:
# BFS from `module` without traversing any nested FSDP instances to
# collect the parameters/buffers that have not yet been managed
queue: Deque[nn.Module] = collections.deque()
queue: collections.deque[nn.Module] = collections.deque()
queue.append(module)
params: List[nn.Parameter] = []
buffers: List[torch.Tensor] = []
params: list[nn.Parameter] = []
buffers: list[torch.Tensor] = []
while queue:
curr_module = queue.popleft()
# NOTE: We include a check to only move parameters/buffers that are
@ -1009,8 +996,8 @@ def _move_module_to_device(
def _move_states_to_device(
params: List[nn.Parameter],
buffers: List[torch.Tensor],
params: list[nn.Parameter],
buffers: list[torch.Tensor],
device_from_device_id: Optional[torch.device],
) -> None:
"""
@ -1053,7 +1040,7 @@ def _warn_cpu_init():
def _get_compute_device(
module: nn.Module,
ignored_params: Set[nn.Parameter],
ignored_params: set[nn.Parameter],
device_from_device_id: Optional[torch.device],
rank: int,
device_handle: _FSDPDeviceHandle,
@ -1088,7 +1075,7 @@ def _get_compute_device(
# TODO: See how to deprecate!
def _sync_module_params_and_buffers(
module: nn.Module,
params: List[nn.Parameter],
params: list[nn.Parameter],
process_group: dist.ProcessGroup,
) -> None:
"""
@ -1097,7 +1084,7 @@ def _sync_module_params_and_buffers(
Precondition: ``sync_module_states == True`` and ``self.process_group`` has
been set.
"""
module_states: List[torch.Tensor] = []
module_states: list[torch.Tensor] = []
for buffer in module.buffers():
# Avoid re-synchronizing buffers in case of nested wrapping
if not getattr(buffer, FSDP_SYNCED, False):
@ -1131,7 +1118,7 @@ def _sync_module_params_and_buffers(
def _check_module_states_for_sync_module_states(
module_states: List[torch.Tensor],
module_states: list[torch.Tensor],
) -> None:
if module_states and any(
tensor.device == torch.device("cpu") for tensor in module_states
@ -1145,7 +1132,7 @@ def _check_module_states_for_sync_module_states(
def _get_orig_params(
module: nn.Module,
ignored_params: Set[nn.Parameter],
ignored_params: set[nn.Parameter],
) -> Iterator[nn.Parameter]:
"""
Return an iterator over the original parameters in ``module``.
@ -1167,7 +1154,7 @@ def _get_orig_params(
def _check_orig_params_flattened(
fsdp_module,
ignored_params: Set[nn.Parameter],
ignored_params: set[nn.Parameter],
) -> None:
"""
Check that original parameters in ``fsdp_module`` have been flattened.

View File

@ -1,5 +1,5 @@
import collections
from typing import Deque, Optional
from typing import Optional
import torch
@ -12,7 +12,7 @@ class _FreeEventQueue:
"""
def __init__(self) -> None:
self._queue: Deque[torch.Event] = collections.deque()
self._queue: collections.deque[torch.Event] = collections.deque()
self._max_num_inflight_all_gathers = 2 # empirically chosen
def enqueue(self, free_event: torch.Event) -> None:

View File

@ -3,23 +3,10 @@ import copy
import functools
import logging
import warnings
from collections.abc import Iterable, Iterator, Sequence
from contextlib import ExitStack
from dataclasses import dataclass, field
from typing import (
Any,
cast,
Dict,
Iterable,
Iterator,
List,
NamedTuple,
no_type_check,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Union,
)
from typing import Any, cast, NamedTuple, no_type_check, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -66,11 +53,11 @@ logger = logging.getLogger(__name__)
class FSDPParamInfo:
state: _FSDPState
handle: FlatParamHandle
param_indices: Dict[str, int]
param_requires_grad: List[bool]
param_indices: dict[str, int]
param_requires_grad: list[bool]
def sorted_items(dictionary: Dict[str, Any]) -> Iterator[tuple[str, Any]]:
def sorted_items(dictionary: dict[str, Any]) -> Iterator[tuple[str, Any]]:
keys = sorted(dictionary.keys())
for k in keys:
yield k, dictionary[k]
@ -97,9 +84,9 @@ class _ConsolidatedOptimState:
name to its value.
"""
tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
non_tensor_state: Dict[str, Any] = field(default_factory=dict)
tensor_state: dict[str, torch.Tensor] = field(default_factory=dict)
zero_dim_tensor_state: dict[str, torch.Tensor] = field(default_factory=dict)
non_tensor_state: dict[str, Any] = field(default_factory=dict)
class _PosDimTensorInfo(NamedTuple):
@ -131,11 +118,11 @@ class _OptimStateKey(NamedTuple):
def _unflatten_optim_state(
fsdp_param_info: FSDPParamInfo,
flat_param_state: Dict[str, Any],
flat_param_state: dict[str, Any],
to_save: bool,
shard_state: bool,
cpu_offload: bool,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Unflattens the optimizer state, consisting of the "state" part and the
"param_groups" part. Unflattening the "state" part involves consolidating
@ -190,7 +177,7 @@ def _is_zero_dim_tensor(x: Any) -> bool:
def _communicate_optim_state(
fsdp_param_info: FSDPParamInfo,
flat_param_state: Dict[str, Any],
flat_param_state: dict[str, Any],
) -> _ConsolidatedOptimState:
"""
Communicates the optimizer state for a flat parameter across ranks. All
@ -262,7 +249,7 @@ def _unflatten_communicated_optim_state(
fsdp_param_info: FSDPParamInfo,
state: _ConsolidatedOptimState,
shard_state: bool,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Unflattens the communicated optimizer state (given by ``tensor_state``,
``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat
@ -283,8 +270,8 @@ def _unflatten_communicated_optim_state(
fsdp_state = fsdp_param_info.state
handle = fsdp_param_info.handle
flat_param = handle.flat_param
unflat_param_state: List[Dict[str, Any]] = []
flat_param_views: Dict[str, Iterator] = {}
unflat_param_state: list[dict[str, Any]] = []
flat_param_views: dict[str, Iterator] = {}
num_unflat_params = flat_param._num_params
tensor_state, zero_dim_tensor_state, non_tensor_state = (
state.tensor_state,
@ -337,10 +324,10 @@ def _unflatten_communicated_optim_state(
def _broadcast_processed_state(
fsdp_state: _FSDPState,
optim_state: Dict[str, Any],
optim_state: dict[str, Any],
group: Optional[dist.ProcessGroup],
) -> Dict[str, Any]:
objects: List[Any] = [None]
) -> dict[str, Any]:
objects: list[Any] = [None]
if dist.get_rank(group) == 0:
objects[0] = tree_map_only(
torch.Tensor,
@ -380,8 +367,8 @@ def _broadcast_state(
def _shard_orig_param_state(
fsdp_param_info: FSDPParamInfo,
fqn: str,
optim_state: Dict[str, Any],
) -> Dict[str, Any]:
optim_state: dict[str, Any],
) -> dict[str, Any]:
"""
Shard the optimizer state for the original parameter with the name ``fqn``.
This API should only be used when ``use_orig_params`` is True.
@ -398,7 +385,7 @@ def _shard_orig_param_state(
if not shard_param_info.in_shard:
return {}
# Flatten and shard the state.
new_optim_state: Dict[str, Any] = {}
new_optim_state: dict[str, Any] = {}
intra_param_start_idx = shard_param_info.intra_param_start_idx
intra_param_end_idx = shard_param_info.intra_param_end_idx
for state_name, value in optim_state.items():
@ -413,13 +400,13 @@ def _shard_orig_param_state(
def _flatten_optim_state_dict(
optim_state_dict: Dict[str, Any],
optim_state_dict: dict[str, Any],
model: nn.Module,
use_orig_params: bool = False,
optim: Optional[torch.optim.Optimizer] = None,
rank0_only: bool = False,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Flattens the full optimizer state dict, still keying by unflattened parameter
names.
@ -461,7 +448,7 @@ def _flatten_optim_state_dict(
unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group)
# Construct the "state" part
flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {}
flat_osd_state: dict[Union[_OptimStateKey, str], Any] = {}
unflat_osd_state = unflat_osd["state"]
all_state_keys = set(unflat_osd_state.keys())
@ -557,9 +544,9 @@ def _flatten_optim_state_dict(
def _flatten_optim_state(
fsdp_param_info: FSDPParamInfo,
unflat_osd_state: Dict[str, Dict[str, Any]],
unflat_param_names: List[str],
) -> Dict[str, Any]:
unflat_osd_state: dict[str, dict[str, Any]],
unflat_param_names: list[str],
) -> dict[str, Any]:
"""
Flattens the optimizer state in ``full_optim_state_dict`` for a single
flat parameter in ``fsdp_param_info`` corresponding to the unflattened
@ -629,7 +616,7 @@ def _flatten_optim_state(
assert state_names is not None
# Flatten the state
flat_state: Dict[str, Optional[torch.Tensor]] = {}
flat_state: dict[str, Optional[torch.Tensor]] = {}
for state_name in state_names:
state_values = [
unflat_param_state[state_name] if unflat_param_state is not None else None
@ -695,8 +682,8 @@ def _flatten_optim_state(
def _flatten_tensor_optim_state(
state_name: str,
pos_dim_tensors: List[torch.Tensor],
unflat_param_names: List[str],
pos_dim_tensors: list[torch.Tensor],
unflat_param_names: list[str],
unflat_param_shapes: Sequence[torch.Size],
handle: FlatParamHandle,
) -> torch.Tensor:
@ -780,8 +767,8 @@ def _flatten_tensor_optim_state(
def _flatten_zero_dim_tensor_optim_state(
state_name: str,
zero_dim_tensors: List[torch.Tensor],
unflat_param_names: List[str],
zero_dim_tensors: list[torch.Tensor],
unflat_param_names: list[str],
) -> torch.Tensor:
"""
Flattens the zero-dimension tensor optimizer state given by the values
@ -834,8 +821,8 @@ def _flatten_zero_dim_tensor_optim_state(
def _flatten_non_tensor_optim_state(
state_name: str,
non_tensors: List[Any],
unflat_param_names: List[str],
non_tensors: list[Any],
unflat_param_names: list[str],
) -> Any:
"""
Flattens the non-tensor optimizer state given by the values ``non_tensors``
@ -872,18 +859,18 @@ def _flatten_non_tensor_optim_state(
def _rekey_sharded_optim_state_dict(
sharded_osd: Dict[str, Any],
sharded_osd: dict[str, Any],
model: nn.Module,
optim: torch.optim.Optimizer,
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[nn.Parameter],
]
],
using_optim_input: bool,
is_named_optimizer: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Rekeys the optimizer state dict from unflattened parameter names to flat
parameter IDs according to the calling rank's ``optim``, which may be
@ -892,8 +879,8 @@ def _rekey_sharded_optim_state_dict(
"""
param_to_fqns = _get_param_to_fqns(model)
flat_param_to_fqn = _get_flat_param_to_fqn(model)
param_to_param_key: Dict[nn.Parameter, Union[int, str]] = cast(
Dict[nn.Parameter, Union[int, str]],
param_to_param_key: dict[nn.Parameter, Union[int, str]] = cast(
dict[nn.Parameter, Union[int, str]],
(
_get_param_to_param_id_from_optim_input(model, optim_input)
if using_optim_input
@ -907,10 +894,10 @@ def _rekey_sharded_optim_state_dict(
# passed to the optimizer
assert len(param_to_param_key) <= len(param_to_fqns)
unflat_param_names_to_flat_param_key: Dict[
unflat_param_names_to_flat_param_key: dict[
tuple[str, ...], Union[int, str]
] = {} # for "state"
unflat_param_name_to_flat_param_key: Dict[
unflat_param_name_to_flat_param_key: dict[
str, Union[int, str]
] = {} # for "param_groups"
for param, unflat_param_names in param_to_fqns.items():
@ -923,7 +910,7 @@ def _rekey_sharded_optim_state_dict(
unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key
sharded_osd_state = sharded_osd["state"]
rekeyed_osd_state: Dict[Union[str, int], Any] = {}
rekeyed_osd_state: dict[Union[str, int], Any] = {}
for key, param_state in sharded_osd_state.items():
if isinstance(key, str):
rekeyed_osd_state[key] = param_state
@ -935,7 +922,7 @@ def _rekey_sharded_optim_state_dict(
# Only process param_groups if it exists in sharded_osd
if "param_groups" in sharded_osd:
rekeyed_osd_param_groups: List[Dict[str, Any]] = []
rekeyed_osd_param_groups: list[dict[str, Any]] = []
for unflat_param_group in sharded_osd["param_groups"]:
flat_param_group = copy.deepcopy(unflat_param_group)
flat_param_keys = sorted(
@ -955,11 +942,11 @@ def _get_param_id_to_param_from_optim_input(
model: nn.Module,
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[nn.Parameter],
]
] = None,
) -> Dict[int, nn.Parameter]:
) -> dict[int, nn.Parameter]:
"""
Constructs a mapping from parameter IDs to parameters. This may be used
both for models with ``FlatParameter`` s and without.
@ -993,7 +980,7 @@ def _get_param_id_to_param_from_optim_input(
if optim_input is None:
return dict(enumerate(model.parameters()))
try:
params = cast(List[nn.Parameter], list(optim_input))
params = cast(list[nn.Parameter], list(optim_input))
except TypeError as e:
raise TypeError(
"Optimizer input should be an iterable of Tensors or dicts, "
@ -1013,7 +1000,7 @@ def _get_param_id_to_param_from_optim_input(
if all_tensors:
return dict(enumerate(params))
assert all_dicts
param_id_to_param: List[nn.Parameter] = []
param_id_to_param: list[nn.Parameter] = []
for param_group in params:
has_params_key = "params" in param_group # type: ignore[operator]
assert has_params_key, (
@ -1026,7 +1013,7 @@ def _get_param_id_to_param_from_optim_input(
return dict(enumerate(param_id_to_param))
def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]:
def _get_flat_param_to_fqn(model: torch.nn.Module) -> dict[FlatParameter, str]:
"""
Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes
from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical"
@ -1053,7 +1040,7 @@ def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]:
def return_fn(flat_param_to_fqn):
return flat_param_to_fqn
flat_param_to_fqn_ret: Dict[FlatParameter, str] = {}
flat_param_to_fqn_ret: dict[FlatParameter, str] = {}
return _apply_to_modules(
model,
module_fn,
@ -1067,16 +1054,16 @@ def _get_param_key_to_param(
optim: torch.optim.Optimizer,
model: Optional[nn.Module] = None,
is_named_optimizer: bool = False,
param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
) -> Dict[Union[int, str], nn.Parameter]:
param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None,
flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None,
) -> dict[Union[int, str], nn.Parameter]:
"""
Constructs a mapping from parameter keys to parameters. For the regular
optimizers, the keys are parameter IDs. For NamedOptimizer, the keys
are FQNs. This API may be used both for models with ``FlatParameter`` s and
without.
"""
clean_fqn_to_curr_fqn: Dict[str, str] = {}
clean_fqn_to_curr_fqn: dict[str, str] = {}
if is_named_optimizer:
assert (
param_to_fqns is not None and flat_param_to_fqn is not None
@ -1085,7 +1072,7 @@ def _get_param_key_to_param(
for key, _ in _named_parameters_with_duplicates(model):
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
param_key_to_param: Dict[Union[str, int], nn.Parameter] = {}
param_key_to_param: dict[Union[str, int], nn.Parameter] = {}
pid = 0
for param_group in optim.param_groups:
if is_named_optimizer:
@ -1118,9 +1105,9 @@ def _get_param_to_param_key(
optim: torch.optim.Optimizer,
model: Optional[nn.Module] = None,
is_named_optimizer: bool = False,
param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
) -> Dict[nn.Parameter, Union[int, str]]:
param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None,
flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None,
) -> dict[nn.Parameter, Union[int, str]]:
"""
Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API
only supports the case where `optim` is a regular optimizer, not NamedOptimizer.
@ -1136,25 +1123,25 @@ def _get_param_to_param_id_from_optim_input(
model: nn.Module,
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[nn.Parameter],
]
] = None,
) -> Dict[nn.Parameter, int]:
) -> dict[nn.Parameter, int]:
"""Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`."""
param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input)
return {param: param_id for param_id, param in param_id_to_param.items()}
def _check_missing_keys_on_rank(
r0_optim_state_keys: List[_OptimStateKey],
optim_state_key_to_param_key: Dict[_OptimStateKey, Union[str, int]],
param_key_to_param: Dict[Union[str, int], nn.Parameter],
r0_optim_state_keys: list[_OptimStateKey],
optim_state_key_to_param_key: dict[_OptimStateKey, Union[str, int]],
param_key_to_param: dict[Union[str, int], nn.Parameter],
group: Optional[dist.ProcessGroup],
) -> None:
# Ensure that all ranks have at least the optimizer states needed by
# rank 0's optimizer
missing_keys: List[_OptimStateKey] = []
missing_keys: list[_OptimStateKey] = []
for r0_optim_state_key in r0_optim_state_keys:
if r0_optim_state_key not in optim_state_key_to_param_key:
# A parameter from rank 0's optimizer does not exist for this
@ -1179,7 +1166,7 @@ def _check_missing_keys_on_rank(
"are missing some of those states"
)
for rank, keys in enumerate(obj_list):
keys = cast(List[_OptimStateKey], keys)
keys = cast(list[_OptimStateKey], keys)
if len(keys) > 0:
error_msg += (
f"\nRank {rank} is missing states for the parameters: "
@ -1189,13 +1176,13 @@ def _check_missing_keys_on_rank(
def _map_param_key_to_optim_keys(
optim_state_dict: Dict[str, Any],
optim_state_dict: dict[str, Any],
group: Optional[dist.ProcessGroup],
param_key_to_param: Dict[Union[int, str], nn.Parameter],
param_to_fqns: Dict[nn.Parameter, List[str]],
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
param_key_to_param: dict[Union[int, str], nn.Parameter],
param_to_fqns: dict[nn.Parameter, list[str]],
fqn_to_fsdp_param_info: dict[str, FSDPParamInfo],
merge_keys: bool = False,
) -> tuple[List[_OptimStateKey], Dict[_OptimStateKey, Union[int, str]]]:
) -> tuple[list[_OptimStateKey], dict[_OptimStateKey, Union[int, str]]]:
"""
Construct the local mapping between the ``_OptimStateKey`` and parameter keys
and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0
@ -1203,8 +1190,8 @@ def _map_param_key_to_optim_keys(
Note that ``merge_keys`` should equal to ``use_orig_params``.
"""
rank = dist.get_rank(group)
optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]] = {} # local
all_optim_state_keys: List[_OptimStateKey] = []
optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]] = {} # local
all_optim_state_keys: list[_OptimStateKey] = []
for param_key, param in param_key_to_param.items():
# Do not include parameters without state to avoid empty mappings
@ -1228,7 +1215,7 @@ def _map_param_key_to_optim_keys(
optim_state_key_to_param_key[optim_state_key] = param_key
if merge_keys:
all_keys: List[List[_OptimStateKey]] = [
all_keys: list[list[_OptimStateKey]] = [
[] for _ in range(dist.get_world_size(group))
]
dist.all_gather_object(all_keys, all_optim_state_keys, group=group)
@ -1237,7 +1224,7 @@ def _map_param_key_to_optim_keys(
]
all_optim_state_keys = sorted(set(merge_all_optim_state_keys))
else:
key_obj_list: List[Optional[List[_OptimStateKey]]] = (
key_obj_list: list[Optional[list[_OptimStateKey]]] = (
[all_optim_state_keys] if rank == 0 else [None]
)
dist.broadcast_object_list(key_obj_list, src=0, group=group)
@ -1254,11 +1241,11 @@ def _map_param_key_to_optim_keys(
def _unflatten_param_groups(
state_dict: Dict[str, Any],
param_key_to_param: Dict[Union[int, str], nn.Parameter],
param_to_fqns: Dict[nn.Parameter, List[str]],
) -> List[Dict[str, Any]]:
param_groups: List[Dict[str, Any]] = []
state_dict: dict[str, Any],
param_key_to_param: dict[Union[int, str], nn.Parameter],
param_to_fqns: dict[nn.Parameter, list[str]],
) -> list[dict[str, Any]]:
param_groups: list[dict[str, Any]] = []
for flat_param_group in state_dict["param_groups"]:
unflat_param_group = copy.deepcopy(flat_param_group)
param_group_params = [
@ -1277,7 +1264,7 @@ def _unflatten_param_groups(
return param_groups
def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool:
def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool:
"""
Returns whether the state_dict is from a NamedOptimizer.
This function checks that the keys in the state_dict['state'] are strings
@ -1299,22 +1286,22 @@ def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool:
@dataclass
class StateInfo:
# The key of these dictionaries are the state name, e.g., `exp_avg`.
tensors: Dict[str, _PosDimTensorInfo]
scalar_tensors: Dict[str, torch.Tensor]
non_tensors: Dict[str, Any]
tensors: dict[str, _PosDimTensorInfo]
scalar_tensors: dict[str, torch.Tensor]
non_tensors: dict[str, Any]
def _allgather_state_info(
fsdp_state: _FSDPState,
input_states: Dict[str, Any],
) -> List[Dict[str, StateInfo]]:
input_states: dict[str, Any],
) -> list[dict[str, StateInfo]]:
"""
Given the ``input_states``, allgather StateInfo for each state. The function
uses all_gather_object to gather StateInfo so no GPU tensors are sent.
"""
processed_state_dict: Dict[str, StateInfo] = {}
gathered_state_info: List[Dict[str, StateInfo]] = [
processed_state_dict: dict[str, StateInfo] = {}
gathered_state_info: list[dict[str, StateInfo]] = [
{} for _ in range(fsdp_state.world_size)
]
@ -1343,10 +1330,10 @@ def _allgather_state_info(
def _convert_all_state_info(
fsdp_param_info: FSDPParamInfo,
gathered_state_info: List[Dict[str, StateInfo]],
input_states: Dict[str, Any],
output_states: Dict[str, Dict[str, Any]],
) -> tuple[Optional[torch.dtype], Dict[str, List[Optional[torch.Tensor]]]]:
gathered_state_info: list[dict[str, StateInfo]],
input_states: dict[str, Any],
output_states: dict[str, dict[str, Any]],
) -> tuple[Optional[torch.dtype], dict[str, list[Optional[torch.Tensor]]]]:
"""
Given the ``gathered_state_info`` and ``input_states``, the API converted
the StateInfo into the original state if the state is not a non-scalar
@ -1354,20 +1341,20 @@ def _convert_all_state_info(
``state_buffer`` in a correct order for later allgather purpose.
"""
state_buffers: Dict[str, List[Optional[torch.Tensor]]] = {}
state_buffers: dict[str, list[Optional[torch.Tensor]]] = {}
for fqn, gathered_state in output_states.items():
state_info = [s[fqn] for s in gathered_state_info]
all_tensor_states = sorted(
{n for state in state_info for n in state.tensors.keys()}
)
empty_ranks: Set[int] = set()
empty_ranks: set[int] = set()
dtype: Optional[torch.dtype] = None
# First check all the non-scalar states and get the information of
# states on each rank.
for state_name in all_tensor_states:
numels = []
_empty_ranks: Set[int] = set()
_empty_ranks: set[int] = set()
for rank, object_state in enumerate(state_info):
numels.append(0)
info = object_state.tensors.get(state_name, None)
@ -1426,7 +1413,7 @@ def _convert_all_state_info(
def _unflatten_orig_param_states(
fsdp_param_info: FSDPParamInfo,
output_states: Dict[str, Dict[str, Any]],
output_states: dict[str, dict[str, Any]],
state_name: str,
shard_state: bool,
to_save: bool,
@ -1498,12 +1485,12 @@ def _unflatten_orig_param_states(
def _allgather_orig_param_states(
fsdp_param_info: FSDPParamInfo,
gathered_state_info: List[Dict[str, StateInfo]],
input_states: Dict[str, Any],
gathered_state_info: list[dict[str, StateInfo]],
input_states: dict[str, Any],
shard_state: bool,
to_save: bool,
cpu_offload: bool,
) -> Dict[str, Dict[str, Any]]:
) -> dict[str, dict[str, Any]]:
"""
Given the ``gathered_state_info`` and ``input_states``, the API allgathers
all tensor states and restore non-tensor states from ``gathered_state_info``.
@ -1515,7 +1502,7 @@ def _allgather_orig_param_states(
fsdp_state._device_handle.memory_summary(),
)
output_states: Dict[str, Dict[str, Any]] = {fqn: {} for fqn in input_states.keys()}
output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states.keys()}
dtype, state_buffers = _convert_all_state_info(
fsdp_param_info, gathered_state_info, input_states, output_states
@ -1524,7 +1511,7 @@ def _allgather_orig_param_states(
if len(state_buffers) == 0:
return output_states
has_state_params: List[bool] = [
has_state_params: list[bool] = [
True if fqn in output_states else False
for fqn, idx in fsdp_param_info.param_indices.items()
]
@ -1545,7 +1532,7 @@ def _allgather_orig_param_states(
# Synchronize can be slow but this will be easier for us to debug.
fsdp_state._device_handle.synchronize()
for state_name, buffers in state_buffers.items():
local_buffers: List[torch.Tensor] = []
local_buffers: list[torch.Tensor] = []
begin = fsdp_state.rank * flat_param._sharded_size.numel()
# End is inclusive.
end = begin + flat_param._sharded_size.numel() - 1
@ -1666,11 +1653,11 @@ def _allgather_orig_param_states(
def _gather_all_orig_param_state(
fsdp_param_info: FSDPParamInfo,
input_states: Dict[str, Any],
input_states: dict[str, Any],
shard_state: bool,
to_save: bool,
cpu_offload: bool,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Given a optimizer state dict, ``input_states``, which the keys are FQNs to the
original parameters (not FlatParameters nor parmeter ID), gather all the
@ -1715,20 +1702,20 @@ def _gather_all_orig_param_state(
def _convert_state_with_orig_params(
all_optim_state_keys: List[_OptimStateKey],
optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
optim_state_dict: Dict[Union[str, int], Any],
all_optim_state_keys: list[_OptimStateKey],
optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]],
fqn_to_fsdp_param_info: dict[str, FSDPParamInfo],
optim_state_dict: dict[Union[str, int], Any],
to_save: bool,
shard_state: bool,
cpu_offload: bool = True,
) -> Dict[str, Any]:
fsdp_osd_state: Dict[str, Any] = {}
) -> dict[str, Any]:
fsdp_osd_state: dict[str, Any] = {}
# This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo
# usually corresponds to multiple parameters. We could not use FSDPParamInfo
# as the key because FSDPParamInfo is not hashable. As a result, we fall back
# to `id(FSDPParamInfo)`, which the type is an integer.
all_states: Dict[int, Dict[str, Any]] = {}
all_states: dict[int, dict[str, Any]] = {}
# Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
# across ranks
for optim_state_key in all_optim_state_keys:
@ -1806,15 +1793,15 @@ def _convert_state_with_orig_params(
def _convert_state_with_flat_params(
all_optim_state_keys: List[_OptimStateKey],
optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
optim_state_dict: Dict[Union[str, int], Any],
all_optim_state_keys: list[_OptimStateKey],
optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]],
fqn_to_fsdp_param_info: dict[str, FSDPParamInfo],
optim_state_dict: dict[Union[str, int], Any],
to_save: bool,
shard_state: bool,
cpu_offload: bool = True,
) -> Dict[str, Any]:
fsdp_osd_state: Dict[str, Any] = {}
) -> dict[str, Any]:
fsdp_osd_state: dict[str, Any] = {}
# Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
# across ranks
for optim_state_key in all_optim_state_keys:
@ -1866,10 +1853,10 @@ def _convert_state_with_flat_params(
def _optim_state_dict(
model: nn.Module,
optim: torch.optim.Optimizer,
optim_state_dict: Dict[str, Any],
optim_state_dict: dict[str, Any],
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[nn.Parameter],
]
],
@ -1879,7 +1866,7 @@ def _optim_state_dict(
using_optim_input: bool,
use_orig_params: bool = False,
cpu_offload: bool = True,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Consolidates the optimizer state and returns it as a :class:`dict`
following the convention of :meth:`torch.optim.Optimizer.state_dict`,
@ -1940,7 +1927,7 @@ def _optim_state_dict(
is_named_optimizer = _is_named_optimizer(optim_state_dict)
param_key_to_param = cast(
Dict[Union[int, str], nn.Parameter],
dict[Union[int, str], nn.Parameter],
(
_get_param_id_to_param_from_optim_input(model, optim_input)
if using_optim_input
@ -1985,7 +1972,7 @@ def _optim_state_dict(
if not to_save:
return {}
fsdp_osd: Dict[str, Any] = {"state": fsdp_osd_state}
fsdp_osd: dict[str, Any] = {"state": fsdp_osd_state}
flat_param_fqns = set(flat_param_to_fqn.values())
for key, value in optim_state_dict["state"].items():
@ -2019,7 +2006,7 @@ def _optim_state_dict(
return fsdp_osd
def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]:
"""
Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo``
if the param is managed by FSDP. Shared parameters, or original parameters that
@ -2055,7 +2042,7 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
def return_fn(fqn_to_param_info):
return fqn_to_param_info
fqn_to_param_info: Dict[str, FSDPParamInfo] = {}
fqn_to_param_info: dict[str, FSDPParamInfo] = {}
# FlatParameter._fqns stores the local fqn, starting from the root of the
# FSDP. Using _apply_to_modules() with model (may not be the FSDP root
# module) allows us to construct the global fqn.

View File

@ -2,7 +2,7 @@
import functools
import logging
from enum import auto, Enum
from typing import Any, Callable, Dict, List, no_type_check, Optional, Set
from typing import Any, Callable, no_type_check, Optional
import torch
import torch.distributed as dist
@ -57,7 +57,7 @@ class _PrefetchMode(Enum):
def _get_fsdp_root_states_with_modules(
module: nn.Module,
) -> tuple[List[_FSDPState], List[nn.Module]]:
) -> tuple[list[_FSDPState], list[nn.Module]]:
"""
Returns a tuple containing:
1. A list of the root ``_FSDPState`` instances in the module tree rooted at
@ -70,9 +70,9 @@ def _get_fsdp_root_states_with_modules(
must call :func:`_is_fsdp_root` to force a lazy initialization to determine
the FSDP root in case lazy initialization has not yet happened.
"""
fsdp_root_states: List[_FSDPState] = []
fsdp_root_modules: List[nn.Module] = []
visited_fsdp_states: Set[_FSDPState] = set()
fsdp_root_states: list[_FSDPState] = []
fsdp_root_modules: list[nn.Module] = []
visited_fsdp_states: set[_FSDPState] = set()
# NOTE: This function assumes that `module.modules()` proceeds top-down.
for submodule in module.modules():
optional_state = _get_module_fsdp_state(submodule)
@ -87,7 +87,7 @@ def _get_fsdp_root_states_with_modules(
return fsdp_root_states, fsdp_root_modules
def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
def _get_fsdp_root_states(module: nn.Module) -> list[_FSDPState]:
"""See :func:`_get_fsdp_root_states_with_modules`."""
fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
return fsdp_root_states
@ -178,7 +178,7 @@ def _share_state_and_init_handle_attrs(
handle = root_state._handle
if handle:
handle.init_flat_param_attributes()
attr_name_to_values: Dict[str, Set[Any]] = {}
attr_name_to_values: dict[str, set[Any]] = {}
for attr_name in HOMOGENEOUS_ATTR_NAMES:
attr_name_to_values[attr_name] = set()
root_state._all_handles = root_state._exec_order_data.all_handles # share reference
@ -347,8 +347,8 @@ def _pre_forward(
unshard_fn: Callable,
module: nn.Module,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
kwargs: dict[str, Any],
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""
Runs the pre-forward logic. This includes an opportunity to unshard
currently sharded parameters such as those for the current forward and
@ -1469,7 +1469,7 @@ def _register_post_backward_reshard_only_hook(
state: _FSDPState,
handle: Optional[FlatParamHandle],
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> None:
"""
Registers post-backward hooks to reshard flat parameters that do not
@ -1483,7 +1483,7 @@ def _register_post_backward_reshard_only_hook(
return
# Construct `inp_tensors` lazily to avoid CPU overhead in typical case
# where each flat parameter requires gradient
inp_tensors: Optional[List[torch.Tensor]] = None
inp_tensors: Optional[list[torch.Tensor]] = None
if not handle:
return
flat_param = handle.flat_param
@ -1555,7 +1555,7 @@ def _wait_for_computation_stream(
def _reset_flat_param_grad_info_if_needed(
handles: List[FlatParamHandle],
handles: list[FlatParamHandle],
):
"""
Clears the original parameters' gradients if needed. This method's CPU
@ -1573,7 +1573,7 @@ def _reset_flat_param_grad_info_if_needed(
def _get_buffers_and_dtypes_for_computation(
state: _FSDPState,
root_module: nn.Module,
) -> tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
) -> tuple[list[torch.Tensor], list[Optional[torch.dtype]]]:
"""
Returns all buffers in the module tree rooted at ``root_module`` and a
corresponding list of the buffer dtypes for computation. Each buffer dtype
@ -1581,9 +1581,9 @@ def _get_buffers_and_dtypes_for_computation(
low precision dtype otherwise.
"""
_p_assert(state._is_root, "Expects the root to cast buffers")
buffers: List[torch.Tensor] = []
buffer_dtypes: List[Optional[torch.dtype]] = []
visited_buffers: Set[torch.Tensor] = set()
buffers: list[torch.Tensor] = []
buffer_dtypes: list[Optional[torch.dtype]] = []
visited_buffers: set[torch.Tensor] = set()
# Traverse the FSDP states bottom-up so that we prefer the owning FSDP
# instance's mixed precision setting for each buffer
fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(
@ -1605,12 +1605,12 @@ def _get_buffers_and_dtypes_for_computation(
@no_type_check
def _get_orig_buffer_dtypes(
state: _FSDPState,
buffer_names: List[str],
) -> List[torch.dtype]:
buffer_names: list[str],
) -> list[torch.dtype]:
"""
Returns the original buffer types of the given buffer names.
"""
buffer_dtypes: List[torch.dtype] = []
buffer_dtypes: list[torch.dtype] = []
for buffer_name in buffer_names:
_p_assert(
buffer_name in state._buffer_name_to_orig_dtype,
@ -1623,8 +1623,8 @@ def _get_orig_buffer_dtypes(
def _cast_buffers_to_dtype_and_device(
buffers: List[torch.Tensor],
buffer_dtypes: List[Optional[torch.dtype]],
buffers: list[torch.Tensor],
buffer_dtypes: list[Optional[torch.dtype]],
device: torch.device,
) -> None:
"""

View File

@ -3,7 +3,8 @@ import contextlib
import logging
import math
import warnings
from typing import Any, Callable, cast, Dict, Generator, Iterator, List, no_type_check
from collections.abc import Generator, Iterator
from typing import Any, Callable, cast, no_type_check
import torch
import torch.distributed as dist
@ -170,10 +171,10 @@ def _common_unshard_pre_state_dict_hook(
def _common_unshard_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
param_hook: Callable,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
The post-state_dict flow that shared by all state_dict types that require
``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this
@ -302,9 +303,9 @@ def _full_pre_state_dict_hook(
def _full_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Hook that runs after model.state_dict() is called before returning result to
user. For FSDP, we may have to clone the tensors in state_dict as params go
@ -313,7 +314,7 @@ def _full_post_state_dict_hook(
"""
def param_hook(
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
fqn: str,
) -> None:
@ -347,7 +348,7 @@ def _full_post_state_dict_hook(
def _full_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
) -> None:
_lazy_init(fsdp_state, module)
@ -393,9 +394,9 @@ def _local_pre_state_dict_hook(
def _local_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
This hook create a ShardedTensor from the local flat_param and replace
the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
@ -448,7 +449,7 @@ def _local_post_load_state_dict_hook(
def _local_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
) -> None:
"""
@ -526,15 +527,15 @@ def _sharded_pre_state_dict_hook(
def _sharded_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
The hook replaces the unflattened, unsharded parameter in the state_dict
with a unflattened, sharded parameter (a ShardedTensor).
"""
def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str):
param = state_dict[fqn]
if not fsdp_state._state_dict_config._use_dtensor:
sharded_tensor = _ext_chunk_tensor(
@ -574,7 +575,7 @@ def _sharded_post_load_state_dict_hook(
def _sharded_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
) -> None:
"""
@ -686,10 +687,10 @@ def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
@torch.no_grad()
def _post_state_dict_hook(
module: nn.Module,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
_post_state_dict_hook() is called after the state_dict() of this
FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide
@ -802,7 +803,7 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
@torch.no_grad()
def _pre_load_state_dict_hook(
module: nn.Module,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
*args: Any,
) -> None:
@ -845,7 +846,7 @@ def _pre_load_state_dict_hook(
@torch.no_grad()
def _post_load_state_dict_hook(
module: nn.Module,
incompatible_keys: tuple[List[str], List[str]],
incompatible_keys: tuple[list[str], list[str]],
*args: Any,
) -> None:
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
@ -906,7 +907,7 @@ def _register_state_dict_hooks_base(
state: _FSDPState,
hook_registration_fn_name: str,
hook: Callable,
hook_registration_fn_kwargs: Dict[str, Any],
hook_registration_fn_kwargs: dict[str, Any],
) -> None:
"""Registers ``hook`` using ``hook_registration_fn``."""
if not _is_composable(state):

View File

@ -2,7 +2,7 @@
import functools
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set
from typing import Any, Callable, NamedTuple, Optional
import torch
import torch.nn as nn
@ -29,7 +29,7 @@ class TracingConfig:
"""
tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
concrete_args: Optional[Dict[str, Any]] = None
concrete_args: Optional[dict[str, Any]] = None
class _ParamUsageInfo(NamedTuple):
@ -52,7 +52,7 @@ class _ParamUsageInfo(NamedTuple):
"""
module: nn.Module
named_params: List[tuple[str, nn.Parameter]]
named_params: list[tuple[str, nn.Parameter]]
class _ExecutionInfo:
@ -79,12 +79,12 @@ class _ExecutionInfo:
def __init__(self, root_module: nn.Module) -> None:
self.curr_module: nn.Module = root_module
self.module_forward_order: List[nn.Module] = [root_module]
self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
self.module_forward_order: list[nn.Module] = [root_module]
self.module_to_param_usage_infos: dict[nn.Module, list[_ParamUsageInfo]] = {
root_module: []
}
self.param_forward_order: List[nn.Parameter] = []
self.visited_params: Set[nn.Parameter] = set()
self.param_forward_order: list[nn.Parameter] = []
self.visited_params: set[nn.Parameter] = set()
class _ExecOrderTracer:
@ -120,7 +120,7 @@ class _ExecOrderTracer:
module: nn.Module,
forward: Callable,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> Any:
"""
Overrides ``call_module`` to save execution information to
@ -160,12 +160,12 @@ class _ExecOrderTracer:
self,
create_proxy: Callable,
exec_info: _ExecutionInfo,
fqn_to_param: Dict[str, nn.Parameter],
fqn_to_param: dict[str, nn.Parameter],
# Below are the expected arguments to `create_proxy()`
kind: str,
target: torch.fx.node.Target,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
@ -210,7 +210,7 @@ class _ExecOrderTracer:
curr_module = exec_info.curr_module
if kind in ("call_function", "call_method"):
if args is not None:
named_params: List[tuple[str, nn.Parameter]] = []
named_params: list[tuple[str, nn.Parameter]] = []
for arg in args:
if (
isinstance(arg, torch.fx.Proxy)

View File

@ -6,7 +6,6 @@ imports. For brevity, we may import the file as ``traversal_utils``.
"""
import collections
from typing import Deque, List, Set
import torch.nn as nn
from torch.distributed._composable.contract import _get_registry
@ -48,7 +47,7 @@ def _composable(module: nn.Module) -> bool:
# `FlatParameter` registration, which is not needed for `use_orig_params=True`.
def _get_fsdp_states_with_modules(
module: nn.Module,
) -> tuple[List[_FSDPState], List[nn.Module]]:
) -> tuple[list[_FSDPState], list[nn.Module]]:
"""
Returns a tuple containing:
1. A list of the ``_FSDPState`` instances in the module tree rooted at
@ -65,19 +64,19 @@ def _get_fsdp_states_with_modules(
NOTE: The traversal does not proceed into any module annotated by an
incompatible API (e.g. ``replicate``).
"""
fsdp_states: List[_FSDPState] = []
fsdp_modules: List[nn.Module] = []
fsdp_states: list[_FSDPState] = []
fsdp_modules: list[nn.Module] = []
# Track the visited FSDP states since multiple modules may share the same
# one and we want to return a de-duplicated list
visited_fsdp_states: Set[_FSDPState] = set()
visited_fsdp_states: set[_FSDPState] = set()
# Track the visited modules in case of shared modules, which implies the
# module graph is no longer a tree
visited_modules: Set[nn.Module] = set()
visited_modules: set[nn.Module] = set()
# Perform depth-first search from `module` to ensure that we do not
# traverse into an incompatible API's subtree (use DFS instead of BFS to
# match `.modules()` order)
deque: Deque[nn.Module] = collections.deque([module])
deque: collections.deque[nn.Module] = collections.deque([module])
while deque:
submodule = deque.popleft()
visited_modules.add(submodule)
@ -94,13 +93,13 @@ def _get_fsdp_states_with_modules(
return fsdp_states, fsdp_modules
def _get_fsdp_states(module: nn.Module) -> List[_FSDPState]:
def _get_fsdp_states(module: nn.Module) -> list[_FSDPState]:
"""See :func:`_get_fsdp_states_with_modules`."""
fsdp_states, _ = _get_fsdp_states_with_modules(module)
return fsdp_states
def _get_fsdp_handles(module: nn.Module) -> List:
def _get_fsdp_handles(module: nn.Module) -> list:
"""
Returns all ``FlatParamHandle`` s in the module tree rooted at ``module``
following the rules in :func:`_get_fsdp_state`.

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import contextlib
import warnings
from typing import cast, Generator
from collections.abc import Generator
from typing import cast
import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils

View File

@ -4,7 +4,7 @@ import functools
import inspect
import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Set, Type, Union
from typing import Any, Callable, Union
import torch.nn as nn
from torch.distributed.fsdp._common_utils import (
@ -25,9 +25,9 @@ from torch.distributed.fsdp.wrap import (
def _auto_wrap(
root_module: nn.Module,
policy: Union[Callable, _Policy],
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
root_kwargs: Dict[str, Any],
ignored_modules: set[nn.Module],
ignored_params: set[nn.Parameter],
root_kwargs: dict[str, Any],
fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
):
"""
@ -111,7 +111,7 @@ def _check_nested_wrapping(root_module: nn.Module):
def _warn_on_overridden_mixed_precision(
overridden_module_classes: Set[Type[nn.Module]],
overridden_module_classes: set[type[nn.Module]],
):
if len(overridden_module_classes) == 0:
return
@ -125,8 +125,8 @@ def _warn_on_overridden_mixed_precision(
def _validate_frozen_params(
root_module: nn.Module,
modules_to_wrap: Set[nn.Module],
ignored_params: Set[nn.Parameter],
modules_to_wrap: set[nn.Module],
ignored_params: set[nn.Parameter],
use_orig_params: bool,
):
"""
@ -136,15 +136,15 @@ def _validate_frozen_params(
recommended for ``use_orig_params=True`` (user warning).
"""
post_order_named_modules = _get_post_order_named_modules(root_module)
visited_modules: Set[nn.Module] = set()
visited_modules: set[nn.Module] = set()
for module_name, module in post_order_named_modules:
if module in modules_to_wrap:
param_to_fqn = _get_managed_param_to_fqn(
module, ignored_params, visited_modules, module_name
)
frozen_param_fqns: List[str] = []
frozen_param_fqns: list[str] = []
frozen_param_numel = 0
nonfrozen_param_fqns: List[str] = []
nonfrozen_param_fqns: list[str] = []
nonfrozen_param_numel = 0
for param, fqn in param_to_fqn.items():
if param.requires_grad:
@ -178,7 +178,7 @@ def _validate_frozen_params(
def _get_post_order_named_modules(
root_module: nn.Module,
) -> List[tuple[str, nn.Module]]:
) -> list[tuple[str, nn.Module]]:
"""
This returns the named modules following a post-order traversal, which is a
valid reverse topological sort. We achieve this using the reverse of a
@ -202,7 +202,7 @@ def _get_post_order_named_modules(
visited_modules = {root_module}
stack = [("", root_module)]
# Append and reverse at the end for linear-time algorithm
reverse_post_order_named_modules: List[tuple[str, nn.Module]] = []
reverse_post_order_named_modules: list[tuple[str, nn.Module]] = []
while stack:
module_name, module = stack.pop()
reverse_post_order_named_modules.append((module_name, module))
@ -220,10 +220,10 @@ def _get_post_order_named_modules(
def _get_managed_param_to_fqn(
module_to_wrap: nn.Module,
ignored_params: Set[nn.Parameter],
visited_modules: Set[nn.Module],
ignored_params: set[nn.Parameter],
visited_modules: set[nn.Module],
root_prefix: str,
) -> Dict[nn.Parameter, str]:
) -> dict[nn.Parameter, str]:
"""
This returns a dict that maps managed parameter to its FQN for the given
``module_to_wrap``. The dict's keys are exactly the parameters that would
@ -238,7 +238,7 @@ def _get_managed_param_to_fqn(
on the full module tree in one shot. Given those differences, we do not try
to unify the two.
"""
param_to_fqn: Dict[nn.Parameter, str] = {}
param_to_fqn: dict[nn.Parameter, str] = {}
# Run BFS (or any tree traversal works)
queue = collections.deque([(module_to_wrap, root_prefix)])
visited_modules.add(module_to_wrap)

View File

@ -3,9 +3,10 @@ This file includes public APIs for FSDP such as the classes used for the
constructor arguments.
"""
from collections.abc import Sequence
from dataclasses import dataclass
from enum import auto, Enum
from typing import Optional, Sequence, Type
from typing import Optional
import torch
from torch.nn.modules.batchnorm import _BatchNorm
@ -223,7 +224,7 @@ class MixedPrecision:
keep_low_precision_grads: bool = False
cast_forward_inputs: bool = False
cast_root_forward_inputs: bool = True
_module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)
_module_classes_to_ignore: Sequence[type[torch.nn.Module]] = (_BatchNorm,)
@dataclass

View File

@ -6,19 +6,10 @@ import functools
import math
import traceback
import warnings
from collections.abc import Generator, Iterable, Iterator
from contextlib import contextmanager
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Union,
)
from typing import Any, Callable, Optional, Union
import torch
import torch.distributed as dist
@ -562,7 +553,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
def fsdp_modules(
module: nn.Module,
root_only: bool = False,
) -> List["FullyShardedDataParallel"]:
) -> list["FullyShardedDataParallel"]:
"""Return all nested FSDP instances.
This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``.
@ -1013,7 +1004,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
param_name = param_name.replace(FSDP_PREFIX, "")
yield (param_name, param)
def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
def _assert_state(self, state: Union[TrainingState, list[TrainingState]]) -> None:
"""Assert we are in the given state."""
# Since assert can be turned off and this error checking
# is really important, we use explicit error checking
@ -1135,7 +1126,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
# iteration order and hence deterministic total norm computation
sharded_params = []
nonsharded_params = []
grads: List[torch.Tensor] = []
grads: list[torch.Tensor] = []
for handle in self._all_handles:
if handle.uses_sharded_strategy:
target_set = sharded_params_set
@ -1257,10 +1248,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
def _optim_state_dict_impl(
model: torch.nn.Module,
optim: torch.optim.Optimizer,
optim_state_dict: Dict[str, Any],
optim_state_dict: dict[str, Any],
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[torch.nn.Parameter],
]
] = None,
@ -1270,7 +1261,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
cpu_offload: bool = True,
*,
_stacklevel: int = 1,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Transform the state-dict of an optimizer corresponding to a sharded model.
This is the internal API that is used by all the optim_state_dict implementations.
@ -1312,11 +1303,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@staticmethod
def _optim_state_dict_to_load_impl(
optim_state_dict: Dict[str, Any],
optim_state_dict: dict[str, Any],
model: torch.nn.Module,
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[torch.nn.Parameter],
]
] = None,
@ -1325,7 +1316,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
rank0_only: bool = False,
is_named_optimizer: bool = False,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
@ -1376,13 +1367,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
optim: torch.optim.Optimizer,
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[torch.nn.Parameter],
]
] = None,
rank0_only: bool = True,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Return the full optimizer state-dict.
Consolidates the full optimizer state on rank 0 and returns it
@ -1451,7 +1442,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
model: torch.nn.Module,
optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Return the optimizer state-dict in its sharded form.
The API is similar to :meth:`full_optim_state_dict` but this API chunks
@ -1482,16 +1473,16 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@staticmethod
def shard_full_optim_state_dict(
full_optim_state_dict: Dict[str, Any],
full_optim_state_dict: dict[str, Any],
model: torch.nn.Module,
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[torch.nn.Parameter],
]
] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Shard a full optimizer state-dict.
Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened
@ -1561,10 +1552,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@staticmethod
def flatten_sharded_optim_state_dict(
sharded_optim_state_dict: Dict[str, Any],
sharded_optim_state_dict: dict[str, Any],
model: torch.nn.Module,
optim: torch.optim.Optimizer,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Flatten a sharded optimizer state-dict.
The API is similar to :meth:`shard_full_optim_state_dict`. The only
@ -1600,17 +1591,17 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@staticmethod
def scatter_full_optim_state_dict(
full_optim_state_dict: Optional[Dict[str, Any]],
full_optim_state_dict: Optional[dict[str, Any]],
model: torch.nn.Module,
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[torch.nn.Parameter],
]
] = None,
optim: Optional[torch.optim.Optimizer] = None,
group: Optional[Any] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Scatter the full optimizer state dict from rank 0 to all other ranks.
Returns the sharded optimizer state dict on each rank.
@ -1684,17 +1675,17 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
@staticmethod
def rekey_optim_state_dict(
optim_state_dict: Dict[str, Any],
optim_state_dict: dict[str, Any],
optim_state_key_type: OptimStateKeyType,
model: torch.nn.Module,
optim_input: Optional[
Union[
List[Dict[str, Any]],
list[dict[str, Any]],
Iterable[torch.nn.Parameter],
]
] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``.
This can be used to achieve compatibility between optimizer state dicts from models with FSDP
@ -1762,7 +1753,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
else _get_param_key_to_param(optim)
)
param_to_param_name = _get_param_to_fqn(model)
param_id_to_param_name: List[str] = [
param_id_to_param_name: list[str] = [
param_to_param_name[param] for param in param_id_to_param.values()
]
new_osd["state"] = {
@ -1811,9 +1802,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
def optim_state_dict(
model: torch.nn.Module,
optim: torch.optim.Optimizer,
optim_state_dict: Optional[Dict[str, Any]] = None,
optim_state_dict: Optional[dict[str, Any]] = None,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Transform the state-dict of an optimizer corresponding to a sharded model.
@ -1907,11 +1898,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
def optim_state_dict_to_load(
model: torch.nn.Module,
optim: torch.optim.Optimizer,
optim_state_dict: Dict[str, Any],
optim_state_dict: dict[str, Any],
is_named_optimizer: bool = False,
load_directly: bool = False,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
@ -2146,7 +2137,7 @@ def _get_grad_norm(
def _get_param_to_fqn(
model: torch.nn.Module,
) -> Dict[torch.nn.Parameter, str]:
) -> dict[torch.nn.Parameter, str]:
"""
Construct a mapping from parameters to their parameter names.
@ -2178,7 +2169,7 @@ def _get_param_to_fqn(
def _get_fqn_to_param(
model: torch.nn.Module,
) -> Dict[str, torch.nn.Parameter]:
) -> dict[str, torch.nn.Parameter]:
"""Construct the inverse mapping of :meth:`_get_param_to_fqn`."""
param_to_param_name = _get_param_to_fqn(model)
return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import logging
from collections import abc, defaultdict
from typing import Any, Dict, Iterable, List, Optional, overload, Union
from collections.abc import Iterable
from typing import Any, Optional, overload, Union
import torch
import torch.distributed as dist
@ -12,7 +13,7 @@ from torch.distributed.distributed_c10d import ProcessGroup
logger = logging.getLogger(__name__)
def _refresh_per_optimizer_state() -> Dict[str, Any]:
def _refresh_per_optimizer_state() -> dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}}
@ -36,7 +37,7 @@ class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
def __init__(self, master_tensor: torch.Tensor) -> None:
assert _is_supported_device(master_tensor)
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
self._per_device_tensors: dict[torch.device, torch.Tensor] = {}
class ShardedGradScaler(GradScaler):
@ -115,7 +116,7 @@ class ShardedGradScaler(GradScaler):
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]:
...
@overload
@ -145,7 +146,7 @@ class ShardedGradScaler(GradScaler):
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
return scaled_output.type(outputs.dtype)
stash: List[_GeneralMultiDeviceReplicator] = []
stash: list[_GeneralMultiDeviceReplicator] = []
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor):
@ -175,7 +176,7 @@ class ShardedGradScaler(GradScaler):
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool = True,
) -> Dict[torch.device, torch.Tensor]:
) -> dict[torch.device, torch.Tensor]:
per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)

View File

@ -7,19 +7,8 @@
import contextlib
import copy
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
Optional,
Sequence,
Set,
Type,
Union,
)
from collections.abc import Generator, Iterable, Sequence
from typing import Any, Callable, cast, Optional, Union
import torch.nn as nn
@ -51,7 +40,7 @@ def _post_order_apply(
not changed.
"""
# Track visited modules to avoid visiting shared modules multiple times
visited_modules: Set[nn.Module] = {root_module}
visited_modules: set[nn.Module] = {root_module}
def _post_order_apply_inner(
module: nn.Module,
@ -82,7 +71,7 @@ def _post_order_apply(
def _construct_wrap_fn(
root_module: nn.Module,
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
target_module_to_kwargs: dict[nn.Module, dict[str, Any]],
fsdp_fn: Callable,
) -> Callable[[nn.Module], Optional[nn.Module]]:
"""
@ -104,10 +93,10 @@ def _construct_wrap_fn(
def _run_mixed_precision_override_policy(
root_module: nn.Module,
module_classes: Iterable[Type[nn.Module]],
ignored_modules: Set[nn.Module],
root_kwargs: Dict[str, Any],
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
module_classes: Iterable[type[nn.Module]],
ignored_modules: set[nn.Module],
root_kwargs: dict[str, Any],
target_module_to_kwargs: dict[nn.Module, dict[str, Any]],
):
module_classes_tuple = tuple(set(module_classes))
for module in root_module.modules():
@ -141,9 +130,9 @@ class _Policy(ABC):
def _run_policy(
self,
root_module: nn.Module,
ignored_modules: Set[nn.Module],
root_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
ignored_modules: set[nn.Module],
root_kwargs: dict[str, Any],
) -> dict[nn.Module, dict[str, Any]]:
"""
This should return a dict ``target_module_to_kwargs`` that maps from
each target module to wrap to its kwargs.
@ -155,7 +144,7 @@ def _module_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
module_classes: Set[Type[nn.Module]],
module_classes: set[type[nn.Module]],
) -> bool:
"""
This auto wrap policy wraps every module that is an instance of any type in
@ -189,7 +178,7 @@ class ModuleWrapPolicy(_Policy):
passing in the kwargs given to the root.
"""
def __init__(self, module_classes: Iterable[Type[nn.Module]]):
def __init__(self, module_classes: Iterable[type[nn.Module]]):
module_classes_set = set(module_classes)
self._module_classes = module_classes_set
self._module_classes_str = str(module_classes_set)
@ -197,11 +186,11 @@ class ModuleWrapPolicy(_Policy):
def _run_policy(
self,
root_module: nn.Module,
ignored_modules: Set[nn.Module],
root_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
ignored_modules: set[nn.Module],
root_kwargs: dict[str, Any],
) -> dict[nn.Module, dict[str, Any]]:
module_classes = tuple(self._module_classes)
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
for module in root_module.modules():
if module in ignored_modules:
continue
@ -245,16 +234,16 @@ class CustomPolicy(_Policy):
>>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
"""
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]):
self._lambda_fn = lambda_fn
def _run_policy(
self,
root_module: nn.Module,
ignored_modules: Set[nn.Module],
root_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
ignored_modules: set[nn.Module],
root_kwargs: dict[str, Any],
) -> dict[nn.Module, dict[str, Any]]:
target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
for module in root_module.modules():
if module in ignored_modules:
continue
@ -307,7 +296,7 @@ def transformer_auto_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
transformer_layer_cls: Set[Type[nn.Module]],
transformer_layer_cls: set[type[nn.Module]],
) -> bool:
"""
See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
@ -352,8 +341,8 @@ def size_based_auto_wrap_policy(
nonwrapped_numel: int,
# Additional custom arguments
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
force_leaf_modules: Optional[set[type[nn.Module]]] = None,
exclude_wrap_modules: Optional[set[type[nn.Module]]] = None,
) -> bool:
"""
A size-based auto wrap policy.
@ -495,8 +484,8 @@ def _recursive_wrap(
module: nn.Module,
auto_wrap_policy: Callable,
wrapper_cls: Callable,
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
ignored_modules: set[nn.Module],
ignored_params: set[nn.Parameter],
only_wrap_children: bool = False,
**kwargs: Any,
) -> tuple[nn.Module, int]:
@ -573,9 +562,9 @@ class _ConfigAutoWrap:
in_autowrap_context: bool = False # Context flag
wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} # Wrapper's args
kwargs: dict[str, Any] = {} # Wrapper's args
def __init__(self, **kwargs: Dict[str, Any]):
def __init__(self, **kwargs: dict[str, Any]):
self.kwargs = kwargs
@staticmethod